Skip to content

Commit

Permalink
Fix and add rawmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
giopaglia committed Jul 21, 2023
1 parent b44eeec commit 78cc5b2
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 33 deletions.
19 changes: 12 additions & 7 deletions src/interfaces/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class
# syntaxstring_kwargs = (; hidemodality = (length(var_grouping) == 1), variable_names_map = var_grouping)
))

rawmodel_full = model
rawmodel = MDT.prune(model; simplify = true)

solemodel_full = translate_function(model)
solemodel = translate_function(MDT.prune(model; simplify = true))
solemodel = translate_function(rawmodel)

fitresult = (
model = model,
Expand All @@ -115,9 +118,11 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class
(Xnew, ynew, var_grouping, classes_seen, w) = MMI.reformat(m, Xnew, ynew; passive_mode = true)
translate_function(ModalDecisionTrees.sprinkle(model, Xnew, ynew))
end,
model = model,
model_full = model_full,
# TODO remove redundancy?
model = solemodel,
model_full = solemodel_full,
rawmodel = rawmodel,
rawmodel_full = rawmodel_full,
solemodel = solemodel,
solemodel_full = solemodel_full,
var_grouping = var_grouping,
Expand All @@ -135,8 +140,8 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class
return fitresult, cache, report
end

MMI.fitted_params(::TreeModel, fitresult) = merge(fitresult, (; tree = fitresult.model))
MMI.fitted_params(::ForestModel, fitresult) = merge(fitresult, (; forest = fitresult.model))
MMI.fitted_params(::TreeModel, fitresult) = merge(fitresult, (; tree = fitresult.rawmodel))
MMI.fitted_params(::ForestModel, fitresult) = merge(fitresult, (; forest = fitresult.rawmodel))

############################################################################################
############################################################################################
Expand All @@ -149,7 +154,7 @@ function MMI.predict(m::SymbolicModel, fitresult, Xnew, var_grouping = nothing)
"var_grouping = $(var_grouping)" *
"\n"
end
MDT.apply_proba(fitresult.model, Xnew, get(fitresult, :classes_seen, nothing); suppress_parity_warning = true)
MDT.apply_proba(fitresult.rawmodel, Xnew, get(fitresult, :classes_seen, nothing); suppress_parity_warning = true)
end

############################################################################################
Expand Down Expand Up @@ -187,7 +192,7 @@ function MMI.feature_importances(m::SymbolicModel, fitresult, report)
error("Unexpected feature_importance encountered: $(m.feature_importance).")
end

featimportance_dict = compute_featureimportance(fitresult.model, fitresult.var_grouping; normalize=true)
featimportance_dict = compute_featureimportance(fitresult.rawmodel, fitresult.var_grouping; normalize=true)
featimportance_vec = collect(featimportance_dict)
sort!(featimportance_vec, rev=true, by=x->last(x))

Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/MLJ/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function (c::ModelPrinter)(
max_depth::Union{Nothing,Integer} = c.m.display_depth;
kwargs...
)
c(io, (print_solemodel ? c.solemodel : c.model); max_depth = max_depth, kwargs...)
c(io, (print_solemodel ? c.solemodel : c.rawmodel); max_depth = max_depth, kwargs...)
end

function (c::ModelPrinter)(
Expand Down
2 changes: 1 addition & 1 deletion test/classification/demo-juliacon2022.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ cm = ConfusionMatrix(y_test, y_test_preds; force_class_order=["I have command",
@test overall_accuracy(cm) > 0.6

# Render model in LaTeX
# show_latex(mach.fitresult.model; variable_names = [variable_names_latex], silent = true);
# show_latex(mach.fitresult.rawmodel; variable_names = [variable_names_latex], silent = true);

end
24 changes: 12 additions & 12 deletions test/classification/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ train_idxs, test_idxs = p[1:round(Int, N*.4)], p[round(Int, N*.4)+1:end]

# Full training
mach = machine(ModalDecisionTree(;), X, y) |> fit!
@test nnodes(fitted_params(mach).model) == 191
@test nnodes(fitted_params(mach).rawmodel) == 191
@test sum(predict_mode(mach, X) .== y) / length(y) > 0.92

############################################################################################

mach = machine(ModalDecisionTree(;), X, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 115
@test nnodes(fitted_params(mach).rawmodel) == 115
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.78


Expand All @@ -33,7 +33,7 @@ mach = machine(ModalDecisionTree(;
max_depth = 6,
min_samples_leaf = 5,
), X, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 77
@test nnodes(fitted_params(mach).rawmodel) == 77
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.75


Expand All @@ -47,7 +47,7 @@ mach = machine(ModalRandomForest(;
min_purity_increase = 0.0,
rng = Random.MersenneTwister(1)
), X, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 1242
@test nnodes(fitted_params(mach).rawmodel) == 1242
@test predict_mode(mach, rows = test_idxs)
@test predict(mach, rows = test_idxs)
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.85
Expand All @@ -57,15 +57,15 @@ mach = machine(ModalRandomForest(;

# NamedTuple dataset
mach = mach = machine(ModalDecisionTree(;), Xnt, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 131
@test nnodes(fitted_params(mach).rawmodel) == 131
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.68

mach = machine(ModalDecisionTree(;
relations = :IA7,
features = [minimum],
initconditions = :start_at_center,
), Xnt, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 147
@test nnodes(fitted_params(mach).rawmodel) == 147
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.67

mach = machine(ModalDecisionTree(;
Expand All @@ -74,7 +74,7 @@ mach = machine(ModalDecisionTree(;
# initconditions = :start_at_center,
featvaltype = Float32,
), selectrows(Xnt, train_idxs), selectrows(y, train_idxs)) |> m->fit!(m)
@test nnodes(fitted_params(mach).model) == 131
@test nnodes(fitted_params(mach).rawmodel) == 131
@test sum(predict_mode(mach, selectrows(Xnt, test_idxs)) .== y[test_idxs]) / length(y[test_idxs]) > 0.71


Expand Down Expand Up @@ -111,7 +111,7 @@ mach = machine(ModalDecisionTree(;
# initconditions = :start_at_center,
featvaltype = Float32,
), selectrows(Xnt, train_idxs), selectrows(y, train_idxs)) |> m->fit!(m)
@test nnodes(fitted_params(mach).model) == 71
@test nnodes(fitted_params(mach).rawmodel) == 71
@test sum(predict_mode(mach, selectrows(Xnt, test_idxs)) .== y[test_idxs]) / length(y[test_idxs]) > 0.73

preds, tree2 = report(mach).sprinkle(selectrows(Xnt, test_idxs), selectrows(y, test_idxs));
Expand All @@ -129,7 +129,7 @@ mach = machine(ModalDecisionTree(;
# initconditions = :start_at_center,
featvaltype = Float32,
), selectrows(Xnt, train_idxs), selectrows(y, train_idxs)) |> m->fit!(m)
@test nnodes(fitted_params(mach).model) == 79
@test nnodes(fitted_params(mach).rawmodel) == 79
@test sum(predict_mode(mach, selectrows(Xnt, test_idxs)) .== y[test_idxs]) / length(y[test_idxs]) > 0.73

preds, tree2 = report(mach).sprinkle(selectrows(Xnt, test_idxs), selectrows(y, test_idxs));
Expand All @@ -141,15 +141,15 @@ printmodel.(joinrules(listrules(ModalDecisionTrees.translate(tree2))); show_metr
readmetrics.(joinrules(listrules(ModalDecisionTrees.translate(tree2))))

mach = machine(ModalDecisionTree(;), Xnt, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 137
@test nnodes(fitted_params(mach).rawmodel) == 137
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.70

mach = machine(ModalDecisionTree(;
n_subfeatures = 0,
max_depth = 6,
min_samples_leaf = 5,
), Xnt, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 77
@test nnodes(fitted_params(mach).rawmodel) == 77
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.75


Expand All @@ -162,5 +162,5 @@ mach = machine(ModalRandomForest(;
min_samples_split = 2,
min_purity_increase = 0.0,
), Xnt, y) |> m->fit!(m, rows = train_idxs)
@test nnodes(fitted_params(mach).model) == 77
@test nnodes(fitted_params(mach).rawmodel) == 77
@test sum(predict_mode(mach, rows = test_idxs) .== y[test_idxs]) / length(y[test_idxs]) > 0.75
6 changes: 3 additions & 3 deletions test/classification/iris-params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ X, y = @load_iris

model = ModalDecisionTree(; max_depth = 0)
mach = machine(model, X, y) |> fit!
@test height(fitted_params(mach).model) == 0
@test depth(fitted_params(mach).model) == 0
@test height(fitted_params(mach).rawmodel) == 0
@test depth(fitted_params(mach).rawmodel) == 0

model = ModalDecisionTree(; max_depth = 2, )
mach = machine(model, X, y) |> fit!
@test depth(fitted_params(mach).model) == 2
@test depth(fitted_params(mach).rawmodel) == 2

model = ModalDecisionTree(;
min_samples_leaf = 2,
Expand Down
6 changes: 3 additions & 3 deletions test/classification/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ yhat = MLJ.predict_mode(mach, X)

@test MLJBase.accuracy(y, yhat) > 0.8

@test_nowarn fitted_params(mach).model
@test_nowarn fitted_params(mach).rawmodel
@test_nowarn report(mach).solemodel

@test_nowarn printmodel(prune(fitted_params(mach).model, simplify=true, min_samples_leaf = 20), max_depth = 3)
@test_nowarn printmodel(prune(fitted_params(mach).model, simplify=true, min_samples_leaf = 20))
@test_nowarn printmodel(prune(fitted_params(mach).rawmodel, simplify=true, min_samples_leaf = 20), max_depth = 3)
@test_nowarn printmodel(prune(fitted_params(mach).rawmodel, simplify=true, min_samples_leaf = 20))

@test_nowarn printmodel(report(mach).solemodel, header = false)
@test_nowarn printmodel(report(mach).solemodel, header = :brief)
Expand Down
6 changes: 3 additions & 3 deletions test/classification/japanesevowels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ acc = sum(yhat .== y[test_idxs])/length(yhat)
@test_throws ErrorException listrules(report(mach).solemodel; use_shortforms=false, use_leftmostlinearform = true, force_syntaxtree = true)

# Access raw model
fitted_params(mach).model;
fitted_params(mach).rawmodel;
report(mach).printmodel(3);

MLJ.fit!(mach)
Expand Down Expand Up @@ -113,5 +113,5 @@ yhat = MLJ.predict_mode(mach, rows=test_idxs)
acc = sum(yhat .== y[test_idxs])/length(yhat)
MLJ.kappa(yhat, y[test_idxs])

@test_nowarn prune(fitted_params(mach).model, simplify=true)
@test_nowarn prune(fitted_params(mach).model, simplify=true, min_samples_leaf = 20)
@test_nowarn prune(fitted_params(mach).rawmodel, simplify=true)
@test_nowarn prune(fitted_params(mach).rawmodel, simplify=true, min_samples_leaf = 20)
6 changes: 3 additions & 3 deletions test/other-test-stuff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fitresult = MMI.fit(model, 0, X_train, Y_train);

Y_test_preds, test_tree = MMI.predict(model, fitresult[1], X_test, Y_test);

tree = fitresult[1].model
tree = fitresult[1].rawmodel

fitresult[3].print_tree()

Expand All @@ -57,7 +57,7 @@ println(test_tree)
# SoleModels.ConfusionMatrix(Y_test_preds, Y_test)


# tree = fitresult.model
# tree = fitresult.rawmodel
# println(tree)
# println(test_tree)
# fitreport.print_tree()
Expand Down Expand Up @@ -99,7 +99,7 @@ show_latex(test_tree, "test", [variable_names_latex])

# Y_test_preds, test_tree = MMI.predict(model, fitresult[1], X_static_test, Y_test);

# tree = fitresult[1].model
# tree = fitresult[1].rawmodel

# fitresult[3].print_tree()

Expand Down

0 comments on commit 78cc5b2

Please sign in to comment.