Skip to content

Commit

Permalink
fix: other algorithms are now functional 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 20, 2024
1 parent e0c9ef4 commit e3c13c7
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 98 deletions.
2 changes: 1 addition & 1 deletion lib/DataDrivenLux/src/DataDrivenLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using ConcreteStructs: @concrete
using Setfield: Setfield, @set!

using Optim: Optim, LBFGS
using Optimisers: Optimisers, ADAM
using Optimisers: Optimisers, Adam

using Lux: Lux, logsoftmax, softmax!
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
Expand Down
59 changes: 17 additions & 42 deletions lib/DataDrivenLux/src/algorithms/randomsearch.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,26 @@
"""
$(TYPEDEF)
@concrete struct RandomSearch <: AbstractDAGSRAlgorithm
options <: CommonAlgOptions
end

Performs a random search over the space of possible solutions to the
symbolic regression problem.
"""
$(SIGNATURES)
# Fields
$(FIELDS)
Performs a random search over the space of possible solutions to the symbolic regression
problem.
"""
@kwdef struct RandomSearch{F, A, L, O} <: AbstractDAGSRAlgorithm
# "The number of candidates to track"
# populationsize::Int = 100
# "The functions to include in the search"
# functions::F = (sin, exp, cos, log, +, -, /, *)
# "The arities of the functions"
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
# "The number of layers"
# n_layers::Int = 1
# "Include skip layers"
# skip::Bool = true
# "Simplex mapping"
# simplex::AbstractSimplex = Softmax()
# "Evaluation function to sort the samples"
# loss::L = aicc
# "The number of candidates to keep in each iteration"
# keep::Union{Real, Int} = 0.1
# "Use protected operators"
# use_protected::Bool = true
# "Use distributed optimization and resampling"
# distributed::Bool = false
# "Use threaded optimization and resampling - not implemented right now."
# threaded::Bool = false
# "Random seed"
# rng::AbstractRNG = Random.default_rng()
# "Optim optimiser"
# optimizer::O = LBFGS()
# "Optim options"
# optim_options::Optim.Options = Optim.Options()
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
# observed::Union{ObservedModel, Nothing} = nothing
# "Field for possible optimiser - no use for Randomsearch"
# optimiser::Nothing = nothing
function RandomSearch(; populationsize = 100, functions = (sin, exp, cos, log, +, -, /, *),
arities = (1, 1, 1, 1, 2, 2, 2, 2), n_layers = 1, skip = true, loss = aicc,
keep = 0.1, use_protected = true, distributed = false, threaded = false,
rng = Random.default_rng(), optimizer = LBFGS(), optim_options = Optim.Options(),
observed = nothing, alpha = 0.999f0)
return RandomSearch(CommonAlgOptions(;
populationsize, functions, arities, n_layers, skip, simplex = Softmax(), loss,
keep, use_protected, distributed, threaded, rng, optimizer,
optim_options, optimiser = nothing, observed, alpha))
end

Base.print(io::IO, ::RandomSearch) = print(io, "RandomSearch")
Base.summary(io::IO, x::RandomSearch) = print(io, x)

# Randomsearch does not do anything
function update_parameters!(::SearchCache)
return
end
update_parameters!(::SearchCache) = nothing
67 changes: 21 additions & 46 deletions lib/DataDrivenLux/src/algorithms/reinforce.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,35 @@
@concrete struct Reinforce <: AbstractDAGSRAlgorithm
reward
ad_backend <: AD.AbstractBackend
options <: CommonAlgOptions
end

"""
$(TYPEDEF)
$(SIGNATURES)
Uses the REINFORCE algorithm to search over the space of possible solutions to the
Uses the REINFORCE algorithm to search over the space of possible solutions to the
symbolic regression problem.
# Fields
$(FIELDS)
"""
@kwdef struct Reinforce{F, A, L, O, R} <: AbstractDAGSRAlgorithm
"Reward function which should convert the loss to a reward."
reward::R = RelativeReward(false)
# "The number of candidates to track"
# populationsize::Int = 100
# "The functions to include in the search"
# functions::F = (sin, exp, cos, log, +, -, /, *)
# "The arities of the functions"
# arities::A = (1, 1, 1, 1, 2, 2, 2, 2)
# "The number of layers"
# n_layers::Int = 1
# "Include skip layers"
# skip::Bool = true
# "Simplex mapping"
# simplex::AbstractSimplex = Softmax()
# "Evaluation function to sort the samples"
# loss::L = aicc
# "The number of candidates to keep in each iteration"
# keep::Union{Real, Int} = 0.1
# "Use protected operators"
# use_protected::Bool = true
# "Use distributed optimization and resampling"
# distributed::Bool = false
# "Use threaded optimization and resampling - not implemented right now."
# threaded::Bool = false
# "Random seed"
# rng::AbstractRNG = Random.default_rng()
# "Optim optimiser"
# optimizer::O = LBFGS()
# "Optim options"
# optim_options::Optim.Options = Optim.Options()
# "Observed model - if `nothing`is used, a normal distributed additive error with fixed variance is assumed."
# observed::Union{ObservedModel, Nothing} = nothing
# "AD Backend"
# ad_backend::AD.AbstractBackend = AD.ForwardDiffBackend()
# "Optimiser"
# optimiser::Optimisers.AbstractRule = ADAM()
function Reinforce(reward = RelativeReward(false); populationsize = 100,
functions = (sin, exp, cos, log, +, -, /, *), arities = (1, 1, 1, 1, 2, 2, 2, 2),
n_layers = 1, skip = true, loss = aicc, keep = 0.1, use_protected = true,
distributed = false, threaded = false, rng = Random.default_rng(),
optimizer = LBFGS(), optim_options = Optim.Options(), observed = nothing,
alpha = 0.999f0, optimiser = Adam(), ad_backend = AD.ForwardDiffBackend())
return Reinforce(reward, ad_backend, CommonAlgOptions(;
populationsize, functions, arities, n_layers, skip, simplex = Softmax(), loss,
keep, use_protected, distributed, threaded, rng, optimizer,
optim_options, optimiser, observed, alpha))
end

Base.print(io::IO, ::Reinforce) = print(io, "Reinforce")
Base.summary(io::IO, x::Reinforce) = print(io, x)

function reinforce_loss(candidates, p, alg)
(; loss, reward) = alg
losses = map(loss, candidates)
rewards = reward(losses)
losses = map(alg.options.loss, candidates)
rewards = alg.reward(losses)
# ∇U(θ) = E[∇log(p)*R(t)]
mean(map(enumerate(candidates)) do (i, candidate)
return mean(map(enumerate(candidates)) do (i, candidate)
return rewards[i] * -candidate(p)
end)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/DataDrivenLux/src/caches/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
Base.show(io::IO, cache::SearchCache) = print(io, "SearchCache : $(cache.alg)")

function init_model(x::AbstractDAGSRAlgorithm, basis::Basis, dataset::Dataset, intervals)
(; simplex, n_layers, arities, functions, use_protected, skip) = x
(; simplex, n_layers, arities, functions, use_protected, skip) = x.options

# Get the parameter mapping
variable_mask = map(enumerate(equations(basis))) do (i, eq)
Expand Down
4 changes: 2 additions & 2 deletions lib/DataDrivenLux/test/randomsearch_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ dummy_dataset = DataDrivenLux.Dataset(dummy_problem)

@test isempty(dummy_dataset.u_intervals)

for (data, interval) in zip((X, Y, 1:size(X, 2)),
for (data, _interval) in zip((X, Y, 1:size(X, 2)),
(dummy_dataset.x_intervals[1], dummy_dataset.y_intervals[1], dummy_dataset.t_interval))
@test (interval.lo, interval.hi) == extrema(data)
@test isequal_interval(_interval, interval(extrema(data)))
end

# We have 1 Choices in the first layer, 2 in the last
Expand Down
1 change: 1 addition & 0 deletions lib/DataDrivenLux/test/reinforce_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Random
using Distributions
using Test
using Optimisers
using Optim
using StableRNGs

rng = StableRNG(1234)
Expand Down
12 changes: 6 additions & 6 deletions lib/DataDrivenLux/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ const GROUP = get(ENV, "GROUP", "All")

@time begin
if GROUP == "All" || GROUP == "DataDrivenLux"
@safetestset "Lux" begin
@testset "Lux" begin
@safetestset "Nodes" include("nodes.jl")
@safetestset "Layers" include("layers.jl")
@safetestset "Graphs" include("graphs.jl")
end

@safetestset "Caches" begin
@testset "Caches" begin
@safetestset "Candidate" include("candidate.jl") # FIXME
@safetestset "Cache" include("cache.jl")
end

@safetestset "Algorithms" begin
@safetestset "RandomSearch" include("randomsearch_solve.jl") # FIXME
@safetestset "Reinforce" include("reinforce_solve.jl") # FIXME
@safetestset "CrossEntropy" include("crossentropy_solve.jl") # FIXME
@testset "Algorithms" begin
@safetestset "RandomSearch" include("randomsearch_solve.jl")
@safetestset "Reinforce" include("reinforce_solve.jl")
@safetestset "CrossEntropy" include("crossentropy_solve.jl")
end
end
end

0 comments on commit e3c13c7

Please sign in to comment.