Skip to content

Commit

Permalink
refactor slicing to be more type safe.
Browse files Browse the repository at this point in the history
  • Loading branch information
dastrobu committed Feb 20, 2022
1 parent df1491c commit 3e5c629
Show file tree
Hide file tree
Showing 32 changed files with 1,187 additions and 700 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Swift Package Manager
.build/
.swiftpm

DerivedData/

Expand Down
216 changes: 148 additions & 68 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Sources/NdArray/Equitable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extension NdArray: Equatable where T: Equatable {
// make sure the array is not sliced
let a = NdArray(lhs)
let b = NdArray(rhs)
for i in 0..<a.shape[0] where a[i] != b[i] {
for i in 0..<a.shape[0] where a[[Slice(i)]] != b[[Slice(i)]] {
return false
}
}
Expand Down
14 changes: 7 additions & 7 deletions Sources/NdArray/Matrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ open class Matrix<T>: NdArray<T> {
Cannot transpose matrix with shape \(shape) to matrix with shape \(out.shape).
Precondition failed while trying to transpose \(debugDescription) to \(out.debugDescription).
""")
out[...][...] = self.transposed()[...][...]
out[0..., 0...] = self.transposed()[0..., 0...]
}
}
}
Expand Down Expand Up @@ -167,11 +167,11 @@ public extension Matrix where T == Double {
return B
}
// copy rhs to work space (thereby also making sure it is F contiguous)
B[...] = rhs[...]
B[0...] = rhs[0...]

// copy self to A, since it is modified (thereby also making sure it is F contiguous)
let A = Matrix<T>(empty: shape, order: .F)
A[...] = self[...]
A[0...] = self[0...]
var nrhs = __CLPK_integer(B.shape[1])
var ipiv: [__CLPK_integer] = [__CLPK_integer].init(repeating: 0, count: Int(n))
var lda: __CLPK_integer = __CLPK_integer(n)
Expand Down Expand Up @@ -202,7 +202,7 @@ public extension Matrix where T == Double {
Precondition failed while trying to solve \(debugDescription).
""")
let A = out ?? Matrix(empty: shape, order: .F)
A[...] = self[...]
A[0...] = self[0...]

var ipiv = try A.luFactor()

Expand Down Expand Up @@ -326,11 +326,11 @@ public extension Matrix where T == Float {
return B
}
// copy rhs to work space (thereby also making sure it is F contiguous)
B[...] = rhs[...]
B[0...] = rhs[0...]

// copy self to A, since it is modified (thereby also making sure it is F contiguous)
let A = Matrix<T>(empty: shape, order: .F)
A[...] = self[...]
A[0...] = self[0...]
var nrhs = __CLPK_integer(B.shape[1])
var ipiv: [__CLPK_integer] = [__CLPK_integer].init(repeating: 0, count: Int(n))
var lda: __CLPK_integer = __CLPK_integer(n)
Expand Down Expand Up @@ -361,7 +361,7 @@ public extension Matrix where T == Float {
Precondition failed while trying to solve \(debugDescription).
""")
let A = out ?? Matrix(empty: shape, order: .F)
A[...] = self[...]
A[0...] = self[0...]

var ipiv = try A.luFactor()

Expand Down
62 changes: 59 additions & 3 deletions Sources/NdArray/NdArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
self.count = a.count
self.shape = a.shape
self.strides = a.strides
precondition(self !== owner)
assert(self !== owner)
}

deinit {
Expand Down Expand Up @@ -247,17 +247,20 @@ open class NdArray<T>: CustomDebugStringConvertible,
"\(self, style: .multiLine)"
}

/// element access
/**
element access
*/
public subscript(index: [Int]) -> T {
get {
data[flatIndex(index)]
}
set {
self.data[flatIndex(index)] = newValue
data[flatIndex(index)] = newValue
}
}

/// full slice access
@available(*, deprecated, message: "prefer new slicing syntax a[0..., 0..., 0...] over old one a[...][...][...]")
public subscript(r: UnboundedRange) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -268,6 +271,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range slice access
@available(*, deprecated, message: "prefer new slicing syntax a[0...42, 0...42, 0...42] over old one a[0...42][0...42][0...42]")
public subscript(r: ClosedRange<Int>) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -278,6 +282,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range slice access
@available(*, deprecated, message: "prefer new slicing syntax a[...42, ...42, ...42] over old one a[...42][...42][...42]")
public subscript(r: PartialRangeThrough<Int>) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -288,6 +293,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range slice access
@available(*, deprecated, message: "prefer new slicing syntax a[..<42, ..<42, ..<42] over old one a[..<42][..<42][..<42]")
public subscript(r: PartialRangeUpTo<Int>) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -298,6 +304,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range slice access
@available(*, deprecated, message: "prefer new slicing syntax a[42..., 42.., 42..] over old one a[42...][42...][42...]")
public subscript(r: PartialRangeFrom<Int>) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -308,6 +315,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// range slice access
@available(*, deprecated, message: "prefer new slicing syntax a[1..<42, 1..<42, 1..<42] over old one a[1..<42][1..<42][1..<42]")
public subscript(r: Range<Int>) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r]
Expand All @@ -318,6 +326,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[1..<42 ~ 3, 1..<42 ~ 3, 1..<42 ~ 3] over old one a[1..<42, 3][1..<42, 3][1..<42, 3]")
public subscript(r: Range<Int>, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -328,6 +337,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// closed range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[0...42 ~ 3, 0...42 ~ 3, 0...42 ~ 3] over old one a[0...42, 3][0...42, 3][0...42, 3]")
public subscript(r: ClosedRange<Int>, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -338,6 +348,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[42... ~ 3, 42.. ~ 3, 42.. ~ 3] over old one a[42..., 3][42..., 3][42..., 3]")
public subscript(r: PartialRangeFrom<Int>, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -348,6 +359,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[...42 ~ 3, ...42 ~ 3, ...42 ~ 3] over old one a[...42, 3][...42, 3][...42, 3]")
public subscript(r: PartialRangeThrough<Int>, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -358,6 +370,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// partial range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[..<42 ~ 3, ..<42 ~ 3, ..<42 ~ 3] over old one a[..<42, 3][..<42, 3][..<42, 3]")
public subscript(r: PartialRangeUpTo<Int>, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -368,6 +381,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// full range with stride
@available(*, deprecated, message: "prefer new slicing syntax a[0... ~ 3, 0... ~ 3, 0... ~ 3] over old one a[..., 3][..., 3][..., 3]")
public subscript(r: UnboundedRange, stride: Int) -> NdArray<T> {
get {
NdArraySlice(self, sliced: 0)[r, stride]
Expand All @@ -378,6 +392,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
}

/// single slice access
@available(*, deprecated, message: "prefer new slicing syntax a[42, 42, 42] over old one a[42][42][42]")
public subscript(i: Int) -> NdArray<T> {
get {
precondition(!isEmpty)
Expand All @@ -402,6 +417,47 @@ open class NdArray<T>: CustomDebugStringConvertible,
newValue.copyTo(self[i])
}
}

/**
slice access
*/
public subscript(slices: [Slice]) -> NdArray<T> {
get {
var a = NdArraySlice(self, sliced: 0)
for (i, s) in slices.enumerated() {
switch s.sliceKind {
case .range(lowerBound: let lowerBound, upperBound: let upperBound, stride: let stride):
let stride = stride ?? 1
let lowerBound = lowerBound ?? 0
let upperBound = upperBound ?? shape[i]
a = a.subscr(lowerBound: lowerBound, upperBound: upperBound, stride: stride)
case .index(let i):
a = a.subscr(i)
if a.shape.isEmpty {
a.shape = [1]
a.strides = [1]
a.count = 1
}
}
}
return NdArray(a)
}
set {
newValue.copyTo(self[slices])
}
}

/**
slice access
*/
public subscript(slices: Slice...) -> NdArray<T> {
get {
self[slices]
}
set {
self[slices] = newValue
}
}
}

// extension helping to handle different memory alignments
Expand Down
36 changes: 33 additions & 3 deletions Sources/NdArray/NdArraySlice.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import Darwin

public class NdArraySlice<T>: NdArray<T> {
internal class NdArraySlice<T>: NdArray<T> {
/// number of dimensions that have been sliced
private var sliced: Int

Expand Down Expand Up @@ -50,7 +50,7 @@ public class NdArraySlice<T>: NdArray<T> {
private func subscr(_ r: UnboundedRange) -> NdArraySlice {
let slice = NdArraySlice(self, sliced: sliced + 1)
slice.sliceDescription = sliceDescription
slice.sliceDescription.append("[...]")
slice.sliceDescription.append("[0...]")
return slice
}

Expand Down Expand Up @@ -141,6 +141,8 @@ public class NdArraySlice<T>: NdArray<T> {
let slice = NdArraySlice(self, startIndex: [Int](repeating: 0, count: ndim), sliced: sliced + 1)
slice.shape[sliced] = 0
slice.count = slice.len
slice.sliceDescription = sliceDescription
slice.sliceDescription.append("[\(r.lowerBound)..<\(r.upperBound)]")
return slice
}

Expand Down Expand Up @@ -200,7 +202,7 @@ public class NdArraySlice<T>: NdArray<T> {
/// partial range with stride
public override subscript(r: PartialRangeFrom<Int>, stride: Int) -> NdArray<T> {
get {
precondition(stride > 0, "\(stride) > 0")
precondition(stride > 0, "\(stride) > 0, \(debugDescription)")

let slice = self.subscr(r)
slice.sliceDescription.removeLast()
Expand Down Expand Up @@ -287,6 +289,34 @@ public class NdArraySlice<T>: NdArray<T> {
}
}

internal func subscr(lowerBound: Int, upperBound: Int, stride: Int) -> NdArraySlice<T> {
precondition(stride > 0, "\(stride) > 0")

let slice = self.subscr(lowerBound..<upperBound)
slice.sliceDescription.removeLast()
slice.sliceDescription.append("[\(lowerBound)..<\(upperBound), \(stride)]")
slice.strideBy(stride, axis: sliced)
return slice
}

internal func subscr(_ i: Int) -> NdArraySlice<T> {
precondition(!isEmpty)

// set the index on the sliced dimension
var startIndex = [Int](repeating: 0, count: ndim)
startIndex[sliced] = i

// here we reduce the shape, hence sliced stays the same
let slice = NdArraySlice(self, startIndex: startIndex, sliced: sliced)
slice.sliceDescription = sliceDescription
slice.sliceDescription.append("[\(i)]")
// drop shape and stride
slice.shape = Array(slice.shape[0..<sliced] + slice.shape[(sliced + 1)...])
slice.strides = Array(slice.strides[0..<sliced] + slice.strides[(sliced + 1)...])
slice.count = slice.len
return slice
}

public override var debugDescription: String {
let address = String(format: "%p", Int(bitPattern: data))
var sliceDescription = sliceDescription.joined()
Expand Down
Loading

0 comments on commit 3e5c629

Please sign in to comment.