From 133dfaf0cc5d104d68080b5d9fd3fb9cad4dfeaa Mon Sep 17 00:00:00 2001 From: giopaglia <24519853+giopaglia@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:15:30 +1100 Subject: [PATCH] fix sprinkling --- Project.toml | 2 +- src/interfaces/MLJ.jl | 2 +- test/other-test-stuff.jl | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ca8b564..a67e641 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "e54bda2e-c571-11ec-9d64-0242ac120002" license = "MIT" desc = "Julia implementation of Modal Decision Trees and Random Forest algorithms" authors = ["Giovanni PAGLIARINI"] -version = "0.3.3" +version = "0.3.4" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/interfaces/MLJ.jl b/src/interfaces/MLJ.jl index c12be9e..89480be 100644 --- a/src/interfaces/MLJ.jl +++ b/src/interfaces/MLJ.jl @@ -122,7 +122,7 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class (Xnew, ynew, var_grouping, classes_seen, w) = MMI.reformat(m, Xnew, ynew; passive_mode = true) preds, sprinkledmodel = ModalDecisionTrees.sprinkle(model, Xnew, ynew) if simplify - sprinkledmodel = MDT.prune(model; simplify = true) + sprinkledmodel = MDT.prune(sprinkledmodel; simplify = true) end preds, translate_function(sprinkledmodel) end, diff --git a/test/other-test-stuff.jl b/test/other-test-stuff.jl index 8982319..544a385 100644 --- a/test/other-test-stuff.jl +++ b/test/other-test-stuff.jl @@ -37,9 +37,9 @@ dataset_name = "NATOPS" # dataset_name = "RacketSports" # dataset_name = "Libras" -X_train, y = SoleModels.load_arff_dataset(dataset_name) +X, y = SoleModels.load_arff_dataset(dataset_name) -fitresult = MMI.fit(model, 0, X_train, Y_train); +fitresult = MMI.fit(model, 0, X, Y); Y_test_preds, test_tree = MMI.predict(model, fitresult[1], X_test, Y_test);