From fd8a69532edf17c54ddf8c4867a10ed4c17ad55c Mon Sep 17 00:00:00 2001 From: Itsuki Toyota Date: Sun, 27 Aug 2023 11:26:06 +0900 Subject: [PATCH] Make .from-matrix enable to accept unshaped @x --- README.md | 2 +- lib/Algorithm/XGBoost.rakumod | 2 +- lib/Algorithm/XGBoost/DMatrix.rakumod | 7 ++++++- t/01-basic.rakutest | 12 ++++++++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 25501ec..6a4b45b 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ say $dmat.num-col; # 127 my $model = Algorithm::XGBoost.train($dmat, 10); $model.num-feature.say; # 127 -my @test[2;2] = [[0e0,0e0],[0e0,1e0]]; +my @test = [[0e0,0e0],[0e0,1e0]]; my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test); say $model.predict($test); # (0.9858561754226685 0.9858561754226685) ``` diff --git a/lib/Algorithm/XGBoost.rakumod b/lib/Algorithm/XGBoost.rakumod index cc9cb30..bce37be 100644 --- a/lib/Algorithm/XGBoost.rakumod +++ b/lib/Algorithm/XGBoost.rakumod @@ -72,7 +72,7 @@ say $dmat.num-col; # 127 my $model = Algorithm::XGBoost.train($dmat, 10); $model.num-feature.say; # 127 -my @test[2;2] = [[0e0,0e0],[0e0,1e0]]; +my @test = [[0e0,0e0],[0e0,1e0]]; my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test); say $model.predict($test); # (0.9858561754226685 0.9858561754226685) diff --git a/lib/Algorithm/XGBoost/DMatrix.rakumod b/lib/Algorithm/XGBoost/DMatrix.rakumod index c8e77a9..4b94760 100644 --- a/lib/Algorithm/XGBoost/DMatrix.rakumod +++ b/lib/Algorithm/XGBoost/DMatrix.rakumod @@ -18,7 +18,12 @@ method from-file(::?CLASS:U: Str $path --> ::?CLASS:D) { nativecast(Algorithm::XGBoost::DMatrix, $h); } -method from-matrix(::?CLASS:U: @x where { $_.shape ~~ ($,$) }, @y?, Num :$missing = NaN --> ::?CLASS:D) { +multi method from-matrix(::?CLASS:U: @x, @y?, Num :$missing = NaN --> ::?CLASS:D) { + my @shaped-x[+@x;@x[0].elems] = @x.clone; + ::?CLASS.from-matrix(@shaped-x, @y); +} + +multi method from-matrix(::?CLASS:U: @x where { $_.shape ~~ ($,$) }, @y?, Num :$missing = NaN --> ::?CLASS:D) is default { my $h = Pointer.new; my $data = CArray[num32].new(@x.flat); my ($nr, $nc) = @x.shape; diff --git a/t/01-basic.rakutest b/t/01-basic.rakutest index 9352b06..9c4e60b 100644 --- a/t/01-basic.rakutest +++ b/t/01-basic.rakutest @@ -20,6 +20,18 @@ subtest { is $model.predict($test), (0.9858561754226685, 0.9858561754226685); }, "Make sure SYNOPSIS works fine"; +subtest { + my $dmat = Algorithm::XGBoost::DMatrix.from-file("{$*PROGRAM.parent.absolute}/../misc/agaricus.txt.train"); + is $dmat.num-row, 6513; + is $dmat.num-col, 127; + my $model = Algorithm::XGBoost.train($dmat, 10); + is $model.num-feature, 127; + + my @test = [[0e0,0e0],[0e0,1e0]]; + my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test); + is $model.predict($test), (0.9858561754226685, 0.9858561754226685); +}, "Make sure SYNOPSIS (unshaped version) works fine"; + subtest { my @train[3;2] = [[0e0,0e0],[0e0,1e0],[1e0,0e0]]; my @y = [1e0, 0e0, 1e0];