Skip to content

Commit

Permalink
metrics function to get performance's metrics for a rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Michele21 committed Feb 24, 2024
1 parent 60c50ed commit b4306de
Showing 1 changed file with 87 additions and 11 deletions.
98 changes: 87 additions & 11 deletions src/evaluation.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
using StatsBase
using SoleModels
using SoleModels: LeafModel
import SoleLogics: natoms

####################### Util Functions #####################
####################### Utils Functions #####################

function accuracy(
y_true::AbstractArray,
Expand All @@ -30,8 +29,7 @@ function mae(
@assert length(y_true) == length(y_pred) "True labels and Predicted labels don't have the same length"
@assert length(y_pred) > 0 "Don't pass the labels, the two vectors are empty"

# return sum(abs.(y_pred.-y_true)) / length(y_pred)
return StatsBase.mean(abs, y_true - y_pred)
return sum(abs.(y_pred.-y_true)) / length(y_pred)
end

function mse(
Expand All @@ -41,8 +39,7 @@ function mse(
@assert length(y_true) == length(y_pred) "True labels and Predicted labels don't have the same length"
@assert length(y_pred) > 0 "Don't pass the labels, the two vectors are empty"

# return sum((y_pred.-y_true).^2) / length(y_pred)
return StatsBase.mean(abs2, y_true - y_pred)
return sum((y_pred.-y_true).^2) / length(y_pred)
end

#############################################################
Expand All @@ -58,14 +55,14 @@ Performance metrics can be computed when the `info` structure of the model has t
The `digits` keyword argument is used to `round` accuracy/confidence metrics.
"""
function readmetrics(m::LeafModel{L}; digits = 2) where {L<:Label}
function readmetrics(m::LeafModel{L}; digits = 2, dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}()) where {L<:Label}
merge(if haskey(info(m), :supporting_labels) && haskey(info(m), :supporting_predictions)
_gts = info(m).supporting_labels
_preds = info(m).supporting_predictions
if L <: CLabel
(; ninstances = length(_gts), confidence = round(accuracy(_gts, _preds); digits = digits))
elseif L <: RLabel
(; ninstances = length(_gts), mse = round(mse(_gts, _preds); digits = digits))
(; ninstances = length(_gts), mae = round(mae(_gts, _preds); digits = digits))
else
error("Could not compute readmetrics with unknown label type: $(L).")
end
Expand All @@ -76,14 +73,45 @@ function readmetrics(m::LeafModel{L}; digits = 2) where {L<:Label}
end, (; coverage = 1.0))
end

function readmetrics(m::Rule; digits = 2, kwargs...)
function readmetrics(m::Rule; digits = 2, dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}(), kwargs...)
if haskey(info(m), :supporting_labels) && haskey(info(consequent(m)), :supporting_labels)
_gts = info(m).supporting_labels
_gts_leaf = info(consequent(m)).supporting_labels
coverage = length(_gts_leaf)/length(_gts)
merge(readmetrics(consequent(m); digits = digits, kwargs...), (; coverage = round(coverage; digits = digits)))

merge(
readmetrics(consequent(m); digits = digits, kwargs...),
(; coverage = round(coverage; digits = digits)),
if haskey(info(m), :supporting_predictions)
_preds = info(m).supporting_predictions
metrics = (;)

for (key,func) in dictmetrics
if Base.nargs(func) == 2
merge(metrics, (; key => func(_gts,_preds)))
end
end

metrics
end
)
elseif haskey(info(m), :supporting_labels)
return (; ninstances = length(info(m).supporting_labels))
_gts = info(m).supporting_labels

return merge(
(; ninstances = length(_gts)),
begin
metrics = (;)

for (key,func) in dictmetrics
if Base.nargs(func) == 1
merge(metrics, (; key => func(_gts)))
end
end

metrics
end
)
elseif haskey(info(consequent(m)), :supporting_labels)
return (; ninstances = length(info(m).supporting_labels))
else
Expand Down Expand Up @@ -125,6 +153,7 @@ function evaluaterule(

antsat = ys .!= nothing

#=
cons_sat = begin
idxs_sat = findall(antsat .== true)
cons_sat = Vector{Union{Bool,Nothing}}(fill(nothing, length(Y)))
Expand All @@ -141,6 +170,7 @@ function evaluaterule(
cons_sat[idxs_false] .= false
cons_sat
end
=#

# - `cons_sat::Vector{Union{Nothing,Bool}}`: for each instance in the dataset:
# - `nothing` if antecedent is not satisfied.
Expand Down Expand Up @@ -226,3 +256,49 @@ end
############################################################################################
############################################################################################
############################################################################################

function metrics(
rule::Rule,
X::Unione{Nothing,AbstractInterpretationSet},
Y::Unione{Nothing,AbstractVector{<:Label}};
return_model::Bool = false,
dictmetrics::Dict{Symbol,Function} = Dict{Symbol,Function}(),
kwargs...,
)
info_rule = info(rule)

supporting_labels = begin
if !isnothing(Y)
Y
elseif haskey(info_rule,:supporting_labels)
info_rule.supporting_labels
else
nothing
end
end

supporting_predictions = begin
if !isnothing(X)
apply(rule,X)
elseif haskey(info_rule,:supporting_predictions)
info_rule.supporting_predictions
else
nothing
end
end

new_inforule = begin
if !isnothing(supporting_labels) && !isnothing(supporting_predictions)
(info_rule..., supporting_labels = supporting_labels, supporting_predictions = supporting_predictions)
elseif !isnothing(supporting_labels)
(info_rule..., supporting_labels = supporting_labels)
elseif !isnothing(supporting_predictions)
(info_rule..., supporting_predictions = supporting_predictions)
else
info_rule
end
end

newrule = Rule(antecedent(rule),consequent(rule),new_inforule)
readmetrics(newrule; dictmetrics=dictmetrics)
end

0 comments on commit b4306de

Please sign in to comment.