diff --git a/src/SoleModels.jl b/src/SoleModels.jl index eed9141..fd08cb2 100644 --- a/src/SoleModels.jl +++ b/src/SoleModels.jl @@ -68,7 +68,7 @@ export AssociationRule, ClassificationRule, RegressionRule include("machine-learning.jl") -export rulemetrics, readmetrics +export rulemetrics, readmetrics, metrics include("evaluation.jl") diff --git a/src/evaluation.jl b/src/evaluation.jl index 7716f93..faca862 100644 --- a/src/evaluation.jl +++ b/src/evaluation.jl @@ -2,6 +2,10 @@ using SoleModels using SoleModels: LeafModel import SoleLogics: natoms +############################# Utils ########################### + +nargs(func::Function) = length(methods(func)[1].sig.parameters) - 1 + ####################### Utility Functions ##################### function accuracy( @@ -79,7 +83,7 @@ function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = D _gts_leaf = info(consequent(m)).supporting_labels coverage = length(_gts_leaf)/length(_gts) - merge( + return merge( readmetrics(consequent(m); digits = digits, kwargs...), (; coverage = round(coverage; digits = digits)), if haskey(info(m), :supporting_predictions) @@ -87,12 +91,14 @@ function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = D metrics = (;) for (key,func) in dictmetrics - if Base.nargs(func) == 2 - merge(metrics, (; key => func(_gts,_preds))) + if nargs(func) == 2 + metrics = merge(metrics, (; key => func(_gts,_preds))) end end metrics + else + (;) end ) elseif haskey(info(m), :supporting_labels) @@ -104,12 +110,26 @@ function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = D metrics = (;) for (key,func) in dictmetrics - if Base.nargs(func) == 1 - merge(metrics, (; key => func(_gts))) + if nargs(func) == 1 + metrics = merge(metrics, (; key => func(_gts))) end end metrics + end, + if haskey(info(m), :supporting_predictions) + _preds = info(m).supporting_predictions + metrics = (;) + + for (key,func) in dictmetrics + if nargs(func) == 2 + metrics = merge(metrics, (; key => func(_gts,_preds))) + end + end + + metrics + else + (;) end ) elseif haskey(info(consequent(m)), :supporting_labels) @@ -259,10 +279,10 @@ end function metrics( rule::Rule, - X::Unione{Nothing,AbstractInterpretationSet}, - Y::Unione{Nothing,AbstractVector{<:Label}}; + X::Union{Nothing,AbstractInterpretationSet} = nothing, + Y::Union{Nothing,AbstractVector{<:Label}} = nothing; return_model::Bool = false, - dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}(), + dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}([(:accuracy,accuracy),]), kwargs..., ) info_rule = info(rule) @@ -287,18 +307,22 @@ function metrics( end end - new_inforule = begin + tmp_inforule = begin if !isnothing(supporting_labels) && !isnothing(supporting_predictions) - (info_rule..., supporting_labels = supporting_labels, supporting_predictions = supporting_predictions) + idxs_valid = supporting_predictions .!= nothing + (info_rule..., supporting_labels = supporting_labels[idxs_valid], supporting_predictions = supporting_predictions[idxs_valid]) elseif !isnothing(supporting_labels) (info_rule..., supporting_labels = supporting_labels) elseif !isnothing(supporting_predictions) - (info_rule..., supporting_predictions = supporting_predictions) + idxs_valid = supporting_predictions .!= nothing + (info_rule..., supporting_predictions = supporting_predictions[idxs_valid]) else info_rule end end - newrule = Rule(antecedent(rule),consequent(rule),new_inforule) - readmetrics(newrule; dictmetrics=dictmetrics) + tmprule = Rule(antecedent(rule),consequent(rule),tmp_inforule) + updatedinfo = readmetrics(tmprule; dictmetrics=dictmetrics) + + return return_model ? Rule(antecedent(rule),consequent(rule),updatedinfo) : updatedinfo end diff --git a/test/metrics-test.jl b/test/metrics-test.jl new file mode 100644 index 0000000..2c0cc6b --- /dev/null +++ b/test/metrics-test.jl @@ -0,0 +1,107 @@ +using Test +using Revise +using Random +using DataFrames +using SoleLogics +using SoleModels +using SoleData +using SoleData.DimensionalDatasets + +n_instances = 20 +rng = MersenneTwister(42) + +# Dataset Construction +attributes = ["fever","pressure"] +attributes_values = [ + [rand(rng,collect(36:0.5:40)) for i in 1:n_instances], + [rand(rng,collect(60:2:130)) for i in 1:n_instances], +] +y_true = [attributes_values[1][i] >= 37.5 && attributes_values[2][i] <= 100 ? "sick" : "not sick" for i in 1:n_instances] +dataset = DataFrame(; NamedTuple([Symbol(attributes[i]) => attributes_values[i] for i in 1:length(attributes)])...) + +# Logiset Definition +nvars = nvariables(dataset) +features = collect(Iterators.flatten([[UnivariateMin(i_var)] for i_var in 1:nvars])) +logiset = scalarlogiset(dataset, features; use_full_memoization = false, use_onestep_memoization = false) + +# Rule Definition: max[V1] >= 38, max[V2] < 110 + +# Build a formula on scalar conditions +condition1 = ScalarCondition(features[1], >=, 38.0) +condition2 = ScalarCondition(features[2], <, 110) +antecedentrule = Atom(condition1) ∧ Atom(condition2) + +# Build consequent +consequentrule = "sick" + +# Build Rule without info +rule = Rule(antecedentrule, consequentrule) + +inforule = metrics(rule) +inforulelogiset = metrics(rule,logiset) +inforuley = metrics(rule,Y = y_true) +inforuleall = metrics(rule,logiset,y_true) +newrule = metrics(rule; return_model=true) +newrulelogiset = metrics(rule,logiset; return_model=true) +newruley = metrics(rule,Y = y_true; return_model=true) +newruleall = metrics(rule,logiset,y_true; return_model=true) + + +@test inforule == NamedTuple() +@test inforulelogiset == NamedTuple() +@test inforuley == NamedTuple() +@test inforuleall == (ninstances = 8, accuracy = 1.0,) +@test SoleModels.info(newrule) == NamedTuple() +@test SoleModels.info(newrule) == NamedTuple() +@test SoleModels.info(newrulelogiset) == NamedTuple() +@test SoleModels.info(newruley) == NamedTuple() +@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,) + +# Build Rule with info +rule = Rule(antecedentrule,consequentrule,(; supporting_labels = y_true)) + +inforule = metrics(rule) +inforulelogiset = metrics(rule,logiset) +inforuley = metrics(rule,Y = y_true) +inforuleall = metrics(rule,logiset,y_true) +newrule = metrics(rule; return_model=true) +newrulelogiset = metrics(rule,logiset; return_model=true) +newruley = metrics(rule,Y = y_true; return_model=true) +newruleall = metrics(rule,logiset,y_true; return_model=true) + + +@test inforule == (ninstances = 20,) +@test inforulelogiset == (ninstances = 8, accuracy = 1.0,) +@test inforuley == (ninstances = 20,) +@test inforuleall == (ninstances = 8, accuracy = 1.0,) +@test SoleModels.info(newrule) == (ninstances = 20,) +@test SoleModels.info(newrule) == (ninstances = 20,) +@test SoleModels.info(newrulelogiset) == (ninstances = 8, accuracy = 1.0,) +@test SoleModels.info(newruley) == (ninstances = 20,) +@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,) + +# Build Rule with info +supp_preds = apply(Rule(antecedentrule,consequentrule),logiset) +rule = Rule(antecedentrule,consequentrule,(; supporting_labels = y_true, supporting_predictions = supp_preds)) + +inforule = metrics(rule) +inforulelogiset = metrics(rule,logiset) +inforuley = metrics(rule,Y = y_true) +inforuleall = metrics(rule,logiset,y_true) +newrule = metrics(rule; return_model=true) +newrulelogiset = metrics(rule,logiset; return_model=true) +newruley = metrics(rule,Y = y_true; return_model=true) +newruleall = metrics(rule,logiset,y_true; return_model=true) + + +@test inforule == (ninstances = 8, accuracy = 1.0) +@test inforulelogiset == (ninstances = 8, accuracy = 1.0,) +@test inforuley == (ninstances = 8, accuracy = 1.0) +@test inforuleall == (ninstances = 8, accuracy = 1.0,) +@test SoleModels.info(newrule) == (ninstances = 8, accuracy = 1.0) +@test SoleModels.info(newrule) == (ninstances = 8, accuracy = 1.0) +@test SoleModels.info(newrulelogiset) == (ninstances = 8, accuracy = 1.0,) +@test SoleModels.info(newruley) == (ninstances = 8, accuracy = 1.0) +@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,) + +