Skip to content

Commit

Permalink
Make .from-matrix enable to accept unshaped @x
Browse files Browse the repository at this point in the history
  • Loading branch information
titsuki committed Aug 27, 2023
1 parent 87b3608 commit fd8a695
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ say $dmat.num-col; # 127
my $model = Algorithm::XGBoost.train($dmat, 10);
$model.num-feature.say; # 127

my @test[2;2] = [[0e0,0e0],[0e0,1e0]];
my @test = [[0e0,0e0],[0e0,1e0]];
my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test);
say $model.predict($test); # (0.9858561754226685 0.9858561754226685)
```
Expand Down
2 changes: 1 addition & 1 deletion lib/Algorithm/XGBoost.rakumod
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ say $dmat.num-col; # 127
my $model = Algorithm::XGBoost.train($dmat, 10);
$model.num-feature.say; # 127
my @test[2;2] = [[0e0,0e0],[0e0,1e0]];
my @test = [[0e0,0e0],[0e0,1e0]];
my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test);
say $model.predict($test); # (0.9858561754226685 0.9858561754226685)
Expand Down
7 changes: 6 additions & 1 deletion lib/Algorithm/XGBoost/DMatrix.rakumod
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ method from-file(::?CLASS:U: Str $path --> ::?CLASS:D) {
nativecast(Algorithm::XGBoost::DMatrix, $h);
}

method from-matrix(::?CLASS:U: @x where { $_.shape ~~ ($,$) }, @y?, Num :$missing = NaN --> ::?CLASS:D) {
multi method from-matrix(::?CLASS:U: @x, @y?, Num :$missing = NaN --> ::?CLASS:D) {
my @shaped-x[+@x;@x[0].elems] = @x.clone;
::?CLASS.from-matrix(@shaped-x, @y);
}

multi method from-matrix(::?CLASS:U: @x where { $_.shape ~~ ($,$) }, @y?, Num :$missing = NaN --> ::?CLASS:D) is default {
my $h = Pointer.new;
my $data = CArray[num32].new(@x.flat);
my ($nr, $nc) = @x.shape;
Expand Down
12 changes: 12 additions & 0 deletions t/01-basic.rakutest
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ subtest {
is $model.predict($test), (0.9858561754226685, 0.9858561754226685);
}, "Make sure SYNOPSIS works fine";

subtest {
my $dmat = Algorithm::XGBoost::DMatrix.from-file("{$*PROGRAM.parent.absolute}/../misc/agaricus.txt.train");
is $dmat.num-row, 6513;
is $dmat.num-col, 127;
my $model = Algorithm::XGBoost.train($dmat, 10);
is $model.num-feature, 127;

my @test = [[0e0,0e0],[0e0,1e0]];
my $test = Algorithm::XGBoost::DMatrix.from-matrix(@test);
is $model.predict($test), (0.9858561754226685, 0.9858561754226685);
}, "Make sure SYNOPSIS (unshaped version) works fine";

subtest {
my @train[3;2] = [[0e0,0e0],[0e0,1e0],[1e0,0e0]];
my @y = [1e0, 0e0, 1e0];
Expand Down

0 comments on commit fd8a695

Please sign in to comment.