Skip to content

Commit

Permalink
fixed metrics function and adding some test about it
Browse files Browse the repository at this point in the history
  • Loading branch information
Michele21 committed Mar 6, 2024
1 parent e0c7012 commit 31c184d
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/SoleModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export AssociationRule, ClassificationRule, RegressionRule

include("machine-learning.jl")

export rulemetrics, readmetrics
export rulemetrics, readmetrics, metrics

include("evaluation.jl")

Expand Down
50 changes: 37 additions & 13 deletions src/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ using SoleModels
using SoleModels: LeafModel
import SoleLogics: natoms

############################# Utils ###########################

nargs(func::Function) = length(methods(func)[1].sig.parameters) - 1

####################### Utility Functions #####################

function accuracy(
Expand Down Expand Up @@ -79,20 +83,22 @@ function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = D
_gts_leaf = info(consequent(m)).supporting_labels
coverage = length(_gts_leaf)/length(_gts)

merge(
return merge(
readmetrics(consequent(m); digits = digits, kwargs...),
(; coverage = round(coverage; digits = digits)),
if haskey(info(m), :supporting_predictions)
_preds = info(m).supporting_predictions
metrics = (;)

for (key,func) in dictmetrics
if Base.nargs(func) == 2
merge(metrics, (; key => func(_gts,_preds)))
if nargs(func) == 2
metrics = merge(metrics, (; key => func(_gts,_preds)))
end
end

metrics
else
(;)
end
)
elseif haskey(info(m), :supporting_labels)
Expand All @@ -104,12 +110,26 @@ function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = D
metrics = (;)

for (key,func) in dictmetrics
if Base.nargs(func) == 1
merge(metrics, (; key => func(_gts)))
if nargs(func) == 1
metrics = merge(metrics, (; key => func(_gts)))
end
end

metrics
end,
if haskey(info(m), :supporting_predictions)
_preds = info(m).supporting_predictions
metrics = (;)

for (key,func) in dictmetrics
if nargs(func) == 2
metrics = merge(metrics, (; key => func(_gts,_preds)))
end
end

metrics
else
(;)
end
)
elseif haskey(info(consequent(m)), :supporting_labels)
Expand Down Expand Up @@ -259,10 +279,10 @@ end

function metrics(
rule::Rule,
X::Unione{Nothing,AbstractInterpretationSet},
Y::Unione{Nothing,AbstractVector{<:Label}};
X::Union{Nothing,AbstractInterpretationSet} = nothing,
Y::Union{Nothing,AbstractVector{<:Label}} = nothing;
return_model::Bool = false,
dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}(),
dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}([(:accuracy,accuracy),]),
kwargs...,
)
info_rule = info(rule)
Expand All @@ -287,18 +307,22 @@ function metrics(
end
end

new_inforule = begin
tmp_inforule = begin
if !isnothing(supporting_labels) && !isnothing(supporting_predictions)
(info_rule..., supporting_labels = supporting_labels, supporting_predictions = supporting_predictions)
idxs_valid = supporting_predictions .!= nothing
(info_rule..., supporting_labels = supporting_labels[idxs_valid], supporting_predictions = supporting_predictions[idxs_valid])
elseif !isnothing(supporting_labels)
(info_rule..., supporting_labels = supporting_labels)
elseif !isnothing(supporting_predictions)
(info_rule..., supporting_predictions = supporting_predictions)
idxs_valid = supporting_predictions .!= nothing
(info_rule..., supporting_predictions = supporting_predictions[idxs_valid])
else
info_rule
end
end

newrule = Rule(antecedent(rule),consequent(rule),new_inforule)
readmetrics(newrule; dictmetrics=dictmetrics)
tmprule = Rule(antecedent(rule),consequent(rule),tmp_inforule)
updatedinfo = readmetrics(tmprule; dictmetrics=dictmetrics)

return return_model ? Rule(antecedent(rule),consequent(rule),updatedinfo) : updatedinfo
end
107 changes: 107 additions & 0 deletions test/metrics-test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using Test
using Revise
using Random
using DataFrames
using SoleLogics
using SoleModels
using SoleData
using SoleData.DimensionalDatasets

n_instances = 20
rng = MersenneTwister(42)

# Dataset Construction
attributes = ["fever","pressure"]
attributes_values = [
[rand(rng,collect(36:0.5:40)) for i in 1:n_instances],
[rand(rng,collect(60:2:130)) for i in 1:n_instances],
]
y_true = [attributes_values[1][i] >= 37.5 && attributes_values[2][i] <= 100 ? "sick" : "not sick" for i in 1:n_instances]
dataset = DataFrame(; NamedTuple([Symbol(attributes[i]) => attributes_values[i] for i in 1:length(attributes)])...)

# Logiset Definition
nvars = nvariables(dataset)
features = collect(Iterators.flatten([[UnivariateMin(i_var)] for i_var in 1:nvars]))
logiset = scalarlogiset(dataset, features; use_full_memoization = false, use_onestep_memoization = false)

# Rule Definition: max[V1] >= 38, max[V2] < 110

# Build a formula on scalar conditions
condition1 = ScalarCondition(features[1], >=, 38.0)
condition2 = ScalarCondition(features[2], <, 110)
antecedentrule = Atom(condition1) Atom(condition2)

# Build consequent
consequentrule = "sick"

# Build Rule without info
rule = Rule(antecedentrule, consequentrule)

inforule = metrics(rule)
inforulelogiset = metrics(rule,logiset)
inforuley = metrics(rule,Y = y_true)
inforuleall = metrics(rule,logiset,y_true)
newrule = metrics(rule; return_model=true)
newrulelogiset = metrics(rule,logiset; return_model=true)
newruley = metrics(rule,Y = y_true; return_model=true)
newruleall = metrics(rule,logiset,y_true; return_model=true)


@test inforule == NamedTuple()
@test inforulelogiset == NamedTuple()
@test inforuley == NamedTuple()
@test inforuleall == (ninstances = 8, accuracy = 1.0,)
@test SoleModels.info(newrule) == NamedTuple()
@test SoleModels.info(newrule) == NamedTuple()
@test SoleModels.info(newrulelogiset) == NamedTuple()
@test SoleModels.info(newruley) == NamedTuple()
@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,)

# Build Rule with info
rule = Rule(antecedentrule,consequentrule,(; supporting_labels = y_true))

inforule = metrics(rule)
inforulelogiset = metrics(rule,logiset)
inforuley = metrics(rule,Y = y_true)
inforuleall = metrics(rule,logiset,y_true)
newrule = metrics(rule; return_model=true)
newrulelogiset = metrics(rule,logiset; return_model=true)
newruley = metrics(rule,Y = y_true; return_model=true)
newruleall = metrics(rule,logiset,y_true; return_model=true)


@test inforule == (ninstances = 20,)
@test inforulelogiset == (ninstances = 8, accuracy = 1.0,)
@test inforuley == (ninstances = 20,)
@test inforuleall == (ninstances = 8, accuracy = 1.0,)
@test SoleModels.info(newrule) == (ninstances = 20,)
@test SoleModels.info(newrule) == (ninstances = 20,)
@test SoleModels.info(newrulelogiset) == (ninstances = 8, accuracy = 1.0,)
@test SoleModels.info(newruley) == (ninstances = 20,)
@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,)

# Build Rule with info
supp_preds = apply(Rule(antecedentrule,consequentrule),logiset)
rule = Rule(antecedentrule,consequentrule,(; supporting_labels = y_true, supporting_predictions = supp_preds))

inforule = metrics(rule)
inforulelogiset = metrics(rule,logiset)
inforuley = metrics(rule,Y = y_true)
inforuleall = metrics(rule,logiset,y_true)
newrule = metrics(rule; return_model=true)
newrulelogiset = metrics(rule,logiset; return_model=true)
newruley = metrics(rule,Y = y_true; return_model=true)
newruleall = metrics(rule,logiset,y_true; return_model=true)


@test inforule == (ninstances = 8, accuracy = 1.0)
@test inforulelogiset == (ninstances = 8, accuracy = 1.0,)
@test inforuley == (ninstances = 8, accuracy = 1.0)
@test inforuleall == (ninstances = 8, accuracy = 1.0,)
@test SoleModels.info(newrule) == (ninstances = 8, accuracy = 1.0)
@test SoleModels.info(newrule) == (ninstances = 8, accuracy = 1.0)
@test SoleModels.info(newrulelogiset) == (ninstances = 8, accuracy = 1.0,)
@test SoleModels.info(newruley) == (ninstances = 8, accuracy = 1.0)
@test SoleModels.info(newruleall) == (ninstances = 8, accuracy = 1.0,)


0 comments on commit 31c184d

Please sign in to comment.