Skip to content

Commit

Permalink
fix mat vec mul for float and add tests
Browse files Browse the repository at this point in the history
closes #17
  • Loading branch information
dastrobu committed Jan 20, 2021
1 parent 06ad7ef commit be39d13
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Sources/NdArray/Matrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ public func * (A: Matrix<Float>, x: Vector<Float>) -> Vector<Float> {

let m: Int32 = Int32(a.shape[0])
let n: Int32 = Int32(a.shape[1])
let lda: Int32 = Int32(a.strides[1])
let lda: Int32 = Int32(a.shape[0])
let incX: Int32 = Int32(x.strides[0])
let incY: Int32 = Int32(y.strides[0])
cblas_sgemv(order, CblasNoTrans, m, n, 1, a.data, lda, x.data, incX, 0, y.data, incY)
Expand Down
14 changes: 14 additions & 0 deletions Tests/NdArrayTests/MatrixTestsDouble.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ class MatrixTestsDouble: XCTestCase {
}
}

func testMatMatMul() {
let A = Matrix<Double>.ones([2, 2])
let B = Matrix<Double>.ones([2, 2, ])
XCTAssertEqual((A * B).shape, [2, 2])
XCTAssertEqual((A * B).dataArray, [2.0, 2.0, 2.0, 2.0])
}

func testMatVecMul() {
let M = Matrix<Double>.ones([2, 2])
let x = Vector<Double>.ones(2)
XCTAssertEqual((M * x).shape, [2])
XCTAssertEqual((M * x).dataArray, [2.0, 2.0])
}

func testSolveAndInverted() throws {
// 2d effective 0d
do {
Expand Down
14 changes: 14 additions & 0 deletions Tests/NdArrayTests/MatrixTestsFloat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ class MatrixTestsFloat: XCTestCase {
}
}

func testMatMatMul() {
let A = Matrix<Float>.ones([2, 2])
let B = Matrix<Float>.ones([2, 2, ])
XCTAssertEqual((A * B).shape, [2, 2])
XCTAssertEqual((A * B).dataArray, [2.0, 2.0, 2.0, 2.0])
}

func testMatVecMul() {
let M = Matrix<Float>.ones([2, 2])
let x = Vector<Float>.ones(2)
XCTAssertEqual((M * x).shape, [2])
XCTAssertEqual((M * x).dataArray, [2.0, 2.0])
}

func testSolveAndInverted() throws {
// 2d effective 0d
do {
Expand Down

0 comments on commit be39d13

Please sign in to comment.