Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
PasoStudio73 committed Oct 8, 2024
1 parent 2e9ef69 commit 7153610
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ function propositional_analisys(
X_propos = DataFrame([name => Float64[] for name in [match(r_select, v)[1] for v in p_variable_names]])
push!(X_propos, vcat([vcat([map(func, Array(row)) for func in metaconditions]...) for row in eachrow(X)])...)

println(rng)
X_train, y_train, X_test, y_test = partitioning(X_propos, y; train_ratio=train_ratio, rng=rng)

@info("Propositional analysis: train model...")
Expand Down Expand Up @@ -203,10 +202,13 @@ function modal_analisys(
error("Unknown set of features: $features.")
end

learned_dt_tree = begin
model = ModalDecisionTree(; relations = :IA7, features = metaconditions)
mach = machine(model, X_train, y_train) |> fit!
end
# learned_dt_tree = begin
# model = ModalDecisionTree(; relations = :IA7, features = metaconditions)
# mach = machine(model, X_train, y_train) |> fit!
# end

model = ModalDecisionTree(; relations = :IA7, features = metaconditions)
mach = machine(model, X_train, y_train) |> fit!
_, mtree = report(mach).sprinkle(X_test, y_test)

# ModalDecisionTrees.translate(mtree)
Expand Down
2 changes: 2 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ function get_df_from_rawaudio(;
)
end,
)
isfile(wav_path) || throw(ArgumentError("wav_path '$wav_path' does not exist."))

df = collect_audio_from_folder(wav_path; audioparams=audioparams, fragmented=fragmented, frag_func=frag_func)
labels = isnothing(csv_file) ?
collect_classes(df, classes_dict; classes_func=classes_func) :
Expand Down
2 changes: 1 addition & 1 deletion src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ r_p_ant = [
r"^\e\[\d+m▣ ",
]
r_m_ant = [r"^SyntaxBranch:\s*", r"\e\[(?:1m|0m)", r"^SoleLogics.SyntaxBranch: *"]
r_var = r"\[V(\d+)\]"
# r_var = r"\[V(\d+)\]"

format_float(x) = replace(x, r"(\d+\.\d+)" => s -> @sprintf("%.3f", parse(Float64, s)))

Expand Down
132 changes: 132 additions & 0 deletions test/debug_experiment.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
using DataFrames, JLD2
using SoleAudio, Random
# using Plots

# -------------------------------------------------------------------------- #
# experiment specific parameters #
# -------------------------------------------------------------------------- #
wav_path = "/home/paso/Documents/Aclai/Datasets/emotion_recognition/Ravdess/audio_speech_actors_01-24"
# wav_path = "/home/paso/datasets/emotion_recognition/Ravdess/audio_speech_actors_01-24"

classes = :emo2bins
# classes = :emo3bins
# classes = :emo8bins

if classes == :emo2bins
classes_dict = Dict{String,String}(
"01" => "positive",
"02" => "positive",
"03" => "positive",
"04" => "negative",
"05" => "negative",
"06" => "negative",
"07" => "negative",
"08" => "positive"
)
elseif classes == :emo3bins
classes_dict = Dict{String,String}(
"01" => "neutral",
"02" => "neutral",
"03" => "positive",
"05" => "negative",
"07" => "negative",
)
elseif classes == :emo8bins
classes_dict = Dict{String,String}(
"01" => "neutral",
"02" => "calm",
"03" => "happy",
"04" => "sad",
"05" => "angry",
"06" => "fearful",
"07" => "disgust",
"08" => "surprised"
)
end

jld2_file = string("ravdess_", classes)

# classes will be taken from audio filename, no csv available
classes_func(row) = match(r"^(?:[^-]*-){2}([^-]*)", row.filename)[1]

# -------------------------------------------------------------------------- #
# global parameters #
# -------------------------------------------------------------------------- #
featset = (:mel, :mfcc, :f0, :spectrals)

# audioparams = let sr = 8000
# (
# sr = sr,
# norm = true,
# speech_detect = true,
# sdetect_thresholds=(0,0),
# sdetect_spread_threshold=0.02,
# nfft = 256,
# mel_scale = :semitones, # :mel_htk, :mel_slaney, :erb, :bark, :semitones, :tuned_semitones
# mel_nbands = 26,
# mfcc_ncoeffs = 13,
# mel_freqrange = (100, round(Int, sr / 2)),
# )
# end

audioparams = let sr = 8000
(
sr = sr,
norm = true,
speech_detect = true,
sdetect_thresholds=(0,0),
sdetect_spread_threshold=0.02,
nfft = 256,
mel_scale = :erb, # :mel_htk, :mel_slaney, :erb, :bark, :semitones, :tuned_semitones
mel_nbands = 26,
mfcc_ncoeffs = 13,
mel_freqrange = (100, round(Int, sr / 2)),
)
end

min_length = 18000
min_samples = 10

features = :catch9
# features = :minmax
# features = :custom

# modal analysis
nwindows = 20
relative_overlap = 0.05

# partitioning
# train_ratio = 0.8
# train_seed = 1
train_ratio = 0.7
train_seed = 9
rng = Random.MersenneTwister(train_seed)
Random.seed!(train_seed)

# -------------------------------------------------------------------------- #
# main #
# -------------------------------------------------------------------------- #
df = get_df_from_rawaudio(
wav_path=wav_path,
classes_dict=classes_dict,
classes_func=classes_func,
audioparams=audioparams,
)

irules = get_interesting_rules(
df;
featset=featset,
audioparams=audioparams,
min_length=min_length,
min_samples=min_samples,
features=features,
nwindows=nwindows,
relative_overlap=relative_overlap,
train_ratio=train_ratio,
rng=rng,
)

println(irules)

jldsave(jld2_file * ".jld2", true; irules)
@info "Done."

0 comments on commit 7153610

Please sign in to comment.