Skip to content

Commit

Permalink
add multinomial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bvenn committed Feb 14, 2024
1 parent 64ea9fb commit 4b979b0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 75 deletions.
4 changes: 4 additions & 0 deletions src/FSharp.Stats/Distributions/Discrete/Multinomial.fs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type Multinomial =
p < 0. || p > 1. || isNan(p)
if (p |> Seq.map checkBetween |> Seq.exists id) then
failwith "Multinomial distribution should be parametrized by 0 ≤ p_i ≤ 1."

let pSum = Seq.sum p
if Math.Round(pSum,15) <> 1. then
failwithf "Multinomial distribution: The sum of probabilities should sum up to 1 but sums up to %.16f" pSum

/// <summary>Computes the mean vector</summary>
/// <remarks></remarks>
Expand Down
119 changes: 44 additions & 75 deletions tests/FSharp.Stats.Tests/DistributionsDiscrete.fs
Original file line number Diff line number Diff line change
Expand Up @@ -347,81 +347,50 @@ let binomialTests =



//[<Tests>]
//let multinomialTests =
// // TestCases from R stats: dmultinom(prob, x)
//
// testList "Distributions.Discrete.Multinominal" [
// testCase "Binomial.Mean_n=0" <| fun () ->
// let testCase = Discrete.Multinomial.Mean 0.5 0
// let r_value = 0
// Expect.equal
// testCase
// r_value
// "Multinominal mean with n=0 does not yield the expected value of 0"
//
// testCase "Binomial.Mean" <| fun () ->
// let testCase = Discrete.Multinomial.Mean 0.5 500
// let r_value = 250
// Expect.equal
// testCase
// r_value
// "Multinominal mean with n=500 and p=0.5 does not yield the expected value of 250"
//
// testCase "Binomial.Variance_n=0" <| fun () ->
// let testCase = Discrete.Multinomial.Variance 0.5 0
// let r_value = 0
// Expect.equal
// testCase
// r_value
// "Multinominal Variance with n=0 a does not yield the expected value of 0"
//
// testCase "Binomial.Variance" <| fun () ->
// let testCase = Discrete.Multinomial.Variance 0.69 420
// let r_value = 89.838
// Expect.floatClose
// Accuracy.high
// testCase
// r_value
// "Multinominal Variance with n=420 and p=0.69 does not yield the expected value of 89.838"
//
// testCase "Binomial.StandardDeviation" <| fun () ->
// let testCase = Discrete.Multinomial.StandardDeviation 0.69 420
// let r_value = 9.478291
// Expect.floatClose
// Accuracy.high
// testCase
// r_value
// "Multinominal StandardDeviation with n=420 and p=0.69 does not yield the expected value of 9.478291"
//
// testCase "Binomial.PMF" <| fun () ->
// let testCase = Discrete.Multinomial.PMF 0.69 420 237
// let r_value = 4.064494e-08
// Expect.floatClose
// Accuracy.low
// testCase
// r_value
// "Binomial.PMF with n=420, p=0.69 and k=237 does not equal the expectd 4.064494e-08"
//
// testCase "Binomial.PMF_n=0" <| fun () ->
// let testCase = Discrete.Multinomial.PMF 0.69 0 237
// let r_value = 0
// Expect.floatClose
// Accuracy.low
// testCase
// r_value
// "Binomial.PMF with n=0, p=0.69 and k=237 does not equal the expectd 0"
//
// testCase "Binomial.PMF_k<0" <| fun () ->
// let testCase = Discrete.Multinomial.PMF 0.69 420 -10
// let r_value = 0
// Expect.floatClose
// Accuracy.low
// testCase
// r_value
// "Binomial.PMF with n=420, p=0.69 and k=-10 does not equal the expectd 0"
// ]
//
[<Tests>]
let multinomialTests =
// TestCases from R stats: dmultinom(prob, x)
let prob1 = vector [0.2;0.4;0.4;0.]
let x1 = Vector.Generic.ofList [2;4;2;0]

let prob2 = vector [0.02;0.04;0.02;0.;0.01;0.1;0.81]
let x2 = Vector.Generic.ofList [2;4;2;0;1;10;100]
testList "Distributions.Discrete.Multinominal" [
testCase "Mean" <| fun () ->
let testCase = Discrete.Multinomial.Mean prob1 100
let means = vector [20.;40.;40.;0.]
TestExtensions.TestExtensions.sequenceEqual Accuracy.veryHigh
testCase
means
"Multinominal mean vector is incorrect"

testCase "Variance" <| fun () ->
let testCase = Discrete.Multinomial.Variance prob2 119
let variances = vector [2.3324;4.5696;2.3324;0;1.1781;10.71;18.3141]
TestExtensions.TestExtensions.sequenceEqual Accuracy.veryHigh
testCase
variances
"Multinominal Variance vector is incorrect"

testCase "PMF1" <| fun () ->
let testCase = Discrete.Multinomial.PMF prob1 x1
let pmf = 0.0688128
Expect.floatClose
Accuracy.veryHigh
testCase
pmf
"Multinominal.PMF is incorrect"

testCase "PMF2" <| fun () ->
let testCase = Discrete.Multinomial.PMF prob2 x2
let pmf = 0.0004954918510266295
Expect.floatClose
Accuracy.veryHigh
testCase
pmf
"Multinominal.PMF is incorrect"
]

[<Tests>]
let hypergeometricTests =

Expand Down
2 changes: 2 additions & 0 deletions tests/FSharp.Stats.Tests/TestExtensions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

static member sequenceEqual(accuracy: Expecto.Accuracy) =
fun actual expected message ->
if Seq.length actual <> Seq.length expected then Expect.isTrue false message
Seq.iter2 (fun a b -> Expect.floatClose accuracy a b message) actual expected

static member sequenceEqualRoundedNaN (digits: int) =
let round (v:float) = System.Math.Round(v,digits)
fun actual expected message ->
if Seq.length actual <> Seq.length expected then Expect.isTrue false message
Seq.iter2 (fun a b ->
if nan.Equals a then
Expect.isTrue (nan.Equals b) message
Expand Down

0 comments on commit 4b979b0

Please sign in to comment.