Skip to content

Commit

Permalink
Various alloc reductions and optimizations
Browse files Browse the repository at this point in the history
Sch: Don't return values in Tasks
Sch: Switch from state.cache to thunk.cache_ref
tests: Improve test_throws_unwrap error comparisons
  • Loading branch information
jpsamaroo committed Aug 6, 2024
1 parent 38230f3 commit 03c27ff
Show file tree
Hide file tree
Showing 30 changed files with 2,186 additions and 1,369 deletions.
22 changes: 14 additions & 8 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if !isdefined(Base, :ScopedValues)
else
import Base.ScopedValues: ScopedValue, with
end
import TaskLocalValues: TaskLocalValue

if !isdefined(Base, :get_extension)
import Requires: @require
Expand All @@ -34,9 +35,13 @@ import Adapt
include("lib/util.jl")
include("utils/dagdebug.jl")

# Logging Basics
include("utils/logging.jl")

# Distributed data
include("utils/locked-object.jl")
include("utils/tasks.jl")
include("utils/reuse.jl")

import MacroTools: @capture
include("options.jl")
Expand All @@ -48,6 +53,7 @@ include("task-tls.jl")
include("scopes.jl")
include("utils/scopes.jl")
include("dtask.jl")
include("argument.jl")
include("queue.jl")
include("thunk.jl")
include("submission.jl")
Expand All @@ -64,34 +70,34 @@ include("sch/Sch.jl"); using .Sch
# Data dependency task queue
include("datadeps.jl")

# File IO
include("file-io.jl")

# Array computations
include("array/darray.jl")
include("array/alloc.jl")
include("array/map-reduce.jl")
include("array/copy.jl")

# File IO
include("file-io.jl")

include("array/random.jl")
include("array/operators.jl")
include("array/indexing.jl")
include("array/setindex.jl")
include("array/matrix.jl")
include("array/sparse_partition.jl")
include("array/parallel-blocks.jl")
include("array/sort.jl")
include("array/linalg.jl")
include("array/mul.jl")
include("array/cholesky.jl")

# Custom Logging Events
include("utils/logging-events.jl")

# Visualization
include("visualization.jl")
include("ui/gantt-common.jl")
include("ui/gantt-text.jl")

# Logging
include("utils/logging-events.jl")
include("utils/logging.jl")

# Precompilation
import PrecompileTools: @compile_workload
include("precompile.jl")
Expand Down
44 changes: 44 additions & 0 deletions src/argument.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
mutable struct ArgPosition
positional::Bool
idx::Int
kw::Symbol
end
ArgPosition() = ArgPosition(true, 0, :NULL)
ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw)
ispositional(pos::ArgPosition) = pos.positional
iskw(pos::ArgPosition) = !pos.positional
function pos_idx(pos::ArgPosition)
@assert pos.positional
@assert pos.idx > 0
@assert pos.kw == :NULL
return pos.idx
end
function pos_kw(pos::ArgPosition)
@assert !pos.positional
@assert pos.idx == 0
@assert pos.kw != :NULL
return pos.kw
end
mutable struct Argument
pos::ArgPosition
value
end
Argument(pos::Integer, value) = Argument(ArgPosition(true, pos, :NULL), value)
Argument(kw::Symbol, value) = Argument(ArgPosition(false, 0, kw), value)
ispositional(arg::Argument) = ispositional(arg.pos)
iskw(arg::Argument) = iskw(arg.pos)
pos_idx(arg::Argument) = pos_idx(arg.pos)
pos_kw(arg::Argument) = pos_kw(arg.pos)
value(arg::Argument) = arg.value
valuetype(arg::Argument) = typeof(arg.value)
Base.iterate(arg::Argument) = (arg.pos, true)
function Base.iterate(arg::Argument, state::Bool)
if state
return (arg.value, false)
else
return nothing
end
end

Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value)
chunktype(arg::Argument) = chunktype(value(arg))
9 changes: 8 additions & 1 deletion src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ domainchunks(d::DArray) = d.subdomains
size(x::DArray) = size(domain(x))
stage(ctx, c::DArray) = c

function Base.collect(d::DArray; tree=false)
function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N}
a = fetch(d)
if isempty(d.chunks)
return Array{eltype(d)}(undef, size(d)...)
Expand All @@ -183,6 +183,13 @@ function Base.collect(d::DArray; tree=false)
return fetch(a.chunks[1])
end

if copyto
C = Array{T,N}(undef, size(a))
DC = view(C, Blocks(size(a)...))
copyto!(DC, a)
return C
end

dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)]
if tree
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))
Expand Down
2 changes: 0 additions & 2 deletions src/array/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import TaskLocalValues: TaskLocalValue

### getindex

struct GetIndex{T,N} <: ArrayOp{T,N}
Expand Down
174 changes: 174 additions & 0 deletions src/array/parallel-blocks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
export ParallelBlocks

using Statistics

struct ParallelBlocks{N} <: Dagger.AbstractSingleBlocks{N}
n::Int
end
ParallelBlocks(n::Integer) = ParallelBlocks{0}(n)
ParallelBlocks{N}(dist::ParallelBlocks) where N = ParallelBlocks{N}(dist.n)
ParallelBlocks() = ParallelBlocks(Dagger.num_processors())

Base.convert(::Type{ParallelBlocks{N}}, dist::ParallelBlocks) where N =
ParallelBlocks{N}(dist.n)

wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, dist::ParallelBlocks) =
wrap_chunks(chunks, N, dist.n)
wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, n::Integer) =
convert(Array{Any}, reshape(chunks, ntuple(i->i == 1 ? n : 1, N)))

function _finish_allocation(f::Function, dist::ParallelBlocks, dims::NTuple{N,Int}) where N
d = ArrayDomain(map(x->1:x, dims))
s = reshape([d for _ in 1:dist.n],
ntuple(i->i == 1 ? dist.n : 1, N))
data = [f(dims) for _ in 1:dist.n]
dist = ParallelBlocks{N}(dist)
chunks = wrap_chunks(map(Dagger.tochunk, data), N, dist)
return Dagger.DArray(eltype(first(data)), d, s, chunks, dist)
end

for fn in [:rand, :randn, :zeros, :ones]
@eval begin
function Base.$fn(dist::ParallelBlocks, ::Type{ET}, dims::Dims) where {ET}
f(block) = $fn(ET, block)
_finish_allocation(f, dist, dims)
end
Base.$fn(dist::ParallelBlocks, T::Type, dims::Integer...) = $fn(dist, T, dims)
Base.$fn(dist::ParallelBlocks, T::Type, dims::Tuple) = $fn(dist, T, dims)
Base.$fn(dist::ParallelBlocks, dims::Integer...) = $fn(dist, Float64, dims)
Base.$fn(dist::ParallelBlocks, dims::Tuple) = $fn(dist, Float64, dims)
end
end
# FIXME: sprand

function Dagger.distribute(data::AbstractArray{T,N}, dist::ParallelBlocks) where {T,N}
dims = size(data)
d = ArrayDomain(map(x->1:x, dims))
s = Dagger.DomainBlocks(ntuple(_->1, N),
ntuple(i->[dims[i]], N))
chunks = [Dagger.tochunk(copy(data)) for _ in 1:dist.n]
new_dist = ParallelBlocks{N}(dist)
return Dagger.DArray(T, d, s, wrap_chunks(chunks, N, dist), new_dist)
end

_invalid_call_pblocks(f::Symbol) =
error("`$f` is not valid for a `DArray` partitioned with `ParallelBlocks`.\nConsider `Dagger.pmap($f, x)` instead.")

Base.collect(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}) =
_invalid_call_pblocks(:collect)
Base.getindex(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, x...) =
_invalid_call_pblocks(:getindex)
Base.setindex!(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, value, x...) =
_invalid_call_pblocks(:setindex!)

function pmap(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
# TODO: Chunks might not be `Array`s
# FIXME
#AT = Array{T,N}
#ET = eltype(Base.promote_op(f, AT))
ET = Any
new_chunks = map(A.chunks) do chunk
Dagger.@spawn f(chunk)
end
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
end
# FIXME: More useful `show` method
Base.show(io::IO, ::MIME"text/plain", A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
print(io, typeof(A))
pfetch(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
map(fetch, A.chunks)
pcollect(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
map(collect, pfetch(A))

function Base.map(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
ET = Base.promote_op(f, T)
new_chunks = map(A.chunks) do chunk
Dagger.@spawn map(f, chunk)
end
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
end
function Base.map!(f::Function,
x::Dagger.DArray{T1,N1,ParallelBlocks{N1}} where {T1,N1},
y::Dagger.DArray{T2,N2,ParallelBlocks{N2}} where {T2,N2})
x_dist = x.partitioning
y_dist = y.partitioning
if x_dist.n != y_dist.n
throw(ArgumentError("Can't `map!` over non-matching `ParallelBlocks` distributions: $(x_dist.n) != $(y_dist.n)"))
end
@sync for i in 1:x_dist.n
Dagger.@spawn map!(f, x.chunks[i], y.chunks[i])
end
end

#=
function Base.reduce(f::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}};
dims=:) where {T,N}
error("Out-of-place Reduce")
if dims == Base.:(:)
localpart = fetch(Dagger.reduce_async(f, x))
return MPI.Allreduce(localpart, f, comm)
elseif dims === nothing
localpart = fetch(x.chunks[1])
return MPI.Allreduce(localpart, f, comm)
else
error("Not yet implemented")
end
end
=#
function allreduce!(op::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}}; nchunks::Integer=0) where {T,N}
if nchunks == 0
nchunks = x.partitioning.n
end
@assert nchunks == x.partitioning.n "Number of chunks must match the number of partitions"

# Split each chunk along the last dimension
chunk_size = cld(size(x, ndims(x)), nchunks)
chunk_dist = Blocks(ntuple(i->i == N ? chunk_size : size(x, i), N))
chunk_ds = partition(chunk_dist, x.subdomains[1])
num_par_chunks = length(x.chunks)

# Allocate temporary buffer
y = copy(x)

# Ring-reduce into temporary buffer
Dagger.spawn_datadeps() do
for j in 1:length(chunk_ds)
for i in 1:num_par_chunks
for step in 1:(num_par_chunks-1)
from_idx = i
to_idx = mod1(i+step, num_par_chunks)
from_chunk = x.chunks[from_idx]
to_chunk = y.chunks[to_idx]
sd = chunk_ds[mod1(j+i-1, length(chunk_ds))].indexes
# FIXME: Specify aliasing based on `sd`
Dagger.@spawn _reduce_view!(op,
InOut(to_chunk), sd,
In(from_chunk), sd)
end
end
end

# Copy from temporary buffer back to origin
for i in 1:num_par_chunks
Dagger.@spawn copyto!(Out(x.chunks[i]), In(y.chunks[i]))
end
end

return x
end
function _reduce_view!(op, to, to_view, from, from_view)
to_viewed = view(to, to_view...)
from_viewed = view(from, from_view...)
reduce!(op, to_viewed, from_viewed)
return
end
function reduce!(op, to, from)
to .= op.(to, from)
end

function Statistics.mean!(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
allreduce!(+, A)
len = length(A.chunks)
map!(x->x ./ len, A, A)
return A
end
13 changes: 13 additions & 0 deletions src/array/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Random

function Random.randn!(A::DArray{T}) where T
Ac = A.chunks

Dagger.spawn_datadeps() do
for chunk in Ac
Dagger.@spawn randn!(InOut(chunk))
end
end

return A
end
21 changes: 18 additions & 3 deletions src/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Base.length(s::Shard) = length(s.chunks)
### Core Stuff

"""
tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, kwargs...) -> Chunk
tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk
Create a chunk from data `x` which resides on `proc` and which has scope
`scope`.
Expand All @@ -262,9 +262,12 @@ will be inspected to determine if it's safe to serialize; if so, the default
MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will
be used.
If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a
new `Chunk`.
All other kwargs are passed directly to `MemPool.poolset`.
"""
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S}
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, rewrap=false, kwargs...) where {X,P,S}
if device === nothing
device = if Sch.walk_storage_safe(x)
MemPool.GLOBAL_DEVICE[]
Expand All @@ -275,7 +278,15 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac
ref = poolset(x; device, kwargs...)
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist)
end
tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x
function tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; rewrap=false, kwargs...)
if rewrap
return remotecall_fetch(x.handle.owner) do
tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...)
end
else
return x
end
end

function savechunk(data, dir, f)
sz = open(joinpath(dir, f), "w") do io
Expand All @@ -302,9 +313,13 @@ function unwrap_weak_checked(c::WeakChunk)
@assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))"
return cw
end
wrap_weak(c::Chunk) = WeakChunk(c)
isweak(c::WeakChunk) = true
isweak(c::Chunk) = false
is_task_or_chunk(c::WeakChunk) = true
Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) =
error("Cannot serialize a WeakChunk")
chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c))

Base.@deprecate_binding AbstractPart Union{Chunk, Thunk}
Base.@deprecate_binding Part Chunk
Expand Down
Loading

0 comments on commit 03c27ff

Please sign in to comment.