Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
giopaglia committed Jul 31, 2024
2 parents 4c52302 + edd5bae commit 646c67d
Show file tree
Hide file tree
Showing 34 changed files with 36,455 additions and 53 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ Reexport = "1"
ResumableFunctions = "0.6"
Revise = "3"
ScientificTypes = "3"
Sole = "0.5"
Sole = "0.6"
SoleBase = "0.12"
SoleData = "0.14"
SoleData = "0.15"
SoleLogics = "0.9"
SoleModels = "0.7"
SoleModels = "0.8"
StatsBase = "0.30 - 0.34"
Suppressor = "0.2"
Tables = "1"
Expand All @@ -79,6 +79,7 @@ julia = "1"

[extras]
ARFFFiles = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8"
D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Expand All @@ -92,7 +93,6 @@ PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45"

[targets]
test = ["ARFFFiles", "HTTP", "ImageFiltering", "Images", "InteractiveUtils", "LinearAlgebra", "MLDatasets", "Markdown", "PlutoUI", "RDatasets", "Test", "ZipFile", "MLJ", "MLJBase", "D3Trees"]
1 change: 1 addition & 0 deletions TODO
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
newpurity = -newent

Fixes:
☐ Be careful! parse_tree parses 1.65e-8 as 1.65.
☐ Add tests for AbstractTrees: `using GraphicRecipes; plot(TreePlot(dtree), method = :tree, nodeshape = :ellipse)`
☐ Entropy: regression, fix variance-based loss, and test it; Finally understand how entropy works, and fix it!!! (check test-entropy.txt)

Expand Down
4 changes: 2 additions & 2 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ struct DTLeaf{L<:Label} <: AbstractDecisionLeaf{L}
DTLeaf(prediction::L, supp_labels::AbstractVector) where {L<:Label} = DTLeaf{L}(prediction, supp_labels)

# create leaf without supporting labels
DTLeaf{L}(prediction) where {L<:Label} = DTLeaf{L}(prediction, [prediction])
DTLeaf(prediction::L) where {L<:Label} = DTLeaf{L}(prediction, [prediction])
DTLeaf{L}(prediction) where {L<:Label} = DTLeaf{L}(prediction, L[])
DTLeaf(prediction::L) where {L<:Label} = DTLeaf{L}(prediction, L[])

# create leaf from supporting labels
DTLeaf{L}(supp_labels::AbstractVector) where {L<:Label} = DTLeaf{L}(bestguess(L.(supp_labels)), supp_labels)
Expand Down
71 changes: 35 additions & 36 deletions src/experimentals/parse.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

function parse_tree(
tree_str::String;
check_format = true,
Expand Down Expand Up @@ -46,35 +45,35 @@ function _parse_tree(
# _threshold_ex = "[^\\)\\s)]+" # TODO use smarter regex (e.g., https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch06s10.html )
# Regex("[-+]?([0-9]+(\\.[0-9]*)?|\\.[0-9]+)") == r"[-+]?([0-9]+(\.[0-9]*)?|\.[0-9]+)"
# _threshold_ex = "[-+]?([0-9]+(\\.[0-9]*)?|\\.[0-9]+)" # GOOD
_threshold_ex = "[-+]?(?:[0-9]+(?:\\.[0-9]*)?|\\.[0-9]+)" # https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch06s10.html
_threshold_ex = "[-+]?(?:[0-9]+(?:\\.[0-9]*)?|\\.[0-9]+)(?:e?[-+]?)(?:[0-9]+)?" # https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch06s10.html

_indentation_ex = "[ │]*[✔✘]"
_metrics_ex = "\\(\\S*.*\\)"
_feature_ex = "(?:\\S+)\\s+(?:(?:⫹|⫺|⪳|⪴|⪵|⪶|↗|↘|>|<|=|≤|≥|<=|>=))"
_feature_ex = "(?:[^\\s\\(\\)]+)\\s+(?:(?:⫹|⫺|⪳|⪴|⪵|⪶|↗|↘|>|<|=|≤|≥|<=|>=))"
_normal_feature_ex_capturing = "^(\\S*)\\$openpar$V(\\d+)\\$closepar\\s+((?:>|<|=|≤|≥|<=|>=))\$"
_propositional_feature_ex_capturing = "^$V(\\d+)\\s+((?:>|<|=|≤|≥|<=|>=))\$"
_special_feature_ex_capturing = "^$V(\\d+)\\s+((?:⫹|⫺|⪳|⪴|⪵|⪶|↗|↘))\$"
_decision_ex = "$(_feature_ex)\\s+(?:$(_threshold_ex))"
_decision_ex__capturing = "($(_feature_ex))\\s+($(_threshold_ex))"
_decision_pr = "$(_feature_ex)\\s+(?:$(_threshold_ex))"
_decision_pr__capturing = "($(_feature_ex))\\s+($(_threshold_ex))"

leaf_ex = "(?:\\S+)\\s+:\\s+\\d+/\\d+(?:\\s+(?:$(_metrics_ex)))?"
leaf_ex__capturing = "(\\S+)\\s+:\\s+(\\d+)/(\\d+)(?:\\s+($(_metrics_ex)))?"
decision_ex = "(?:⟨(?:\\S+)⟩\\s*)?(?:$(_decision_ex)|\\(\\s*$(_decision_ex)\\s*\\))"
decision_ex__capturing = "(?:⟨(\\S+)⟩\\s*)?\\(?\\s*$(_decision_ex__capturing)\\s*\\)?"
# decision_ex__capturing = "(?:⟨(\\S+)⟩\\s*)?\\s*($(_decision_ex__capturing)|\\(\\s*$(_decision_ex__capturing)\\s*\\))"
decision_ex = "(?:SimpleDecision\\()?(?:⟨(?:\\S+)⟩\\s*)?(?:$(_decision_pr)|\\(\\s*$(_decision_pr)\\s*\\))(?:\\))?"
decision_ex__capturing = "(?:SimpleDecision\\()?(?:⟨(\\S+)⟩\\s*)?\\(?\\s*$(_decision_pr__capturing)\\s*\\)?(?:\\))?"
# _decision_ex__capturing = "(?:⟨(\\S+)⟩\\s*)?\\s*($(_decision_pr__capturing)|\\(\\s*$(_decision_pr__capturing)\\s*\\))"

# TODO default frame to 1
# split_ex = "(?:\\s*{(\\d+)}\\s+)?($(decision_ex))(?:\\s+($(leaf_ex)))?"
split_ex = "\\s*{(\\d+)}\\s+($(decision_ex))(?:\\s*($(leaf_ex)))?"

blank_line_regex = Regex("^\\s*\$")
split_line_regex = Regex("^($(_indentation_ex)\\s+)?$(split_ex)\\s*\$")
leaf_line_regex = Regex("^($(_indentation_ex)\\s+)?$(leaf_ex)\\s*\$")
leaf_line_regex = Regex("^($(_indentation_ex)\\s+)?$(leaf_ex)\\s*\$")

function _parse_simple_real(x)
x = parse(Float64, x)
x = isinteger(x) ? Int(x) : x
end
end

function _parse_decision((i_this_line, decision_str)::Tuple{<:Integer,<:AbstractString},)
function _parse_relation(relation_str)
Expand Down Expand Up @@ -107,9 +106,9 @@ function _parse_tree(
end

function _parse_feature_test_operator(feature_str)

m_normal = match(Regex(_normal_feature_ex_capturing), feature_str)
m_special = match(Regex(_special_feature_ex_capturing), feature_str)
m_special = match(Regex(_special_feature_ex_capturing), feature_str)
m_propos = match(Regex(_propositional_feature_ex_capturing), feature_str)

if !isnothing(m_normal) && length(m_normal) == 3
Expand Down Expand Up @@ -152,21 +151,21 @@ function _parse_tree(
else
error("Unexpected format encountered on line $(i_this_line+offset) when parsing feature: \"$(feature_str)\". Matches $(m_normal), $(m_special), $(m_propos)")
end
end
end

print(repeat(" ", _depth))
m = match(Regex(decision_ex), decision_str)
@assert !isnothing(m) "Unexpected format encountered on line $(i_this_line+offset) when parsing decision: \"$(decision_str)\". Matches: $(m)"
@assert !isnothing(m) "Unexpected format encountered on line $(i_this_line+offset) when parsing decision: \"$(decision_str)\". Matches: $(m)"

m = match(Regex(decision_ex__capturing), decision_str)
@assert !isnothing(m) && length(m) == 3 "Unexpected format encountered on line $(i_this_line+offset) when parsing decision: \"$(decision_str)\". Matches: $(m) Expected matches = 3"
# print(repeat(" ", _depth))
# println(m)

# println(m)
# @show m[3]
relation, feature_test_operator, threshold = m
relation = _parse_relation(relation)
feature, test_operator = _parse_feature_test_operator(feature_test_operator)
threshold = _parse_simple_real(threshold)
threshold = _parse_simple_real(threshold)

RestrictedDecision(ScalarExistentialFormula(relation, feature, test_operator, threshold))
end
Expand All @@ -187,51 +186,51 @@ function _parse_tree(
########################################################################################
########################################################################################
########################################################################################

# Can't do this because then i_line is misaligned
# tree_str = strip(tree_str)
lines = enumerate(split(tree_str, "\n")) |> collect

if check_format
for (i_line, line) in lines
!isempty(strip(line)) || continue
_line = line
_line = line

blank_match = match(blank_line_regex, _line)
split_match = match(split_line_regex, _line)
leaf_match = match(leaf_line_regex, _line)
is_blank = !isnothing(blank_match)
is_split = !isnothing(split_match)
is_leaf = !isnothing(leaf_match)

# DEBUG
# println(match(Regex("($(_indentation_ex)\\s+)?"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?$V(\\d+)"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?(\\S+\\s+)?$V(\\d+)"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?(\\S+\\s+)?$V(\\d+)\\s+([⫹⫺⪳⪴⪵⪶↗↘])"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?$(_decision_ex)"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?$(_decision_pr)"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?$(decision_ex)"), _line))
# println(match(Regex("^\\s*($(_indentation_ex)\\s+)?({(\\d+)}\\s+)?$(decision_ex)\\s+$(leaf_ex)"), _line))

@assert xor(is_blank, is_split, is_leaf) "Could not parse line $(i_line+offset): \"$(line)\". $((is_blank, is_split, is_leaf))"
end
end
end

_lines = filter(((i_line, line),)->(!isempty(strip(line))), lines)
_lines = filter(((i_line, line),)->(!isempty(strip(line))), lines)

if length(_lines) == 1 # a leaf
_parse_leaf(_lines[1])
else # a split

this_line, yes_line, no_line = begin
this_line = nothing
yes_line = -Inf
no_line = Inf
no_line = Inf

for (i_line, line) in lines
!isempty(strip(line)) || continue
_line = line
_line = line

if !isnothing(match(r"^\s*{.*$", _line))
@assert isnothing(this_line) "Cannot have more than one row beginning with '{'"
Expand All @@ -252,7 +251,7 @@ function _parse_tree(
end
this_line, yes_line, no_line
end

function clean_lines(lines)
join([(isempty(strip(line)) ? line : begin
begin_ex = Regex("^([ │]|[✔✘]\\s+)(.*)\$")
Expand All @@ -261,27 +260,27 @@ function _parse_tree(
end
left_tree_str, right_tree_str = clean_lines(lines[yes_line:no_line]), clean_lines(lines[no_line+1:end])
i_this_line, this_line = lines[this_line]

print(repeat(" ", _depth))
m = match(Regex(split_ex), this_line)
@assert !isnothing(m) && length(m) == 3 "Unexpected format encountered on line $(i_this_line+offset) : \"$(this_line)\". Matches: $(m) Expected matches = 3"
# println(m)
i_modality, decision_str, leaf_str = m
# @show i_modality, decision_str, leaf_str
i_modality = parse(Int, i_modality)
decision = _parse_decision((i_this_line, decision_str),)
decision = _parse_decision((i_this_line, decision_str),)

# println(clean_lines(lines[yes_line:no_line]))
# println("\n")
# println(clean_lines(lines[no_line+1:end]))
left = _parse_tree(left_tree_str; offset = yes_line-1, check_format = false, _depth = _depth + 1, child_kwargs...)
right = _parse_tree(right_tree_str; offset = no_line-1, check_format = false, _depth = _depth + 1, child_kwargs...)

if isnothing(leaf_str)
DTInternal(i_modality, decision, left, right)
else
this = _parse_leaf((i_this_line, leaf_str),)
DTInternal(i_modality, decision, this, left, right)
end
end
end
end
11 changes: 9 additions & 2 deletions src/interfaces/Sole/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,17 @@ function translate(
# end
# end

forthnode_as_a_leaf = ModalDecisionTrees.this(forthnode)
this_as_a_leaf = translate(forthnode_as_a_leaf, initconditions, new_path, new_pos_path, ancestors, ancestor_formulas; shortform = shortform, optimize_shortforms = optimize_shortforms)

info = merge(info, (;
this = translate(ModalDecisionTrees.this(forthnode), initconditions, new_path, new_pos_path, ancestors, ancestor_formulas; shortform = shortform, optimize_shortforms = optimize_shortforms),
supporting_labels = ModalDecisionTrees.supp_labels(forthnode),
this = this_as_a_leaf,
# supporting_labels = SoleModels.info(this_as_a_leaf, :supporting_labels),
supporting_labels = ModalDecisionTrees.supp_labels(forthnode_as_a_leaf),
# supporting_predictions = SoleModels.info(this_as_a_leaf, :supporting_predictions),
supporting_predictions = ModalDecisionTrees.predictions(forthnode_as_a_leaf),
))

if !isnothing(shortform)
# @show syntaxstring(shortform)
info = merge(info, (;
Expand Down
12 changes: 6 additions & 6 deletions src/interpret-onestep-decisions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ Base.@propagate_inbounds @resumable function generate_decisions(
allow_propositional_decisions::Bool,
allow_modal_decisions::Bool,
allow_global_decisions::Bool,
modal_relations_inds::AbstractVector{<:Integer},
features_inds::AbstractVector{<:Integer},
modal_relations_inds::AbstractVector,
features_inds::AbstractVector,
grouped_featsaggrsnops::AbstractVector{<:AbstractDict{<:Aggregator,<:AbstractVector{<:ScalarMetaCondition}}},
grouped_featsnaggrs::AbstractVector{<:AbstractVector{Tuple{<:Integer,<:Aggregator}}},
) where {W<:AbstractWorld,U}
Expand Down Expand Up @@ -328,7 +328,7 @@ Base.@propagate_inbounds @resumable function generate_propositional_decisions(
X::AbstractScalarLogiset{W,U,FT,FR},
i_instances::AbstractVector{<:Integer},
Sf::AbstractVector{<:AbstractWorlds{W}},
features_inds::AbstractVector{<:Integer},
features_inds::AbstractVector,
grouped_featsaggrsnops::AbstractVector{<:AbstractDict{<:Aggregator,<:AbstractVector{<:ScalarMetaCondition}}},
grouped_featsnaggrs::AbstractVector{<:AbstractVector{Tuple{<:Integer,<:Aggregator}}},
) where {W<:AbstractWorld,U,FT<:AbstractFeature,N,FR<:FullDimensionalFrame{N,W}}
Expand Down Expand Up @@ -395,8 +395,8 @@ Base.@propagate_inbounds @resumable function generate_modal_decisions(
X::AbstractScalarLogiset{W,U,FT,FR},
i_instances::AbstractVector{<:Integer},
Sf::AbstractVector{<:AbstractWorlds{W}},
modal_relations_inds::AbstractVector{<:Integer},
features_inds::AbstractVector{<:Integer},
modal_relations_inds::AbstractVector,
features_inds::AbstractVector,
grouped_featsaggrsnops::AbstractVector{<:AbstractDict{<:Aggregator,<:AbstractVector{<:ScalarMetaCondition}}},
grouped_featsnaggrs::AbstractVector{<:AbstractVector{Tuple{<:Integer,<:Aggregator}}},
) where {W<:AbstractWorld,U,FT<:AbstractFeature,N,FR<:FullDimensionalFrame{N,W}}
Expand Down Expand Up @@ -482,7 +482,7 @@ Base.@propagate_inbounds @resumable function generate_global_decisions(
X::AbstractScalarLogiset{W,U,FT,FR},
i_instances::AbstractVector{<:Integer},
Sf::AbstractVector{<:AbstractWorlds{W}},
features_inds::AbstractVector{<:Integer},
features_inds::AbstractVector,
grouped_featsaggrsnops::AbstractVector{<:AbstractDict{<:Aggregator,<:AbstractVector{<:ScalarMetaCondition}}},
grouped_featsnaggrs::AbstractVector{<:AbstractVector{Tuple{<:Integer,<:Aggregator}}},
) where {W<:AbstractWorld,U,FT<:AbstractFeature,N,FR<:FullDimensionalFrame{N,W}}
Expand Down
Loading

0 comments on commit 646c67d

Please sign in to comment.