diff --git a/Project.toml b/Project.toml index ca8b564..a67e641 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "e54bda2e-c571-11ec-9d64-0242ac120002" license = "MIT" desc = "Julia implementation of Modal Decision Trees and Random Forest algorithms" authors = ["Giovanni PAGLIARINI"] -version = "0.3.3" +version = "0.3.4" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/interfaces/MLJ.jl b/src/interfaces/MLJ.jl index c12be9e..89480be 100644 --- a/src/interfaces/MLJ.jl +++ b/src/interfaces/MLJ.jl @@ -122,7 +122,7 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class (Xnew, ynew, var_grouping, classes_seen, w) = MMI.reformat(m, Xnew, ynew; passive_mode = true) preds, sprinkledmodel = ModalDecisionTrees.sprinkle(model, Xnew, ynew) if simplify - sprinkledmodel = MDT.prune(model; simplify = true) + sprinkledmodel = MDT.prune(sprinkledmodel; simplify = true) end preds, translate_function(sprinkledmodel) end, diff --git a/test/other-test-stuff.jl b/test/other-test-stuff.jl index 8982319..544a385 100644 --- a/test/other-test-stuff.jl +++ b/test/other-test-stuff.jl @@ -37,9 +37,9 @@ dataset_name = "NATOPS" # dataset_name = "RacketSports" # dataset_name = "Libras" -X_train, y = SoleModels.load_arff_dataset(dataset_name) +X, y = SoleModels.load_arff_dataset(dataset_name) -fitresult = MMI.fit(model, 0, X_train, Y_train); +fitresult = MMI.fit(model, 0, X, Y); Y_test_preds, test_tree = MMI.predict(model, fitresult[1], X_test, Y_test);