From c7b10bfe67d3fb8591b0c6a276ef0ac0e67c14f9 Mon Sep 17 00:00:00 2001 From: giopaglia <24519853+giopaglia@users.noreply.github.com> Date: Tue, 5 Dec 2023 00:00:13 +1100 Subject: [PATCH] Fix sprinkle --- src/interfaces/MLJ.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/interfaces/MLJ.jl b/src/interfaces/MLJ.jl index fbac40a..c12be9e 100644 --- a/src/interfaces/MLJ.jl +++ b/src/interfaces/MLJ.jl @@ -118,9 +118,12 @@ function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, class cache = nothing report = ( printmodel = printer, - sprinkle = (Xnew, ynew)->begin + sprinkle = (Xnew, ynew; simplify = false)->begin (Xnew, ynew, var_grouping, classes_seen, w) = MMI.reformat(m, Xnew, ynew; passive_mode = true) preds, sprinkledmodel = ModalDecisionTrees.sprinkle(model, Xnew, ynew) + if simplify + sprinkledmodel = MDT.prune(model; simplify = true) + end preds, translate_function(sprinkledmodel) end, # TODO remove redundancy?