From 03c27ffb33f758b505c7ff280e1d57e1794d3911 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 16 Jul 2024 11:33:34 +0100 Subject: [PATCH] Various alloc reductions and optimizations Sch: Don't return values in Tasks Sch: Switch from state.cache to thunk.cache_ref tests: Improve test_throws_unwrap error comparisons --- src/Dagger.jl | 22 +- src/argument.jl | 44 ++ src/array/darray.jl | 9 +- src/array/indexing.jl | 2 - src/array/parallel-blocks.jl | 174 ++++++++ src/array/random.jl | 13 + src/chunks.jl | 21 +- src/datadeps.jl | 66 +-- src/dtask.jl | 13 +- src/file-io.jl | 293 ------------- src/precompile.jl | 2 +- src/queue.jl | 3 +- src/sch/Sch.jl | 813 ++++++++++++++++++----------------- src/sch/dynamic.jl | 55 +-- src/sch/eager.jl | 20 +- src/sch/fault-handler.jl | 8 +- src/sch/util.jl | 311 +++++++++----- src/submission.jl | 368 +++++++++------- src/task-tls.jl | 17 +- src/threadproc.jl | 4 +- src/thunk.jl | 283 +++++++----- src/utils/logging-events.jl | 2 +- src/utils/logging.jl | 11 +- src/utils/reuse.jl | 542 +++++++++++++++++++++++ test/logging.jl | 22 +- test/memory-spaces.jl | 50 ++- test/processors.jl | 4 +- test/scheduler.jl | 264 +++++++----- test/thunk.jl | 100 +++-- test/util.jl | 19 +- 30 files changed, 2186 insertions(+), 1369 deletions(-) create mode 100644 src/argument.jl create mode 100644 src/array/parallel-blocks.jl create mode 100644 src/array/random.jl create mode 100644 src/utils/reuse.jl diff --git a/src/Dagger.jl b/src/Dagger.jl index 093bc2cfa..b72ddca41 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -21,6 +21,7 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) import Requires: @require @@ -34,9 +35,13 @@ import Adapt include("lib/util.jl") include("utils/dagdebug.jl") +# Logging Basics +include("utils/logging.jl") + # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") +include("utils/reuse.jl") import MacroTools: @capture include("options.jl") @@ -48,6 +53,7 @@ include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") include("dtask.jl") +include("argument.jl") include("queue.jl") include("thunk.jl") include("submission.jl") @@ -64,34 +70,34 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# File IO +include("file-io.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") include("array/map-reduce.jl") include("array/copy.jl") - -# File IO -include("file-io.jl") - +include("array/random.jl") include("array/operators.jl") include("array/indexing.jl") include("array/setindex.jl") include("array/matrix.jl") include("array/sparse_partition.jl") +include("array/parallel-blocks.jl") include("array/sort.jl") include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") +# Custom Logging Events +include("utils/logging-events.jl") + # Visualization include("visualization.jl") include("ui/gantt-common.jl") include("ui/gantt-text.jl") -# Logging -include("utils/logging-events.jl") -include("utils/logging.jl") - # Precompilation import PrecompileTools: @compile_workload include("precompile.jl") diff --git a/src/argument.jl b/src/argument.jl new file mode 100644 index 000000000..42ff52e6b --- /dev/null +++ b/src/argument.jl @@ -0,0 +1,44 @@ +mutable struct ArgPosition + positional::Bool + idx::Int + kw::Symbol +end +ArgPosition() = ArgPosition(true, 0, :NULL) +ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw) +ispositional(pos::ArgPosition) = pos.positional +iskw(pos::ArgPosition) = !pos.positional +function pos_idx(pos::ArgPosition) + @assert pos.positional + @assert pos.idx > 0 + @assert pos.kw == :NULL + return pos.idx +end +function pos_kw(pos::ArgPosition) + @assert !pos.positional + @assert pos.idx == 0 + @assert pos.kw != :NULL + return pos.kw +end +mutable struct Argument + pos::ArgPosition + value +end +Argument(pos::Integer, value) = Argument(ArgPosition(true, pos, :NULL), value) +Argument(kw::Symbol, value) = Argument(ArgPosition(false, 0, kw), value) +ispositional(arg::Argument) = ispositional(arg.pos) +iskw(arg::Argument) = iskw(arg.pos) +pos_idx(arg::Argument) = pos_idx(arg.pos) +pos_kw(arg::Argument) = pos_kw(arg.pos) +value(arg::Argument) = arg.value +valuetype(arg::Argument) = typeof(arg.value) +Base.iterate(arg::Argument) = (arg.pos, true) +function Base.iterate(arg::Argument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end + +Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) +chunktype(arg::Argument) = chunktype(value(arg)) diff --git a/src/array/darray.jl b/src/array/darray.jl index 9e0bc4636..16444a217 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -173,7 +173,7 @@ domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) stage(ctx, c::DArray) = c -function Base.collect(d::DArray; tree=false) +function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} a = fetch(d) if isempty(d.chunks) return Array{eltype(d)}(undef, size(d)...) @@ -183,6 +183,13 @@ function Base.collect(d::DArray; tree=false) return fetch(a.chunks[1]) end + if copyto + C = Array{T,N}(undef, size(a)) + DC = view(C, Blocks(size(a)...)) + copyto!(DC, a) + return C + end + dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)] if tree collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks))) diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..69725eb7a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/array/parallel-blocks.jl b/src/array/parallel-blocks.jl new file mode 100644 index 000000000..26a30a941 --- /dev/null +++ b/src/array/parallel-blocks.jl @@ -0,0 +1,174 @@ +export ParallelBlocks + +using Statistics + +struct ParallelBlocks{N} <: Dagger.AbstractSingleBlocks{N} + n::Int +end +ParallelBlocks(n::Integer) = ParallelBlocks{0}(n) +ParallelBlocks{N}(dist::ParallelBlocks) where N = ParallelBlocks{N}(dist.n) +ParallelBlocks() = ParallelBlocks(Dagger.num_processors()) + +Base.convert(::Type{ParallelBlocks{N}}, dist::ParallelBlocks) where N = + ParallelBlocks{N}(dist.n) + +wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, dist::ParallelBlocks) = + wrap_chunks(chunks, N, dist.n) +wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, n::Integer) = + convert(Array{Any}, reshape(chunks, ntuple(i->i == 1 ? n : 1, N))) + +function _finish_allocation(f::Function, dist::ParallelBlocks, dims::NTuple{N,Int}) where N + d = ArrayDomain(map(x->1:x, dims)) + s = reshape([d for _ in 1:dist.n], + ntuple(i->i == 1 ? dist.n : 1, N)) + data = [f(dims) for _ in 1:dist.n] + dist = ParallelBlocks{N}(dist) + chunks = wrap_chunks(map(Dagger.tochunk, data), N, dist) + return Dagger.DArray(eltype(first(data)), d, s, chunks, dist) +end + +for fn in [:rand, :randn, :zeros, :ones] + @eval begin + function Base.$fn(dist::ParallelBlocks, ::Type{ET}, dims::Dims) where {ET} + f(block) = $fn(ET, block) + _finish_allocation(f, dist, dims) + end + Base.$fn(dist::ParallelBlocks, T::Type, dims::Integer...) = $fn(dist, T, dims) + Base.$fn(dist::ParallelBlocks, T::Type, dims::Tuple) = $fn(dist, T, dims) + Base.$fn(dist::ParallelBlocks, dims::Integer...) = $fn(dist, Float64, dims) + Base.$fn(dist::ParallelBlocks, dims::Tuple) = $fn(dist, Float64, dims) + end +end +# FIXME: sprand + +function Dagger.distribute(data::AbstractArray{T,N}, dist::ParallelBlocks) where {T,N} + dims = size(data) + d = ArrayDomain(map(x->1:x, dims)) + s = Dagger.DomainBlocks(ntuple(_->1, N), + ntuple(i->[dims[i]], N)) + chunks = [Dagger.tochunk(copy(data)) for _ in 1:dist.n] + new_dist = ParallelBlocks{N}(dist) + return Dagger.DArray(T, d, s, wrap_chunks(chunks, N, dist), new_dist) +end + +_invalid_call_pblocks(f::Symbol) = + error("`$f` is not valid for a `DArray` partitioned with `ParallelBlocks`.\nConsider `Dagger.pmap($f, x)` instead.") + +Base.collect(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}) = + _invalid_call_pblocks(:collect) +Base.getindex(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, x...) = + _invalid_call_pblocks(:getindex) +Base.setindex!(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, value, x...) = + _invalid_call_pblocks(:setindex!) + +function pmap(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} + # TODO: Chunks might not be `Array`s + # FIXME + #AT = Array{T,N} + #ET = eltype(Base.promote_op(f, AT)) + ET = Any + new_chunks = map(A.chunks) do chunk + Dagger.@spawn f(chunk) + end + return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning) +end +# FIXME: More useful `show` method +Base.show(io::IO, ::MIME"text/plain", A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} = + print(io, typeof(A)) +pfetch(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} = + map(fetch, A.chunks) +pcollect(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} = + map(collect, pfetch(A)) + +function Base.map(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} + ET = Base.promote_op(f, T) + new_chunks = map(A.chunks) do chunk + Dagger.@spawn map(f, chunk) + end + return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning) +end +function Base.map!(f::Function, + x::Dagger.DArray{T1,N1,ParallelBlocks{N1}} where {T1,N1}, + y::Dagger.DArray{T2,N2,ParallelBlocks{N2}} where {T2,N2}) + x_dist = x.partitioning + y_dist = y.partitioning + if x_dist.n != y_dist.n + throw(ArgumentError("Can't `map!` over non-matching `ParallelBlocks` distributions: $(x_dist.n) != $(y_dist.n)")) + end + @sync for i in 1:x_dist.n + Dagger.@spawn map!(f, x.chunks[i], y.chunks[i]) + end +end + +#= +function Base.reduce(f::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}}; + dims=:) where {T,N} + error("Out-of-place Reduce") + if dims == Base.:(:) + localpart = fetch(Dagger.reduce_async(f, x)) + return MPI.Allreduce(localpart, f, comm) + elseif dims === nothing + localpart = fetch(x.chunks[1]) + return MPI.Allreduce(localpart, f, comm) + else + error("Not yet implemented") + end +end +=# +function allreduce!(op::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}}; nchunks::Integer=0) where {T,N} + if nchunks == 0 + nchunks = x.partitioning.n + end + @assert nchunks == x.partitioning.n "Number of chunks must match the number of partitions" + + # Split each chunk along the last dimension + chunk_size = cld(size(x, ndims(x)), nchunks) + chunk_dist = Blocks(ntuple(i->i == N ? chunk_size : size(x, i), N)) + chunk_ds = partition(chunk_dist, x.subdomains[1]) + num_par_chunks = length(x.chunks) + + # Allocate temporary buffer + y = copy(x) + + # Ring-reduce into temporary buffer + Dagger.spawn_datadeps() do + for j in 1:length(chunk_ds) + for i in 1:num_par_chunks + for step in 1:(num_par_chunks-1) + from_idx = i + to_idx = mod1(i+step, num_par_chunks) + from_chunk = x.chunks[from_idx] + to_chunk = y.chunks[to_idx] + sd = chunk_ds[mod1(j+i-1, length(chunk_ds))].indexes + # FIXME: Specify aliasing based on `sd` + Dagger.@spawn _reduce_view!(op, + InOut(to_chunk), sd, + In(from_chunk), sd) + end + end + end + + # Copy from temporary buffer back to origin + for i in 1:num_par_chunks + Dagger.@spawn copyto!(Out(x.chunks[i]), In(y.chunks[i])) + end + end + + return x +end +function _reduce_view!(op, to, to_view, from, from_view) + to_viewed = view(to, to_view...) + from_viewed = view(from, from_view...) + reduce!(op, to_viewed, from_viewed) + return +end +function reduce!(op, to, from) + to .= op.(to, from) +end + +function Statistics.mean!(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} + allreduce!(+, A) + len = length(A.chunks) + map!(x->x ./ len, A, A) + return A +end diff --git a/src/array/random.jl b/src/array/random.jl new file mode 100644 index 000000000..f896137e0 --- /dev/null +++ b/src/array/random.jl @@ -0,0 +1,13 @@ +using Random + +function Random.randn!(A::DArray{T}) where T + Ac = A.chunks + + Dagger.spawn_datadeps() do + for chunk in Ac + Dagger.@spawn randn!(InOut(chunk)) + end + end + + return A +end diff --git a/src/chunks.jl b/src/chunks.jl index 1eb56714e..0cb932dcf 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -250,7 +250,7 @@ Base.length(s::Shard) = length(s.chunks) ### Core Stuff """ - tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, kwargs...) -> Chunk + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk Create a chunk from data `x` which resides on `proc` and which has scope `scope`. @@ -262,9 +262,12 @@ will be inspected to determine if it's safe to serialize; if so, the default MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will be used. +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + All other kwargs are passed directly to `MemPool.poolset`. """ -function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S} +function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, rewrap=false, kwargs...) where {X,P,S} if device === nothing device = if Sch.walk_storage_safe(x) MemPool.GLOBAL_DEVICE[] @@ -275,7 +278,15 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cac ref = poolset(x; device, kwargs...) Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist) end -tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x +function tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; rewrap=false, kwargs...) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end function savechunk(data, dir, f) sz = open(joinpath(dir, f), "w") do io @@ -302,9 +313,13 @@ function unwrap_weak_checked(c::WeakChunk) @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" return cw end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false is_task_or_chunk(c::WeakChunk) = true Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = error("Cannot serialize a WeakChunk") +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) Base.@deprecate_binding AbstractPart Union{Chunk, Thunk} Base.@deprecate_binding Part Chunk diff --git a/src/datadeps.jl b/src/datadeps.jl index 580009aaa..a9d26beb4 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -226,9 +226,9 @@ function populate_task_info!(state::DataDepsState, spec, task) dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}() # Track the task's arguments and access patterns - for (idx, (pos, arg)) in enumerate(spec.args) + for (idx, _arg) in enumerate(spec.fargs) # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(arg) + arg, deps = unwrap_inout(value(_arg)) # Unwrap the Chunk underlying any DTask arguments arg = arg isa DTask ? fetch(arg; raw=true) : arg @@ -523,7 +523,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) scheduler = queue.scheduler if scheduler == :naive - raw_args = map(arg->tochunk(last(arg)), spec.args) + raw_args = map(arg->tochunk(value(arg)), spec.fargs) our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -538,13 +538,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end end elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, arg), spec.args)) do arg + raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg arg_chunk = tochunk(last(arg)) # Only the owned slot is valid # FIXME: Track up-to-date copies and pass all of those return arg_chunk => data_locality[arg] end - f_chunk = tochunk(spec.f) + f_chunk = tochunk(value(f)) our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -585,7 +585,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # FIXME: Pressure should be decreased by pressure of syncdeps on same processor pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure elseif scheduler == :ultra - args = Base.mapany(spec.args) do arg + args = Base.mapany(spec.fargs) do arg pos, data = arg data, _ = unwrap_inout(data) if data isa DTask @@ -593,7 +593,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end return pos => tochunk(data) end - f_chunk = tochunk(spec.f) + f_chunk = tochunk(value(f)) task_time = remotecall_fetch(1, f_chunk, args) do f, args Sch.init_eager() sch_state = Sch.EAGER_STATE[] @@ -653,20 +653,21 @@ function distribute_tasks!(queue::DataDepsTaskQueue) our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) our_scope = UnionScope(map(ExactScope, our_procs)...) - spec.f = move(ThreadProc(myid(), 1), our_proc, spec.f) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f))) Scheduling: $our_proc ($our_space)" + f = spec.fargs[1] + f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis - task_args = copy(spec.args) + task_args = map(copy, spec.fargs) # Copy args from local to remote - for (idx, (pos, arg)) in enumerate(task_args) + for (idx, _arg) in enumerate(task_args) # Is the data written previously or now? - arg, deps = unwrap_inout(arg) + arg, deps = unwrap_inout(value(_arg)) arg = arg isa DTask ? fetch(arg; raw=true) : arg if Base.datatype_pointerfree(typeof(arg)) || !has_writedep(state, arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)" - spec.args[idx] = pos => arg + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" + spec.fargs[idx].value = arg continue end @@ -681,20 +682,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue) nonlocal = our_space != data_space if nonlocal # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do generate_slot!(state, data_space, arg) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) add_writer!(state, ainfo, copy_to, write_num) astate.data_locality[ainfo] = our_space else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Skipped copy-to (local): $data_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" end end else @@ -702,32 +703,33 @@ function distribute_tasks!(queue::DataDepsTaskQueue) nonlocal = our_space != data_space if nonlocal # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Enqueueing copy-to: $data_space => $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do generate_slot!(state, data_space, arg) end copy_to_scope = our_scope copy_to_syncdeps = Set{Any}() get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] $(length(copy_to_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) add_writer!(state, arg, copy_to, write_num) astate.data_locality[arg] = our_space else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (local): $data_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" end end - spec.args[idx] = pos => arg_remote + spec.fargs[idx].value = arg_remote end write_num += 1 # Validate that we're not accidentally performing a copy - for (idx, (_, arg)) in enumerate(spec.args) - _, deps = unwrap_inout(task_args[idx][2]) + for (idx, _arg) in enumerate(spec.fargs) + arg = value(_arg) + _, deps = unwrap_inout(value(task_args[idx])) if is_writedep(arg, deps, task) arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" end end @@ -741,24 +743,24 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for (dep_mod, _, writedep) in deps ainfo = aliasing(arg, dep_mod) if writedep - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" get_write_deps!(state, ainfo, task, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" get_read_deps!(state, ainfo, task, write_num, syncdeps) end end else if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" get_write_deps!(state, arg, task, write_num, syncdeps) else - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" get_read_deps!(state, arg, task, write_num, syncdeps) end end end - @dagdebug nothing :spawn_datadeps "($(repr(spec.f))) $(length(syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" # Launch user's task task_scope = our_scope @@ -774,7 +776,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for (dep_mod, _, writedep) in deps ainfo = aliasing(arg, dep_mod) if writedep - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Set as owner" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" add_writer!(state, ainfo, task, write_num) else add_reader!(state, ainfo, task, write_num) @@ -782,7 +784,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end else if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Set as owner" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" add_writer!(state, arg, task, write_num) else add_reader!(state, arg, task, write_num) @@ -935,7 +937,7 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true, throw(ArgumentError("Dynamic scheduling is no longer available")) end wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :naive)::Symbol + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool if launch_wait result = spawn_bulk() do diff --git a/src/dtask.jl b/src/dtask.jl index 94b2a64e0..197c322eb 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -48,9 +48,8 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture - finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + DTask(uid, future) = new(uid, future) end const EagerThunk = DTask @@ -81,16 +80,6 @@ function Base.show(io::IO, t::DTask) end istask(t::DTask) = true -"When finalized, cleans-up the associated `DTask`." -mutable struct DTaskFinalizer - uid::UInt - function DTaskFinalizer(uid) - x = new(uid) - finalizer(Sch.eager_cleanup, x) - x - end -end - const EAGER_ID_COUNTER = Threads.Atomic{UInt64}(1) function eager_next_id() if myid() == 1 diff --git a/src/file-io.jl b/src/file-io.jl index efe80c59f..1407fbced 100644 --- a/src/file-io.jl +++ b/src/file-io.jl @@ -74,296 +74,3 @@ Base.show(io::IO, file::File) = print(io, "Dagger.File(path=\"$(file.path)\")") function move(from_proc::Processor, to_proc::Processor, file::File) return move(from_proc, to_proc, file.chunk) end - -""" - FileReader - -Used as a `Chunk` handle for reading a file, starting at a given offset. -""" -mutable struct FileReader{T} - file::AbstractString - chunktype::Type{T} - data_offset::Int - mmap::Bool -end - -""" - save(io::IO, val) - -Save a value into the IO buffer. In the case of arrays and sparse -matrices, this will save it in a memory-mappable way. - -`load(io::IO, t::Type, domain)` will load the object given its domain -""" -function save(ctx, io::IO, val) - error("Save method for $(typeof(val)) not defined") -end - - -###### Save chunks ###### - -const PARTSPEC = 0x00 -const CAT = 0x01 - -# subparts are saved as Parts - -""" - save(ctx, chunk::Union{Chunk, Thunk}, file_path::AbsractString) - -Save a chunk to a file at `file_path`. -""" -function save(ctx, chunk::Union{Chunk, Thunk}, file_path::AbstractString) - open(file_path, "w") do io - save(ctx, io, chunk, file_path) - end -end - -""" - save(ctx, chunk, file_path) - -Special case distmem writing - write to disk on the process with the chunk. -""" -function save(ctx, chunk::Chunk{X,DRef}, file_path::AbstractString) where X - pid = chunk.handle.where - - remotecall_fetch(pid, file_path, chunk.handle) do path, rref - open(path, "w") do io - save(ctx, io, chunk, file_path) - end - end -end - -function save(ctx, io::IO, chunk::Chunk, file_path) - meta_io = IOBuffer() - - serialize(meta_io, (chunktype(chunk), domain(chunk))) - meta = take!(meta_io) - - write(io, PARTSPEC) - write(io, length(meta)) - write(io, meta) - data_offset = position(io) - - save(ctx, io, collect(ctx, chunk)) - - Chunk(chunktype(chunk), domain(chunk), FileReader(file_path, chunktype(chunk), data_offset, false), false) -end - -function save(ctx, io::IO, chunk::DArray, file_path::AbstractString, saved_parts::AbstractArray) - - metadata = (chunktype(chunk), domain(chunk), saved_parts) - - # save yourself - write(io, CAT) - serialize(io, metadata) - - DArray(metadata...) - # write each child -end - - -function save(ctx, io::IO, chunk::DArray, file_path) - dir_path = file_path*"_data" - if !isdir(dir_path) - mkdir(dir_path) - end - - # save the chunks - saved_parts = [save(ctx, c, joinpath(dir_path, lpad(i, 4, "0"))) - for (i, c) in enumerate(chunks(chunk))] - - save(ctx, io, chunk, file_path, saved_parts) - # write each child -end - -function save(ctx, chunk::Chunk{X, FileReader}, file_path::AbstractString) where X - if abspath(file_path) == abspath(chunk.reader.file) - chunk - else - cp(chunk.reader.file, file_path) - Chunk(chunktype(chunk), domain(chunk), - FileReader(file_path, chunktype(chunk), - chunk.reader.data_offset, false), false) - end -end - -save(chunk::Union{Chunk, Thunk}, file_path::AbstractString) = save(Context(global_context()), chunk, file_path) - - - -###### Load chunks ###### - -""" - load(ctx::Context, file_path) - -Load an Union{Chunk, Thunk} from a file. -""" -function load(ctx::Context, file_path::AbstractString; mmap=false) - - open(file_path) do f - part_typ = read(f, UInt8) - if part_typ == PARTSPEC - c = load(ctx, Chunk, file_path, mmap, f) - elseif part_typ == CAT - c = load(ctx, DArray, file_path, mmap, f) - else - error("Could not determine chunk type") - end - end - c -end - -""" - load(ctx::Context, ::Type{Chunk}, fpath, io) - -Load a Chunk object from a file, the file path -is required for creating a FileReader object -""" -function load(ctx::Context, ::Type{Chunk}, fname, mmap, io) - meta_len = read(io, Int) - io = IOBuffer(read(io, meta_len)) - - (T, dmn, sz) = deserialize(io) - - DArray(Chunk(T, dmn, sz, - FileReader(fname, T, meta_len+1, mmap), false)) -end - -function load(ctx::Context, ::Type{DArray}, file_path, mmap, io) - dir_path = file_path*"_data" - - metadata = deserialize(io) - c = DArray(metadata...) - for p in chunks(c) - if isa(p.handle, FileReader) - p.handle.mmap = mmap - end - end - DArray(c) -end - - -###### Save and Load for actual data ##### - -function save(ctx::Context, io::IO, m::Array) - write(io, reinterpret(UInt8, m, (sizeof(m),))) - m -end - -function save(ctx::Context, io::IO, m::BitArray) - save(ctx, io, convert(Array{Bool}, m)) -end - -function collect(ctx::Context, c::Chunk{X,FileReader{T}}) where {X,T<:Array} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - arr = h.mmap ? Mmap.mmap(io, h.chunktype, size(c.domain)) : - reshape(reinterpret(eltype(T), read(io)), size(c.domain)) - close(io) - arr -end - -function collect(ctx::Context, c::Chunk{X, FileReader{T}}) where {X,T<:BitArray} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - - arr = h.mmap ? Mmap.mmap(io, Bool, size(c.domain)) : - reshape(reinterpret(Bool, read(io)), size(c.domain)) - close(io) - arr -end - -function save(ctx::Context, io::IO, m::SparseMatrixCSC{Tv,Ti}) where {Tv, Ti} - write(io, m.m) - write(io, m.n) - write(io, length(m.nzval)) - - typ_io = IOBuffer() - serialize(typ_io, (Tv, Ti)) - buf = take!(typ_io) - write(io, sizeof(buf)) - write(io, buf) - - write(io, reinterpret(UInt8, m.colptr, (sizeof(m.colptr),))) - write(io, reinterpret(UInt8, m.rowval, (sizeof(m.rowval),))) - write(io, reinterpret(UInt8, m.nzval, (sizeof(m.nzval),))) - m -end - -function collect(ctx::Context, c::Chunk{X, FileReader{T}}) where {X, T<:SparseMatrixCSC} - h = c.handle - io = open(h.file, "r+") - seek(io, h.data_offset) - - m = read(io, Int) - n = read(io, Int) - nnz = read(io, Int) - - typ_len = read(io, Int) - typ_bytes = read(io, typ_len) - (Tv, Ti) = deserialize(IOBuffer(typ_bytes)) - - pos = position(io) - colptr = Mmap.mmap(io, Vector{Ti}, (n+1,), pos) - - pos += sizeof(Ti)*(n+1) - rowval = Mmap.mmap(io, Vector{Ti}, (nnz,), pos) - - pos += sizeof(Ti)*nnz - nnzval = Mmap.mmap(io, Vector{Tv}, (nnz,), pos) - close(io) - - SparseMatrixCSC(m, n, colptr, rowval, nnzval) -end - -function getsub(ctx::Context, c::Chunk{X,FileReader{T}}, d) where {X,T<:AbstractArray} - Chunk(collect(ctx, c)[d]) -end - - -#### Save computation - -struct Save <: Computation - input - name::AbstractString -end - -function save(p::Computation, name::AbstractString) - Save(p, name) -end - -function stage(ctx::Context, s::Save) - x = stage(ctx, s.input) - dir_path = s.name * "_data" - if !isdir(dir_path) - mkdir(dir_path) - end - function save_part(idx, data) - p = tochunk(data) - path = joinpath(dir_path, lpad(idx, 4, "0")) - saved = save(ctx, p, path) - - # release reference created for the purpose of save - release_token(p.handle) - saved - end - - saved_parts = similar(chunks(x), Thunk) - for i=1:length(chunks(x)) - saved_parts[i] = Thunk(save_part, i, chunks(x)[i]) - end - - sz = size(chunks(x)) - function save_cat_meta(chunks...) - f = open(s.name, "w") - saved_parts = reshape(Union{Chunk, Thunk}[c for c in chunks], sz) - res = save(ctx, f, x, s.name, saved_parts) - close(f) - res - end - - # The DAG has to block till saving is complete. - res = Thunk(save_cat_meta, saved_parts...; meta=true) -end diff --git a/src/precompile.jl b/src/precompile.jl index 9c37580c0..8b5c6a505 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -21,7 +21,7 @@ # Halt scheduler notify(state.halt) - put!(state.chan, (1, nothing, nothing, (Sch.SchedulerHaltedException(), nothing))) + put!(state.chan, Sch.TaskResult(1, OSProc(), 0, Sch.SchedulerHaltedException(), nothing)) state = nothing # Wait for halt diff --git a/src/queue.jl b/src/queue.jl index 71789c6fb..692f817fa 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,6 +1,5 @@ mutable struct DTaskSpec - f - args::Vector{Pair{Union{Symbol,Nothing},Any}} + fargs::Vector{Argument} options::NamedTuple end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 54c8eab96..01c4f29f2 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -4,16 +4,21 @@ import Distributed: Future, ProcessExitedException, RemoteChannel, RemoteExcepti import MemPool import MemPool: DRef, StorageResource import MemPool: poolset, storage_capacity, storage_utilized -import Random: randperm +import Random: randperm, randperm! import Base: @invokelatest import ..Dagger -import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, ThunkFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, LockedObject -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, thunk_processor, constrain, cputhreadtime -import ..Dagger: @dagdebug, @safe_lock_spin1 +import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, ThunkFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, LockedObject, Argument +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, get_processors, get_parent, execute!, rmprocs!, thunk_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek -import ..Dagger +import ..Dagger: ReusableCache, ReusableLinkedList, ReusableDict +import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks + +import TimespanLogging + +import TaskLocalValues: TaskLocalValue const OneToMany = Dict{Thunk, Set{Thunk}} @@ -40,7 +45,15 @@ function Base.show(io::IO, entry::ProcessorCacheEntry) print(io, "ProcessorCacheEntry(pid $(entry.gproc.pid), $(entry.proc), $entries entries)") end -const Signature = Vector{Any} +struct TaskResult + pid::Int + proc::Processor + thunk_id::Int + result::Any + metadata::Union{NamedTuple,Nothing} +end + +const AnyTaskResult = Union{RescheduleSignal, TaskResult} """ ComputeState @@ -72,7 +85,7 @@ Fields: - `futures::Dict{Thunk, Vector{ThunkFuture}}` - Futures registered for waiting on the result of a thunk. - `errored::WeakKeyDict{Thunk,Bool}` - Indicates if a thunk's result is an error. - `thunks_to_delete::Set{Thunk}` - The list of `Thunk`s ready to be deleted upon completion. -- `chan::RemoteChannel{Channel{Any}}` - Channel for receiving completed thunks. +- `chan::RemoteChannel{Channel{AnyTaskResult}}` - Channel for receiving completed thunks. """ struct ComputeState uid::UInt64 @@ -99,7 +112,7 @@ struct ComputeState futures::Dict{Thunk, Vector{ThunkFuture}} errored::WeakKeyDict{Thunk,Bool} thunks_to_delete::Set{Thunk} - chan::RemoteChannel{Channel{Any}} + chan::RemoteChannel{Channel{AnyTaskResult}} end const UID_COUNTER = Threads.Atomic{UInt64}(1) @@ -276,36 +289,39 @@ Base.merge(sopts::SchedulerOptions, ::Nothing) = nothing, sopts.allow_errors) """ - populate_defaults(opts::ThunkOptions, Tf, Targs) -> ThunkOptions + populate_defaults(opts::ThunkOptions, sig::Signature) -> ThunkOptions -Returns a `ThunkOptions` with default values filled in for a function of type -`Tf` with argument types `Targs`, if the option was previously unspecified in -`opts`. +Returns a `ThunkOptions` with default values filled in for a function with type +signature `sig`, if the option was previously unspecified in `opts`. """ -function populate_defaults(opts::ThunkOptions, Tf, Targs) - function maybe_default(opt::Symbol) - old_opt = getproperty(opts, opt) - if old_opt !== nothing - return old_opt - else - return Dagger.default_option(Val(opt), Tf, Targs...) - end - end +function populate_defaults(opts::ThunkOptions, sig::Signature) ThunkOptions( - maybe_default(:single), - maybe_default(:proclist), - maybe_default(:time_util), - maybe_default(:alloc_util), - maybe_default(:occupancy), - maybe_default(:allow_errors), - maybe_default(:checkpoint), - maybe_default(:restore), - maybe_default(:storage), - maybe_default(:storage_root_tag), - maybe_default(:storage_leaf_tag), - maybe_default(:storage_retain), + maybe_default(opts, Val{:single}(), sig), + maybe_default(opts, Val{:proclist}(), sig), + maybe_default(opts, Val{:time_util}(), sig), + maybe_default(opts, Val{:alloc_util}(), sig), + maybe_default(opts, Val{:occupancy}(), sig), + maybe_default(opts, Val{:allow_errors}(), sig), + maybe_default(opts, Val{:checkpoint}(), sig), + maybe_default(opts, Val{:restore}(), sig), + maybe_default(opts, Val{:storage}(), sig), + maybe_default(opts, Val{:storage_root_tag}(), sig), + maybe_default(opts, Val{:storage_leaf_tag}(), sig), + maybe_default(opts, Val{:storage_retain}(), sig), ) end +function maybe_default(opts, ::Val{opt}, sig::Signature) where opt + old_opt = getfield(opts, opt) + if old_opt !== nothing + return old_opt + else + @warn "SIGNATURE_DEFAULT_CACHE should use an LRU" maxlog=1 + return get!(SIGNATURE_DEFAULT_CACHE[], (sig.hash_nokw, opt)) do + Dagger.default_option(Val{opt}(), sig.sig_nokw...) + end + end +end +const SIGNATURE_DEFAULT_CACHE = TaskLocalValue{Dict{Tuple{UInt,Symbol},Any}}(()->Dict{Tuple{UInt,Symbol},Any}()) function cleanup(ctx) end @@ -318,7 +334,7 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - timespan_start(ctx, :init_proc, (;worker=p.pid), nothing) + @maybelog ctx timespan_start(ctx, :init_proc, (;worker=p.pid), nothing) # Initialize pressure and capacity gproc = OSProc(p.pid) lock(state.lock) do @@ -355,7 +371,7 @@ function init_proc(state, p, log_sink) d = WORKER_MONITOR_CHANS[wid] for uid in keys(d) try - put!(d[uid], (wid, OSProc(wid), nothing, (ProcessExitedException(wid), nothing))) + put!(d[uid], TaskResult(wid, OSProc(wid), 0, ProcessExitedException(wid), nothing)) catch end end @@ -383,7 +399,7 @@ function init_proc(state, p, log_sink) # Setup dynamic listener dynamic_listener!(ctx, state, p.pid) - timespan_finish(ctx, :init_proc, (;worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;worker=p.pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -399,7 +415,7 @@ end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) wid = p.pid - timespan_start(ctx, :cleanup_proc, (;worker=wid), nothing) + @maybelog ctx timespan_start(ctx, :cleanup_proc, (;worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) delete!(WORKER_MONITOR_CHANS[wid], state.uid) @@ -419,7 +435,7 @@ function cleanup_proc(state, p, log_sink) end end - timespan_finish(ctx, :cleanup_proc, (;worker=wid), nothing) + @maybelog ctx timespan_finish(ctx, :cleanup_proc, (;worker=wid), nothing) end "Process-local condition variable (and lock) indicating task completion." @@ -458,7 +474,7 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) end end - chan = RemoteChannel(()->Channel(typemax(Int))) + chan = RemoteChannel(()->Channel{AnyTaskResult}(typemax(Int))) deps = dependents(d) ord = order(d, noffspring(deps)) @@ -467,24 +483,24 @@ function compute_dag(ctx, d::Thunk; options=SchedulerOptions()) master = OSProc(myid()) - timespan_start(ctx, :scheduler_init, nothing, master) + @maybelog ctx timespan_start(ctx, :scheduler_init, nothing, master) try scheduler_init(ctx, state, d, options, deps) finally - timespan_finish(ctx, :scheduler_init, nothing, master) + @maybelog ctx timespan_finish(ctx, :scheduler_init, nothing, master) end value, errored = try scheduler_run(ctx, state, d, options) finally # Always try to tear down the scheduler - timespan_start(ctx, :scheduler_exit, nothing, master) + @maybelog ctx timespan_start(ctx, :scheduler_exit, nothing, master) try scheduler_exit(ctx, state, options) catch err @error "Error when tearing down scheduler" exception=(err,catch_backtrace()) finally - timespan_finish(ctx, :scheduler_exit, nothing, master) + @maybelog ctx timespan_finish(ctx, :scheduler_exit, nothing, master) end end @@ -539,20 +555,26 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) while !isempty(state.ready) || !isempty(state.running) if !isempty(state.ready) # Nothing running, so schedule up to N thunks, 1 per N workers - schedule!(ctx, state) + @invokelatest schedule!(ctx, state) end check_integrity(ctx) isempty(state.running) && continue - timespan_start(ctx, :take, nothing, nothing) + @maybelog ctx timespan_start(ctx, :take, nothing, nothing) @dagdebug nothing :take "Waiting for results" - chan_value = take!(state.chan) # get result of completed thunk - timespan_finish(ctx, :take, nothing, nothing) - if chan_value isa RescheduleSignal + tresult = take!(state.chan) # get result of completed thunk + @maybelog ctx timespan_finish(ctx, :take, nothing, nothing) + if tresult isa RescheduleSignal continue end - pid, proc, thunk_id, (res, metadata) = chan_value + + tresult::TaskResult + pid = tresult.pid + proc = tresult.proc + thunk_id = tresult.thunk_id + res = tresult.result + @dagdebug thunk_id :take "Got finished task" gproc = OSProc(pid) safepoint(state) @@ -563,13 +585,13 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) @warn "Worker $(pid) died, rescheduling work" # Remove dead worker from procs list - timespan_start(ctx, :remove_procs, (;worker=pid), nothing) + @maybelog ctx timespan_start(ctx, :remove_procs, (;worker=pid), nothing) remove_dead_proc!(ctx, state, gproc) - timespan_finish(ctx, :remove_procs, (;worker=pid), nothing) + @maybelog ctx timespan_finish(ctx, :remove_procs, (;worker=pid), nothing) - timespan_start(ctx, :handle_fault, (;worker=pid), nothing) + @maybelog ctx timespan_start(ctx, :handle_fault, (;worker=pid), nothing) handle_fault(ctx, state, gproc) - timespan_finish(ctx, :handle_fault, (;worker=pid), nothing) + @maybelog ctx timespan_finish(ctx, :handle_fault, (;worker=pid), nothing) return # effectively `continue` else if something(ctx.options.allow_errors, false) || @@ -580,13 +602,14 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end end end - node = unwrap_weak_checked(state.thunk_dict[thunk_id]) + node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk + metadata = tresult.metadata if metadata !== nothing state.worker_time_pressure[pid][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[pid] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -594,8 +617,7 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ 2 end end - state.cache[node] = res - state.errored[node] = thunk_failed + store_result!(state, node, res; error=thunk_failed) if node.options !== nothing && node.options.checkpoint !== nothing try @invokelatest node.options.checkpoint(node, res) @@ -604,18 +626,20 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) end end - timespan_start(ctx, :finish, (;thunk_id), (;thunk_id)) + @maybelog ctx timespan_start(ctx, :finish, (;thunk_id), (;thunk_id)) finish_task!(ctx, state, node, thunk_failed) - timespan_finish(ctx, :finish, (;thunk_id), (;thunk_id)) - - delete_unused_tasks!(state) + @maybelog ctx timespan_finish(ctx, :finish, (;thunk_id), (;thunk_id)) end + # Allow data to be GC'd + tresult = nothing + res = nothing + safepoint(state) end # Final value is ready - value = state.cache[d] + value = load_result(state, d) errored = get(state.errored, d, false) if !errored if options.checkpoint !== nothing @@ -674,6 +698,17 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() +struct ScheduleTaskLocation + gproc::OSProc + proc::Processor +end +struct ScheduleTaskSpec + task::Thunk + scope::Dagger.AbstractScope + est_time_util::UInt64 + est_alloc_util::UInt64 + est_occupancy::UInt32 +end function schedule!(ctx, state, procs=procs_to_use(ctx)) lock(state.lock) do safepoint(state) @@ -682,21 +717,21 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) populate_processor_cache_list!(state, procs) # Schedule tasks - to_fire = Dict{Tuple{OSProc,<:Processor},Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}}() - failed_scheduling = Thunk[] + to_fire = @reusable_dict :schedule!_to_fire ScheduleTaskLocation Vector{ScheduleTaskSpec} ScheduleTaskLocation(OSProc(), OSProc()) ScheduleTaskSpec[] 32 + failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 # Select a new task and get its options task = nothing @label pop_task if task !== nothing - timespan_finish(ctx, :schedule, (;thunk_id=task.id), (;thunk_id=task.id)) + @maybelog ctx timespan_finish(ctx, :schedule, (;thunk_id=task.id), (;thunk_id=task.id)) end if isempty(state.ready) @goto fire_tasks end task = pop!(state.ready) - timespan_start(ctx, :schedule, (;thunk_id=task.id), (;thunk_id=task.id)) - if haskey(state.cache, task) + @maybelog ctx timespan_start(ctx, :schedule, (;thunk_id=task.id), (;thunk_id=task.id)) + if has_result(state, task) if haskey(state.errored, task) # An error was eagerly propagated to this task finish_failed!(state, task) @@ -705,19 +740,27 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) iob = IOBuffer() println(iob, "Scheduling inconsistency: Task being scheduled is already cached!") println(iob, " Task: $(task.id)") - println(iob, " Cache Entry: $(typeof(state.cache[task]))") + println(iob, " Cache Entry: $(typeof(something(task.cache_ref)))") ex = SchedulingException(String(take!(iob))) - state.cache[task] = ex - state.errored[task] = true + store_result!(state, task, ex; error=true) end @goto pop_task end - opts = merge(ctx.options, task.options) + + # Load task inputs + collect_task_inputs!(state, task) + + # Calculate signature sig = signature(state, task) + # Merge options and fill defaults + opts = merge(ctx.options, task.options) + opts = populate_defaults(opts, sig) + # Calculate scope - scope = if task.f isa Chunk - task.f.scope + f = Dagger.value(task.inputs[1]) + scope = if f isa Chunk + f.scope else if task.options.proclist !== nothing # proclist overrides scope selection @@ -726,11 +769,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) DefaultScope() end end - for (_,input) in task.inputs - input = unwrap_weak_checked(input) - chunk = if istask(input) - state.cache[input] - elseif input isa Chunk + for input in task.inputs + input = unwrap_weak_checked(Dagger.value(input)) + chunk = if input isa Chunk input else nothing @@ -739,38 +780,41 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) scope = constrain(scope, chunk.scope) if scope isa Dagger.InvalidScope ex = SchedulingException("Scopes are not compatible: $(scope.x), $(scope.y)") - state.cache[task] = ex - state.errored[task] = true + store_result!(state, task, ex; error=true) set_failed!(state, task) @goto pop_task end end - fallback_threshold = 1024 # TODO: Parameterize this threshold - if length(procs) > fallback_threshold - @goto fallback - end - local_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in procs]...)) - if length(local_procs) > fallback_threshold - @goto fallback + # FIXME: Use compatible_processors + input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + for gp in procs + subprocs = get_processors(gp) + for proc in subprocs + if !(proc in input_procs) + push!(input_procs, proc) + end + end end - inputs = map(last, collect_task_inputs(state, task)) - opts = populate_defaults(opts, chunktype(task.f), map(chunktype, inputs)) - local_procs, costs = estimate_task_costs(state, local_procs, task, inputs) + sorted_procs = @reusable_vector :schedule!_sorted_procs Processor OSProc() 32 + resize!(sorted_procs, length(input_procs)) + costs = @reusable_dict :schedule!_costs Processor Float64 OSProc() 0.0 32 + estimate_task_costs!(sorted_procs, costs, state, input_procs, task) + empty!(costs) # We don't use costs here scheduled = false # Move our corresponding ThreadProc to be the last considered - if length(local_procs) > 1 + if length(sorted_procs) > 1 sch_threadproc = Dagger.ThreadProc(myid(), Threads.threadid()) - sch_thread_idx = findfirst(proc->proc==sch_threadproc, local_procs) + sch_thread_idx = findfirst(proc->proc==sch_threadproc, sorted_procs) if sch_thread_idx !== nothing - deleteat!(local_procs, sch_thread_idx) - push!(local_procs, sch_threadproc) + deleteat!(sorted_procs, sch_thread_idx) + push!(sorted_procs, sch_threadproc) end end - for proc in local_procs + for proc in sorted_procs gproc = get_parent(proc) can_use, scope = can_use_proc(task, gproc, proc, opts, scope) if can_use @@ -779,95 +823,35 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util - proc_tasks = get!(to_fire, (gproc, proc)) do - Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}() + proc_tasks = get!(to_fire, ScheduleTaskLocation(gproc, proc)) do + #=FIXME:REALLOC_VEC=# + Vector{ScheduleTaskSpec}() end - push!(proc_tasks, (task, scope, est_time_util, est_alloc_util, est_occupancy)) + push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) state.worker_time_pressure[gproc.pid][proc] = get(state.worker_time_pressure[gproc.pid], proc, 0) + est_time_util @dagdebug task :schedule "Scheduling to $gproc -> $proc" + empty!(sorted_procs) @goto pop_task end end end - state.cache[task] = SchedulingException("No processors available, try widening scope") - state.errored[task] = true + ex = SchedulingException("No processors available, try widening scope") + store_result!(state, task, ex; error=true) set_failed!(state, task) - @goto pop_task - - # Fast fallback algorithm, used when the smarter cost model algorithm - # would be too expensive - @label fallback - selected_entry = nothing - entry = state.procs_cache_list[] - cap, extra_util = nothing, nothing - procs_found = false - # N.B. if we only have one processor, we need to select it now - can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope) - if can_use - has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) - if has_cap - selected_entry = entry - else - procs_found = true - entry = entry.next - end - else - entry = entry.next - end - while selected_entry === nothing - if entry === state.procs_cache_list[] - # Exhausted all procs - if procs_found - push!(failed_scheduling, task) - else - state.cache[task] = SchedulingException("No processors available, try widening scope") - state.errored[task] = true - set_failed!(state, task) - end - @goto pop_task - end - - can_use, scope = can_use_proc(task, entry.gproc, entry.proc, opts, scope) - if can_use - has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, entry.proc, entry.gproc.pid, opts.time_util, opts.alloc_util, opts.occupancy, sig) - if has_cap - # Select this processor - selected_entry = entry - else - # We could have selected it otherwise - procs_found = true - entry = entry.next - end - else - # Try next processor - entry = entry.next - end - end - @assert selected_entry !== nothing - - # Schedule task onto proc - gproc, proc = entry.gproc, entry.proc - est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util - proc_tasks = get!(to_fire, (gproc, proc)) do - Vector{Tuple{Thunk,<:Any,<:Any,UInt64,UInt32}}() - end - push!(proc_tasks, (task, scope, est_time_util, est_alloc_util, est_occupancy)) - - # Proceed to next entry to spread work - state.procs_cache_list[] = state.procs_cache_list[].next + empty!(sorted_procs) @goto pop_task # Fire all newly-scheduled tasks @label fire_tasks - for gpp in keys(to_fire) - fire_tasks!(ctx, to_fire[gpp], gpp, state) + for (task_loc, task_spec) in to_fire + fire_tasks!(ctx, task_loc, task_spec, state) end + empty!(to_fire) append!(state.ready, failed_scheduling) + empty!(failed_scheduling) end end @@ -885,7 +869,7 @@ function monitor_procs_changed!(ctx, state) wait(ctx.proc_notify) end - timespan_start(ctx, :assign_procs, nothing, nothing) + @maybelog ctx timespan_start(ctx, :assign_procs, nothing, nothing) # Load new set of procs new_ps = procs_to_use(ctx) @@ -913,7 +897,7 @@ function monitor_procs_changed!(ctx, state) end end - timespan_finish(ctx, :assign_procs, nothing, nothing) + @maybelog ctx timespan_finish(ctx, :assign_procs, nothing, nothing) old_ps = new_ps end end @@ -935,40 +919,29 @@ function finish_task!(ctx, state, node, thunk_failed) if thunk_failed set_failed!(state, node) end - if node.cache - node.cache_ref = state.cache[node] - end schedule_dependents!(state, node, thunk_failed) fill_registered_futures!(state, node, thunk_failed) + #= to_evict = cleanup_syncdeps!(state, node) if node.f isa Chunk # FIXME: Check the graph for matching chunks - push!(to_evict, node.f) + #push!(to_evict, node.f) end + =# + cleanup_syncdeps!(state, node) if haskey(state.waiting_data, node) && isempty(state.waiting_data[node]) delete!(state.waiting_data, node) end + if !haskey(state.waiting_data, node) + node.sch_accessible = false + delete_unused_task!(state, node) + end #evict_all_chunks!(ctx, to_evict) end -function delete_unused_tasks!(state) - to_delete = Thunk[] - for thunk in state.thunks_to_delete - if task_unused(state, thunk) - # Finished and nobody waiting on us, we can be deleted - push!(to_delete, thunk) - end - end - for thunk in to_delete - # Delete all cached data - task_delete!(state, thunk) - - pop!(state.thunks_to_delete, thunk) - end -end function delete_unused_task!(state, thunk) - if task_unused(state, thunk) + if has_result(state, thunk) && !thunk.eager_accessible && !thunk.sch_accessible # Will not be accessed further, delete all cached data task_delete!(state, thunk) return true @@ -976,11 +949,8 @@ function delete_unused_task!(state, thunk) return false end end -task_unused(state, thunk) = - haskey(state.cache, thunk) && !haskey(state.waiting_data, thunk) function task_delete!(state, thunk) - delete!(state.cache, thunk) - delete!(state.errored, thunk) + clear_result!(state, thunk) delete!(state.valid, thunk) delete!(state.thunk_dict, thunk.id) end @@ -995,46 +965,57 @@ end function evict_chunks!(log_sink, chunks::Set{Chunk}) # Need worker id or else Context might use Processors which user does not want us to use. # In particular workers which have not yet run using Dagger will cause the call below to throw an exception - ctx = Context([myid()];log_sink) + ctx = Context([myid()]; log_sink) for chunk in chunks lock(TASK_SYNC) do - timespan_start(ctx, :evict, (;worker=myid()), (;data=chunk)) + @maybelog ctx timespan_start(ctx, :evict, (;worker=myid()), (;data=chunk)) haskey(CHUNK_CACHE, chunk) && delete!(CHUNK_CACHE, chunk) - timespan_finish(ctx, :evict, (;worker=myid()), (;data=chunk)) + @maybelog ctx timespan_finish(ctx, :evict, (;worker=myid()), (;data=chunk)) end end nothing end -fire_task!(ctx, thunk::Thunk, p, state; scope=AnyScope(), time_util=10^9, alloc_util=10^6, occupancy=typemax(UInt32)) = - fire_task!(ctx, (thunk, scope, time_util, alloc_util, occupancy), p, state) -fire_task!(ctx, (thunk, scope, time_util, alloc_util, occupancy)::Tuple{Thunk,<:Any}, p, state) = - fire_tasks!(ctx, [(thunk, scope, time_util, alloc_util, occupancy)], p, state) -function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) - to_send = [] - for (thunk, scope, time_util, alloc_util, occupancy) in thunks +"A serializable description of a `Thunk` to be executed." +struct TaskSpec + thunk_id::Int + est_time_util::UInt64 + est_alloc_util::UInt64 + est_occupancy::UInt32 + scope::Dagger.AbstractScope + Tf::Type + data::Vector{Argument} + # TODO: Get these from options + get_result::Bool + persist::Bool + cache::Bool + meta::Bool + options#::Options + propagated::NamedTuple + ctx_vars::NamedTuple + sch_handle::SchedulerHandle + sch_uid::UInt64 +end +Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) + +function fire_tasks!(ctx, task_loc::ScheduleTaskLocation, task_specs::Vector{ScheduleTaskSpec}, state) + gproc, proc = task_loc.gproc, task_loc.proc + to_send = @reusable_vector :fire_tasks!_to_send Union{TaskSpec,Nothing} nothing 32 + for task_spec in task_specs + thunk = task_spec.task push!(state.running, thunk) state.running_on[thunk] = gproc - if thunk.cache && thunk.cache_ref !== nothing - # the result might be already cached - data = thunk.cache_ref - if data !== nothing - # cache hit - state.cache[thunk] = data - thunk_failed = get(state.errored, thunk, false) - finish_task!(ctx, state, thunk, thunk_failed) - continue - else - # cache miss - thunk.cache_ref = nothing - end + if has_result(state, thunk) + # the result is already cached + thunk_failed = get(state.errored, thunk, false) + finish_task!(ctx, state, thunk, thunk_failed) + continue end if thunk.options !== nothing && thunk.options.restore !== nothing try result = @invokelatest thunk.options.restore(thunk) if result isa Chunk - state.cache[thunk] = result - state.errored[thunk] = false + store_result!(state, thunk, result) finish_task!(ctx, state, thunk, false) continue elseif result !== nothing @@ -1045,16 +1026,14 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) end end - ids = Int[0] - data = Any[thunk.f] - positions = Union{Symbol,Nothing}[] - for (idx, pos_x) in enumerate(thunk.inputs) - pos, x = pos_x - x = unwrap_weak_checked(x) - push!(ids, istask(x) ? x.id : -idx) - push!(data, istask(x) ? state.cache[x] : x) - push!(positions, pos) + # Unwrap any weak arguments + args = map(copy, thunk.inputs) + for arg in args + # TODO: Only for non-delayed: @assert Dagger.isweak(Dagger.value(arg)) "Non-weak argument: $(arg)" + arg.value = unwrap_weak_checked(Dagger.value(arg)) end + Tf = chunktype(first(args)) + toptions = thunk.options !== nothing ? thunk.options : ThunkOptions() options = merge(ctx.options, toptions) propagated = get_propagated_options(thunk) @@ -1063,32 +1042,50 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state) sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) - push!(to_send, Any[thunk.id, time_util, alloc_util, occupancy, - scope, chunktype(thunk.f), data, - thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options, - propagated, ids, positions, - (log_sink=ctx.log_sink, profile=ctx.profile), - sch_handle, state.uid]) + push!(to_send, TaskSpec( + thunk.id, + task_spec.est_time_util, task_spec.est_alloc_util, task_spec.est_occupancy, + task_spec.scope, Tf, args, + thunk.get_result, thunk.persist, thunk.cache, thunk.meta, options, + propagated, + (log_sink=ctx.log_sink, profile=ctx.profile), + sch_handle, state.uid)) end + # N.B. We don't batch these because we might get a deserialization # error due to something not being defined on the worker, and then we don't # know which task failed. - tasks = Task[] - for ts in to_send - # TODO: errormonitor - @async begin - timespan_start(ctx, :fire, (;worker=gproc.pid), nothing) - try - remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts]); - catch err - bt = catch_backtrace() - thunk_id = ts[1] - put!(state.chan, (gproc.pid, proc, thunk_id, (CapturedException(err, bt), nothing))) - finally - timespan_finish(ctx, :fire, (;worker=gproc.pid), nothing) - end - end + for (idx, task_spec) in enumerate(to_send) + @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, task_spec) end + empty!(to_send) +end + +struct FireTaskSpec + init_proc::Processor + return_chan::RemoteChannel + task::TaskSpec +end +function (ets::FireTaskSpec)() + task = ets.task + ctx_vars = task.ctx_vars + ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) + + proc = ets.init_proc + chan = ets.return_chan + pid = Dagger.root_worker_id(proc) + + @maybelog ctx timespan_start(ctx, :fire, (;worker=pid), nothing) + try + remotecall_wait(do_tasks, pid, proc, chan, [task]); + catch err + bt = catch_backtrace() + thunk_id = task.thunk_id + put!(chan, TaskResult(pid, proc, thunk_id, CapturedException(err, bt), nothing)) + finally + @maybelog ctx timespan_finish(ctx, :fire, (;worker=pid), nothing) + end + return end @static if VERSION >= v"1.9" @@ -1151,18 +1148,10 @@ function Base.notify(db::Doorbell) end end -struct TaskSpecKey - task_id::Int - task_spec::Vector{Any} - TaskSpecKey(task_spec::Vector{Any}) = new(task_spec[1], task_spec) -end -Base.getindex(key::TaskSpecKey) = key.task_spec -Base.hash(key::TaskSpecKey, h::UInt) = hash(key.task_id, hash(TaskSpecKey, h)) - struct ProcessorInternalState ctx::Context proc::Processor - queue::LockedObject{PriorityQueue{TaskSpecKey, UInt32, Base.Order.ForwardOrdering}} + queue::LockedObject{PriorityQueue{TaskSpec, UInt32, Base.Order.ForwardOrdering}} reschedule::Doorbell tasks::Dict{Int,Task} proc_occupancy::Base.RefValue{UInt32} @@ -1211,12 +1200,12 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Wait for new tasks if !work_to_do @dagdebug nothing :processor "Waiting for tasks" - timespan_start(ctx, :proc_run_wait, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :proc_run_wait, (;worker=wid, processor=to_proc), nothing) wait(istate.reschedule) @static if VERSION >= v"1.9" reset(istate.reschedule) end - timespan_finish(ctx, :proc_run_wait, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :proc_run_wait, (;worker=wid, processor=to_proc), nothing) if istate.done[] return end @@ -1224,7 +1213,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Fetch a new task to execute @dagdebug nothing :processor "Trying to dequeue" - timespan_start(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), nothing) work_to_do = false task_and_occupancy = lock(istate.queue) do queue # Only steal if there are multiple queued tasks, to prevent @@ -1243,7 +1232,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re return queue_result end if task_and_occupancy === nothing - timespan_finish(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), nothing) @dagdebug nothing :processor "Failed to dequeue" @@ -1258,7 +1247,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re @dagdebug nothing :processor "Trying to steal" # Try to steal a task - timespan_start(ctx, :steal_local, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_start(ctx, :steal_local, (;worker=wid, processor=to_proc), nothing) # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors @@ -1278,9 +1267,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re if length(queue) == 0 return nothing end - task_spec, occupancy = peek(queue) - task = task_spec[] - scope = task[5] + task, occupancy = peek(queue) + scope = task.scope if !isa(constrain(scope, Dagger.ExactScope(to_proc)), Dagger.InvalidScope) && typemax(UInt32) - proc_occupancy_cached >= occupancy @@ -1291,14 +1279,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end if task_and_occupancy !== nothing from_proc = other_istate.proc - thunk_id = task[1] + thunk_id = task.thunk_id @dagdebug thunk_id :processor "Stolen from $from_proc by $to_proc" - timespan_finish(ctx, :steal_local, (;worker=wid, processor=to_proc), (;from_proc, thunk_id)) + @maybelog ctx timespan_finish(ctx, :steal_local, (;worker=wid, processor=to_proc), (;from_proc, thunk_id)) # TODO: Keep stealing until we hit full occupancy? @goto execute end end - timespan_finish(ctx, :steal_local, (;worker=wid, processor=to_proc), nothing) + @maybelog ctx timespan_finish(ctx, :steal_local, (;worker=wid, processor=to_proc), nothing) # TODO: Try to steal from remote queues @@ -1306,46 +1294,25 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end @label execute - task_spec, task_occupancy = task_and_occupancy - task = task_spec[] - thunk_id = task[1] - time_util = task[2] - timespan_finish(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), (;thunk_id, proc_occupancy=proc_occupancy[], task_occupancy)) + task, task_occupancy = task_and_occupancy + thunk_id = task.thunk_id + time_util = task.est_time_util + @maybelog ctx timespan_finish(ctx, :proc_run_fetch, (;worker=wid, processor=to_proc), (;thunk_id, proc_occupancy=proc_occupancy[], task_occupancy)) @dagdebug thunk_id :processor "Dequeued task" # Execute the task and return its result - t = @task begin - result = try - do_task(to_proc, task) - catch err - bt = catch_backtrace() - (CapturedException(err, bt), nothing) - finally - lock(istate.queue) do _ - delete!(tasks, thunk_id) - proc_occupancy[] -= task_occupancy - time_pressure[] -= time_util - end - notify(istate.reschedule) - end - try - put!(return_queue, (myid(), to_proc, thunk_id, result)) - catch err - if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) - @dagdebug thunk_id :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) + t = @reusable_tasks :start_processor_runner!_task_cache 32 t->begin + lock(istate.queue) do _ + tid = task_tid_for_processor(to_proc) + if tid !== nothing + Dagger.set_task_tid!(t, tid) else - rethrow(err) + t.sticky = false end end - end + end "thunk $thunk_id" DoTaskSpec(to_proc, return_queue, task) lock(istate.queue) do _ - tid = task_tid_for_processor(to_proc) - if tid !== nothing - Dagger.set_task_tid!(t, tid) - else - t.sticky = false - end - tasks[thunk_id] = errormonitor_tracked("thunk $thunk_id", schedule(t)) + tasks[thunk_id] = t proc_occupancy[] += task_occupancy time_pressure[] += time_util end @@ -1359,6 +1326,44 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end return errormonitor_tracked("processor $to_proc", schedule(proc_run_task)) end +struct DoTaskSpec + to_proc::Processor + chan::RemoteChannel + task::TaskSpec +end +function (dts::DoTaskSpec)() + to_proc = dts.to_proc + task = dts.task + tid = task.thunk_id + r = rand() + result, metadata = try + do_task(to_proc, task) + catch err + bt = catch_backtrace() + (CapturedException(err, bt), nothing) + finally + istate = proc_states(task.sch_uid) do states + states[to_proc].state + end + lock(istate.queue) do _ + delete!(istate.tasks, tid) + istate.proc_occupancy[] -= task.est_occupancy + istate.time_pressure[] -= task.est_time_util + end + notify(istate.reschedule) + end + return_queue = dts.chan + try + put!(return_queue, TaskResult(myid(), to_proc, tid, result, metadata)) + catch err + if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) + @dagdebug tid :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) + else + rethrow(err) + end + end + return +end """ do_tasks(to_proc, return_queue, tasks) @@ -1369,13 +1374,12 @@ Executes a batch of tasks on `to_proc`, returning their results through function do_tasks(to_proc, return_queue, tasks) @dagdebug nothing :processor "Enqueuing task batch" batch_size=length(tasks) - # FIXME: This is terrible - ctx_vars = first(tasks)[16] + ctx_vars = first(tasks).ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - uid = first(tasks)[18] + uid = first(tasks).sch_uid state = proc_states(uid) do states get!(states, to_proc) do - queue = PriorityQueue{TaskSpecKey, UInt32}() + queue = PriorityQueue{TaskSpec, UInt32}() queue_locked = LockedObject(queue) reschedule = Doorbell() istate = ProcessorInternalState(ctx, to_proc, @@ -1393,9 +1397,9 @@ function do_tasks(to_proc, return_queue, tasks) istate = state.state lock(istate.queue) do queue for task in tasks - thunk_id = task[1] - occupancy = task[4] - timespan_start(ctx, :enqueue, (;processor=to_proc, thunk_id), nothing) + thunk_id = task.thunk_id + occupancy = task.est_occupancy + @maybelog ctx timespan_start(ctx, :enqueue, (;processor=to_proc, thunk_id), nothing) should_launch = lock(TASK_SYNC) do # Already running; don't try to re-launch if !(thunk_id in TASKS_RUNNING) @@ -1406,8 +1410,8 @@ function do_tasks(to_proc, return_queue, tasks) end end should_launch || continue - enqueue!(queue, TaskSpecKey(task), occupancy) - timespan_finish(ctx, :enqueue, (;processor=to_proc, thunk_id), nothing) + enqueue!(queue, task, occupancy) + @maybelog ctx timespan_finish(ctx, :enqueue, (;processor=to_proc, thunk_id), nothing) @dagdebug thunk_id :processor "Enqueued task" end end @@ -1428,42 +1432,48 @@ function do_tasks(to_proc, return_queue, tasks) end """ - do_task(to_proc, task_desc) -> Any + do_task(to_proc, task::TaskSpec) -> Any -Executes a single task specified by `task_desc` on `to_proc`. +Executes a single task specified by `task` on `to_proc`. """ -function do_task(to_proc, task_desc) - thunk_id, est_time_util, est_alloc_util, est_occupancy, - scope, Tf, data, - send_result, persist, cache, meta, - options, propagated, ids, positions, - ctx_vars, sch_handle, sch_uid = task_desc +function do_task(to_proc, task::TaskSpec) + thunk_id = task.thunk_id + + ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) from_proc = OSProc() - Tdata = Any[] - for x in data - push!(Tdata, chunktype(x)) - end + #Tdata = Any[] + data = task.data + #for arg in data + # push!(Tdata, chunktype(Dagger.value(arg))) + #end + Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available + options = task.options to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] - to_storage_name = nameof(typeof(to_storage)) - storage_cap = storage_capacity(to_storage) + #to_storage_name = nameof(typeof(to_storage)) + #storage_cap = storage_capacity(to_storage) - timespan_start(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + @maybelog ctx timespan_start(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + est_time_util = task.est_time_util + est_alloc_util = task.est_alloc_util real_time_util = Ref{UInt64}(0) real_alloc_util = UInt64(0) + #= TODO if !meta # Factor in the memory costs for our lazy arguments for arg in data[2:end] - if arg isa Chunk - est_alloc_util += arg.handle.size + if Dagger.valuetype(arg) <: Chunk + est_alloc_util += Dagger.value(arg).handle.size end end end + =# + #= FIXME: Serialize on over-memory situation debug_storage(msg::String) = @debug begin let est_alloc_util=Base.format_bytes(est_alloc_util), real_alloc_util=Base.format_bytes(real_alloc_util), @@ -1474,7 +1484,7 @@ function do_task(to_proc, task_desc) lock(TASK_SYNC) do while true # Get current time utilization for the selected processor - time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, sch_uid) + time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, task.sch_uid) real_time_util = get!(()->Ref{UInt64}(UInt64(0)), time_dict, to_proc) # Get current allocation utilization and capacity @@ -1483,48 +1493,51 @@ function do_task(to_proc, task_desc) # Check if we'll go over memory capacity from running this thunk # Waits for free storage, if necessary - #= TODO: Implement a priority queue, ordered by est_alloc_util - if est_alloc_util > storage_cap - debug_storage("WARN: Estimated utilization above storage capacity on $to_storage_name, proceeding anyway") - break - end - if est_alloc_util + real_alloc_util > storage_cap - if MemPool.externally_varying(to_storage) - debug_storage("WARN: Insufficient space and allocation behavior is externally varying on $to_storage_name, proceeding anyway") - break - end - if length(TASKS_RUNNING) <= 2 # This task + eager submission task - debug_storage("WARN: Insufficient space and no other running tasks on $to_storage_name, proceeding anyway") - break - end - # Fully utilized, wait and re-check - debug_storage("Waiting for free $to_storage_name") - wait(TASK_SYNC) - else - # Sufficient free storage is available, prepare for execution - debug_storage("Using available $to_storage_name") - break - end - =# + # TODO: Implement a priority queue, ordered by est_alloc_util + #if est_alloc_util > storage_cap + # debug_storage("WARN: Estimated utilization above storage capacity on $to_storage_name, proceeding anyway") + # break + #end + #if est_alloc_util + real_alloc_util > storage_cap + # if MemPool.externally_varying(to_storage) + # debug_storage("WARN: Insufficient space and allocation behavior is externally varying on $to_storage_name, proceeding anyway") + # break + # end + # if length(TASKS_RUNNING) <= 2 # This task + eager submission task + # debug_storage("WARN: Insufficient space and no other running tasks on $to_storage_name, proceeding anyway") + # break + # end + # # Fully utilized, wait and re-check + # debug_storage("Waiting for free $to_storage_name") + # wait(TASK_SYNC) + #else + # # Sufficient free storage is available, prepare for execution + # debug_storage("Using available $to_storage_name") + # break + #end # FIXME break end end - timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) + =# @dagdebug thunk_id :execute "Moving data" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) transfer_size = Threads.Atomic{UInt64}(0) - _data, _ids = if meta - (Any[first(data)], Int[first(ids)]) # always fetch function + _data = if task.meta + Argument[first(data)] # always fetch function else - (data, ids) + data end - fetch_tasks = map(Iterators.zip(_data,_ids)) do (x, id) + fetch_tasks = map(_data) do arg + #=FIXME:REALLOC_TASKS=# @async begin - timespan_start(ctx, :move, (;thunk_id, id, processor=to_proc), (;f, data=x)) + value = Dagger.value(arg) + pos = arg.pos + @maybelog ctx timespan_start(ctx, :move, (;thunk_id, pos, processor=to_proc), (;f, data=value)) #= FIXME: This isn't valid if x is written to x = if x isa Chunk value = lock(TASK_SYNC) do @@ -1567,31 +1580,28 @@ function do_task(to_proc, task_desc) end else =# - x = @invokelatest move(to_proc, x) + new_value = @invokelatest move(to_proc, value) #end - @dagdebug thunk_id :move "Moved argument $id to $to_proc: $(typeof(x))" - timespan_finish(ctx, :move, (;thunk_id, id, processor=to_proc), (;f, data=x); tasks=[Base.current_task()]) - return x + @dagdebug thunk_id :move "Moved argument @ $pos to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, pos, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) + arg.value = new_value + return end end - fetched = Any[] for task in fetch_tasks - push!(fetched, fetch_report(task)) - end - if meta - append!(fetched, data[2:end]) + fetch_report(task) end - f = popfirst!(fetched) + f = Dagger.value(first(data)) @assert !(f isa Chunk) "Failed to unwrap thunk function" - fetched_args = Any[] - fetched_kwargs = Pair{Symbol,Any}[] - for (idx, x) in enumerate(fetched) - pos = positions[idx] - if pos === nothing - push!(fetched_args, x) + fetched_args = @reusable_vector :do_task_fetched_args Any nothing 32 + fetched_kwargs = @reusable_vector :do_task_fetched_kwargs Pair{Symbol,Any} :NULL=>nothing 32 + for idx in 2:length(data) + arg = data[idx] + if Dagger.ispositional(arg) + push!(fetched_args, Dagger.value(arg)) else - push!(fetched_kwargs, pos => x) + push!(fetched_kwargs, Dagger.pos_kw(arg) => Dagger.value(arg)) end end @@ -1604,8 +1614,7 @@ function do_task(to_proc, task_desc) =# real_time_util[] += est_time_util - timespan_start(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) - res = nothing + @maybelog ctx timespan_start(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) # Start counting time and GC allocations threadtime_start = cputhreadtime() @@ -1614,47 +1623,60 @@ function do_task(to_proc, task_desc) @dagdebug thunk_id :execute "Executing $(typeof(f))" + logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) + result_meta = try # Set TLS variables - Dagger.set_tls!(( - sch_uid, - sch_handle, + Dagger.set_tls!((; + sch_uid=task.sch_uid, + sch_handle=task.sch_handle, processor=to_proc, - task_spec=task_desc, + task_spec=task, + logging_enabled, )) - res = Dagger.with_options(propagated) do + result = Dagger.with_options(task.propagated) do # Execute execute!(to_proc, f, fetched_args...; fetched_kwargs...) end # Check if result is safe to store + # FIXME: Move here and below *after* timespan_finish for :compute device = nothing - if !(res isa Chunk) - timespan_start(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(res))) - device = if walk_storage_safe(res) + if !(result isa Chunk) + @maybelog ctx timespan_start(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(result))) + device = if walk_storage_safe(result) to_storage else MemPool.CPURAMDevice() end - timespan_finish(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(res))) + @maybelog ctx timespan_finish(ctx, :storage_safe_scan, (;thunk_id, processor=to_proc), (;T=typeof(result))) end # Construct result - # TODO: We should cache this locally - send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache, - tag=options.storage_root_tag, - leaf_tag=something(options.storage_leaf_tag, MemPool.Tag()), - retain=options.storage_retain) + result_meta = if task.get_result || task.meta + result + else + # TODO: We should cache this locally + cache = task.persist || task.cache + tochunk(result, to_proc; device, task.persist, cache, + tag=options.storage_root_tag, + leaf_tag=something(options.storage_leaf_tag, MemPool.Tag()), + retain=options.storage_retain) + end catch ex bt = catch_backtrace() RemoteException(myid(), CapturedException(ex, bt)) + finally + empty!(fetched_args) + empty!(fetched_kwargs) end threadtime = cputhreadtime() - threadtime_start # FIXME: This is not a realistic measure of max. required memory #gc_allocd = min(max(UInt64(Base.gc_num().allocd) - UInt64(gcnum_start.allocd), UInt64(0)), UInt64(1024^4)) - timespan_finish(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) + @maybelog ctx timespan_finish(ctx, :compute, (;thunk_id, processor=to_proc), (;f)) + lock(TASK_SYNC) do real_time_util[] -= est_time_util pop!(TASKS_RUNNING, thunk_id) @@ -1666,12 +1688,13 @@ function do_task(to_proc, task_desc) # TODO: debug_storage("Releasing $to_storage_name") metadata = ( time_pressure=real_time_util[], - storage_pressure=real_alloc_util, - storage_capacity=storage_cap, - loadavg=((Sys.loadavg()...,) ./ Sys.CPU_THREADS), + #storage_pressure=real_alloc_util, + #storage_capacity=storage_cap, + #loadavg=((Sys.loadavg()...,) ./ Sys.CPU_THREADS), threadtime=threadtime, # FIXME: Add runtime allocation tracking - gc_allocd=(isa(result_meta, Chunk) ? result_meta.handle.size : 0), + #gc_allocd=(isa(result_meta, Chunk) ? result_meta.handle.size : 0), + gc_allocd=0, transfer_rate=(transfer_size[] > 0 && transfer_time[] > 0) ? round(UInt64, transfer_size[] / (transfer_time[] / 10^9)) : nothing, ) return (result_meta, metadata) diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index df78a1dd1..b5355a3ee 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -96,13 +96,18 @@ function dynamic_listener!(ctx, state, wid) end end end + return end errormonitor_tracked("dynamic_listener! $wid", listener_task) errormonitor_tracked("dynamic_listener! (halt+throw) $wid", @async begin wait(state.halt) # TODO: Not sure why we need the @async here, but otherwise we # don't stop all the listener tasks - @async Base.throwto(listener_task, SchedulerHaltedException()) + @async begin + Base.throwto(listener_task, SchedulerHaltedException()) + return + end + return end) end @@ -124,7 +129,7 @@ end halt!(h::SchedulerHandle) = exec!(_halt, h, nothing) function _halt(ctx, state, task, tid, _) notify(state.halt) - put!(state.chan, (1, nothing, nothing, (SchedulerHaltedException(), nothing))) + put!(state.chan, TaskResult(1, OSProc(), 0, SchedulerHaltedException(), nothing)) Base.throwto(task, SchedulerHaltedException()) end @@ -172,8 +177,8 @@ function _register_future!(ctx, state, task, tid, (future, id, check)::Tuple{Thu end end # TODO: Assert that future will be fulfilled - if haskey(state.cache, thunk) - put!(future, state.cache[thunk]; error=state.errored[thunk]) + if has_result(state, thunk) + put!(future, load_result(state, thunk); error=state.errored[thunk]) else futures = get!(()->ThunkFuture[], state.futures, thunk) push!(futures, future) @@ -208,37 +213,15 @@ function _get_dag_ids(ctx, state, task, tid, _) end "Adds a new Thunk to the DAG." -add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, options...) = - exec!(_add_thunk!, h, f, args, options, future, ref) -function _add_thunk!(ctx, state, task, tid, (f, args, options, future, ref)) - timespan_start(ctx, :add_thunk, (;thunk_id=tid), (;f, args, options)) - _args = map(args) do pos_arg - if pos_arg[2] isa ThunkID - return pos_arg[1] => state.thunk_dict[pos_arg[2].id] - else - return pos_arg[1] => pos_arg[2] - end - end - GC.@preserve _args begin - thunk = Thunk(f, _args...; options...) - # Create a `DRef` to `thunk` so that the caller can preserve it - thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice()) - thunk_id = ThunkID(thunk.id, thunk_ref) - state.thunk_dict[thunk.id] = WeakThunk(thunk) - reschedule_syncdeps!(state, thunk) - @dagdebug thunk :submit "Added to scheduler" - if future !== nothing - # Ensure we attach a future before the thunk is scheduled - _register_future!(ctx, state, task, tid, (future, thunk_id, false)) - @dagdebug thunk :submit "Registered future" - end - if ref !== nothing - # Preserve the `DTaskFinalizer` through `thunk` - thunk.eager_ref = ref - end - state.valid[thunk] = nothing - put!(state.chan, RescheduleSignal()) - timespan_finish(ctx, :add_thunk, (;thunk_id=tid), (;f, args, options)) - return thunk_id +function add_thunk!(f, h::SchedulerHandle, args...; future=nothing, ref=nothing, options...) + if future !== nothing || ref !== nothing + @warn "`future` and `ref` arguments are no longer supported in `add_thunk!`" maxlog=1 end + return exec!(_add_thunk!, h, f, args, options) +end +function _add_thunk!(ctx, state, task, tid, (f, args, options)) + spec = Dagger.DTaskSpec(Dagger.args_kwargs_to_arguments(f, args), (;options...)) + dtask = Dagger.eager_spawn(spec) + Dagger.eager_launch!(spec=>dtask) + return dtask end diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 2e8886c5e..dbccd7586 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -68,7 +68,7 @@ function thunk_yield(f) proc_istate = proc_states(tls.sch_uid) do states states[proc].state end - task_occupancy = tls.task_spec[4] + task_occupancy = tls.task_spec.est_occupancy # Decrease our occupancy and inform the processor to reschedule lock(proc_istate.queue) do _ @@ -100,24 +100,6 @@ function thunk_yield(f) end end -eager_cleanup(t::Dagger.DTaskFinalizer) = - errormonitor_tracked("eager_cleanup $(t.uid)", Threads.@spawn eager_cleanup(EAGER_STATE[], t.uid)) -function eager_cleanup(state, uid) - tid = nothing - lock(EAGER_ID_MAP) do id_map - if !haskey(id_map, uid) - return - end - tid = id_map[uid] - delete!(id_map, uid) - end - tid === nothing && return - lock(state.lock) do - # N.B. cache and errored expire automatically - delete!(state.thunk_dict, tid) - end -end - function _find_thunk(e::Dagger.DTask) tid = lock(EAGER_ID_MAP) do id_map id_map[e.uid] diff --git a/src/sch/fault-handler.jl b/src/sch/fault-handler.jl index fca184cfa..56ccc3ca1 100644 --- a/src/sch/fault-handler.jl +++ b/src/sch/fault-handler.jl @@ -20,11 +20,13 @@ function handle_fault(ctx, state, deadproc) deadlist = Thunk[] # Evict cache entries that were stored on the worker - for t in keys(state.cache) - v = state.cache[t] + for t in values(state.thunk_dict) + t = unwrap_weak_checked(t) + has_result(state, t) || continue + v = load_result(state, t) if v isa Chunk && v.handle isa DRef && v.handle.owner == deadproc.pid push!(deadlist, t) - pop!(state.cache, t) + clear_result!(state, t) end end # Remove thunks that were running on the worker diff --git a/src/sch/util.jl b/src/sch/util.jl index 17b0fbfbe..b8521af38 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -18,6 +18,16 @@ function errormonitor_tracked(name::String, t::Task) end end) end +function errormonitor_tracked_set!(name::String, t::Task) + lock(ERRORMONITOR_TRACKED) do tracked + for idx in 1:length(tracked) + if tracked[idx][2] === t + tracked[idx] = name => t + return + end + end + end +end const ERRORMONITOR_TRACKED = LockedObject(Pair{String,Task}[]) """ @@ -29,16 +39,19 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::Dagger.ThunkFailedException) = + unwrap_nested_exception(err.ex) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." function get_propagated_options(thunk) nt = NamedTuple() + f = Dagger.value(thunk.inputs[1]) for key in thunk.propagates value = if key == :scope - isa(thunk.f, Chunk) ? thunk.f.scope : DefaultScope() + isa(f, Chunk) ? f.scope : DefaultScope() elseif key == :processor - isa(thunk.f, Chunk) ? thunk.f.processor : OSProc() + isa(f, Chunk) ? f.processor : OSProc() elseif key in fieldnames(Thunk) getproperty(thunk, key) elseif key in fieldnames(ThunkOptions) @@ -48,57 +61,84 @@ function get_propagated_options(thunk) end nt = merge(nt, (key=>value,)) end - nt + return nt +end + +has_result(state, thunk) = thunk.cache_ref !== nothing +load_result(state, thunk) = something(thunk.cache_ref) +function store_result!(state, thunk, value; error::Bool=false) + @assert islocked(state.lock) + @assert !thunk.finished "Thunk[$(thunk.id)] should not be finished yet" + @assert !has_result(state, thunk) "Thunk[$(thunk.id)] already contains a cached result" + thunk.finished = true + if error && value isa Exception && !(value isa ThunkFailedException) + thunk.cache_ref = Some{Any}(ThunkFailedException(thunk, thunk, value)) + else + thunk.cache_ref = Some{Any}(value) + end + state.errored[thunk] = error +end +function clear_result!(state, thunk) + @assert islocked(state.lock) + thunk.cache_ref = nothing + delete!(state.errored, thunk) end -"Fills the result for all registered futures of `node`." -function fill_registered_futures!(state, node, failed) - if haskey(state.futures, node) +"Fills the result for all registered futures of `thunk`." +function fill_registered_futures!(state, thunk, failed) + if haskey(state.futures, thunk) # Notify any listening thunks - for future in state.futures[node] - put!(future, state.cache[node]; error=failed) + for future in state.futures[thunk] + put!(future, load_result(state, thunk); error=failed) end - delete!(state.futures, node) + delete!(state.futures, thunk) end end "Cleans up any syncdeps that aren't needed any longer, and returns a `Set{Chunk}` of all chunks that can now be evicted from workers." -function cleanup_syncdeps!(state, node) - to_evict = Set{Chunk}() - for inp in node.syncdeps +function cleanup_syncdeps!(state, thunk) + #to_evict = Set{Chunk}() + for inp in thunk.syncdeps inp = unwrap_weak_checked(inp) if !istask(inp) && !(inp isa Chunk) continue end if inp in keys(state.waiting_data) w = state.waiting_data[inp] - if node in w - pop!(w, node) + if thunk in w + pop!(w, thunk) end if isempty(w) - if istask(inp) && haskey(state.cache, inp) - _node = state.cache[inp] - if _node isa Chunk - push!(to_evict, _node) + #= + if istask(inp) && has_result(state, inp) + _thunk = load_result(state, inp) + if _thunk isa Chunk + push!(to_evict, _thunk) end elseif inp isa Chunk push!(to_evict, inp) end + =# delete!(state.waiting_data, inp) + inp.sch_accessible = false + delete_unused_task!(state, inp) end end end - return to_evict + #return to_evict end "Schedules any dependents that may be ready to execute." -function schedule_dependents!(state, node, failed) - for dep in sort!(collect(get(()->Set{Thunk}(), state.waiting_data, node)), by=state.node_order) +function schedule_dependents!(state, thunk, failed) + if !haskey(state.waiting_data, thunk) || isempty(state.waiting_data[thunk]) + return + end + for dep in state.waiting_data[thunk] dep_isready = false if haskey(state.waiting, dep) set = state.waiting[dep] - node in set && pop!(set, node) + thunk in set && pop!(set, thunk) dep_isready = isempty(set) if dep_isready delete!(state.waiting, dep) @@ -118,68 +158,77 @@ end Prepares the scheduler to schedule `thunk`. Will mark `thunk` as ready if its inputs are satisfied. """ -function reschedule_syncdeps!(state, thunk, seen=Set{Thunk}()) - to_visit = Thunk[thunk] - while !isempty(to_visit) - thunk = pop!(to_visit) - push!(seen, thunk) - if haskey(state.valid, thunk) - continue - end - if haskey(state.cache, thunk) || (thunk in state.ready) || (thunk in state.running) - continue - end - for (_,input) in thunk.inputs - if input isa WeakChunk - input = unwrap_weak_checked(input) +function reschedule_syncdeps!(state, thunk, seen=nothing) + Dagger.maybe_take_or_alloc!(RESCHEDULE_SYNCDEPS_SEEN_CACHE[], seen) do seen + #=FIXME:REALLOC=# + to_visit = Thunk[thunk] + while !isempty(to_visit) + thunk = pop!(to_visit) + push!(seen, thunk) + if haskey(state.valid, thunk) + continue end - if input isa Chunk - # N.B. Different Chunks with the same DRef handle will hash to the same slot, - # so we just pick an equivalent Chunk as our upstream - if !haskey(state.waiting_data, input) - push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + if thunk.finished || (thunk in state.ready) || (thunk in state.running) + continue + end + for idx in 1:length(thunk.inputs) + input = Dagger.value(thunk.inputs[idx]) + if input isa WeakChunk + input = unwrap_weak_checked(input) + end + if input isa Chunk + # N.B. Different Chunks with the same DRef handle will hash to the same slot, + # so we just pick an equivalent Chunk as our upstream + if !haskey(state.waiting_data, input) + push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + end end end - end - w = get!(()->Set{Thunk}(), state.waiting, thunk) - for input in thunk.syncdeps - input = unwrap_weak_checked(input) - istask(input) && input in seen && continue + w = get!(()->Set{Thunk}(), state.waiting, thunk) + for input in thunk.syncdeps + input = unwrap_weak_checked(input) + istask(input) && input in seen && continue - # Unseen - push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) - istask(input) || continue + # Unseen + push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) + istask(input) || continue - # Unseen task - if get(state.errored, input, false) - set_failed!(state, input, thunk) - end - haskey(state.cache, input) && continue + # Unseen task + if get(state.errored, input, false) + set_failed!(state, input, thunk) + end + input.finished && continue - # Unseen and unfinished task - push!(w, input) - if !((input in state.running) || (input in state.ready)) - push!(to_visit, input) + # Unseen and unfinished task + push!(w, input) + if !((input in state.running) || (input in state.ready)) + push!(to_visit, input) + end end - end - if isempty(w) - # Inputs are ready - delete!(state.waiting, thunk) - if !get(state.errored, thunk, false) - push!(state.ready, thunk) + if isempty(w) + # Inputs are ready + delete!(state.waiting, thunk) + if !get(state.errored, thunk, false) + push!(state.ready, thunk) + end end end end end +const RESCHEDULE_SYNCDEPS_SEEN_CACHE = TaskLocalValue{ReusableCache{Set{Thunk},Nothing}}(()->ReusableCache(Set{Thunk}, nothing, 1)) "Marks `thunk` and all dependent thunks as failed." function set_failed!(state, origin, thunk=origin) + @assert islocked(state.lock) filter!(x->x!==thunk, state.ready) - state.cache[thunk] = ThunkFailedException(thunk, origin, state.cache[origin]) - state.errored[thunk] = true + if origin !== thunk + ex = ThunkFailedException(thunk, origin, load_result(state, origin)) + store_result!(state, thunk, ex; error=true) + end finish_failed!(state, thunk, origin) end function finish_failed!(state, thunk, origin=nothing) + @assert islocked(state.lock) fill_registered_futures!(state, thunk, true) if haskey(state.waiting_data, thunk) for dep in state.waiting_data[thunk] @@ -190,6 +239,8 @@ function finish_failed!(state, thunk, origin=nothing) origin !== nothing && set_failed!(state, origin, dep) end delete!(state.waiting_data, thunk) + thunk.sch_accessible = false + delete_unused_task!(state, thunk) end if haskey(state.waiting, thunk) delete!(state.waiting, thunk) @@ -213,7 +264,7 @@ function print_sch_status(io::IO, state, thunk; offset=0, limit=5, max_inputs=3) status *= "r" elseif node in state.running status *= "R" - elseif haskey(state.cache, node) + elseif has_result(state, node) status *= "C" else status *= "?" @@ -284,38 +335,81 @@ function report_catch_error(err, desc=nothing) write(stderr, iob) end +struct Signature + sig::Vector{Any}#DataType} + hash::UInt + sig_nokw::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int}},true} + hash_nokw::UInt + function Signature(sig::Vector{Any})#DataType}) + # Hash full signature + h = hash(Signature) + for T in sig + h = hash(T, h) + end + + # Hash non-kwarg signature + @assert isdefined(Core, :kwcall) "FIXME: No kwcall! Use kwfunc" + idx = findfirst(T->T===typeof(Core.kwcall), sig) + if idx !== nothing + # Skip NT kwargs + sig_nokw = @view sig[idx+2:end] + else + sig_nokw = @view sig[1:end] + end + h_nokw = hash(Signature, UInt(1)) + for T in sig_nokw + h_nokw = hash(T, h_nokw) + end + + return new(sig, h, sig_nokw, h_nokw) + end +end +Base.hash(sig::Signature, h::UInt) = hash(sig.hash, h) +Base.isequal(sig1::Signature, sig2::Signature) = sig1.hash == sig2.hash + chunktype(x) = typeof(x) signature(state, task::Thunk) = - signature(task.f, collect_task_inputs(state, task.inputs)) + signature(task.inputs[1], @view task.inputs[2:end]) function signature(f, args) - sig = DataType[chunktype(f)] + n_pos = count(Dagger.ispositional, args) + any_kw = any(!Dagger.ispositional, args) + kw_extra = any_kw ? 2 : 0 + sig = Vector{Any}(undef, 1+n_pos+kw_extra) + sig[1+kw_extra] = chunktype(f) + #=FIXME:REALLOC_N=# sig_kwarg_names = Symbol[] sig_kwarg_types = [] - for (pos, arg) in args - if arg isa Dagger.DTask + for idx in 1:length(args) + arg = args[idx] + value = Dagger.value(arg) + if value isa Dagger.DTask # Only occurs via manual usage of signature - arg = fetch(arg; raw=true) + value = fetch(value; raw=true) + end + if istask(value) + throw(ConcurrencyViolationError("Must call `collect_task_inputs!(state, task)` before calling `signature`")) end - T = chunktype(arg) - if pos === nothing - push!(sig, T) + T = chunktype(value) + if Dagger.ispositional(arg) + sig[1+idx+kw_extra] = T else - push!(sig_kwarg_names, pos) + push!(sig_kwarg_names, Dagger.pos_kw(arg)) push!(sig_kwarg_types, T) end end - if !isempty(sig_kwarg_names) + if any_kw NT = NamedTuple{(sig_kwarg_names...,), Base.to_tuple_type(sig_kwarg_types)} - pushfirst!(sig, NT) + sig[2] = NT @static if isdefined(Core, :kwcall) - pushfirst!(sig, typeof(Core.kwcall)) + sig[1] = typeof(Core.kwcall) else f_instance = chunktype(f).instance kw_f = Core.kwfunc(f_instance) - pushfirst!(sig, typeof(kw_f)) + sig[1] = typeof(kw_f) end end - return sig + #=FIXME:UNIQUE=# + return Signature(sig) end function can_use_proc(task, gproc, proc, opts, scope) @@ -378,7 +472,7 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) time_util[T] * 1000^3 else get(state.signature_time_cost, sig, 1000^3) - end) + end)::UInt64 est_alloc_util = if alloc_util !== nothing && haskey(alloc_util, T) alloc_util[T] else @@ -441,15 +535,17 @@ function impute_sum(xs) end "Collects all arguments for `task`, converting Thunk inputs to Chunks." -collect_task_inputs(state, task::Thunk) = - collect_task_inputs(state, task.inputs) -function collect_task_inputs(state, inputs) - new_inputs = Pair{Union{Symbol,Nothing},Any}[] - for (pos, input) in inputs +collect_task_inputs!(state, task::Thunk) = + collect_task_inputs!(state, task.inputs) +function collect_task_inputs!(state, inputs) + for idx in 1:length(inputs) + input = Dagger.value(inputs[idx]) input = unwrap_weak_checked(input) - push!(new_inputs, pos => (istask(input) ? state.cache[input] : input)) + if istask(input) + inputs[idx].value = wrap_weak(load_result(state, input)) + end end - return new_inputs + return end """ @@ -458,20 +554,26 @@ current estimated per-processor compute pressure, and transfer costs for each `Chunk` argument to `task`. Returns `(procs, costs)`, with `procs` sorted in order of ascending cost. """ -function estimate_task_costs(state, procs, task, inputs) +function estimate_task_costs(state, procs, task) + sorted_procs = Vector{Processor}(undef, length(procs)) + costs = Dict{Processor,Float64}() + estimate_task_costs!(sorted_procs, costs, state, procs, task) + return sorted_procs, costs +end +function estimate_task_costs!(sorted_procs, costs, state, procs, task) tx_rate = state.transfer_rate[] # Find all Chunks - chunks = Chunk[] - for input in inputs - if input isa Chunk - push!(chunks, input) + chunks = @reusable_vector :estimate_task_costs_chunks Union{Chunk,Nothing} nothing 32 + for input in task.inputs + if Dagger.valuetype(input) <: Chunk + push!(chunks, Dagger.value(input)::Chunk) end end - costs = Dict{Processor,Float64}() for proc in procs - chunks_filt = Iterators.filter(c->get_parent(processor(c))!=get_parent(proc), chunks) + gproc = get_parent(proc) + chunks_filt = Iterators.filter(c->get_parent(processor(c)) != gproc, chunks) # Estimate network transfer costs based on data size # N.B. `affinity(x)` really means "data size of `x`" @@ -481,18 +583,23 @@ function estimate_task_costs(state, procs, task, inputs) tx_cost = impute_sum(affinity(chunk)[2] for chunk in chunks_filt) # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(state.worker_time_pressure[get_parent(proc).pid], proc, 0) + est_time_util = get(state.worker_time_pressure[gproc.pid], proc, 0) costs[proc] = est_time_util + (tx_cost/tx_rate) end + empty!(chunks) # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(procs)) - procs = getindex.(Ref(procs), P) + np = length(procs) + @reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin + copyto!(P, 1:np) + randperm!(P) + for idx in 1:np + sorted_procs[idx] = procs[P[idx]] + end + end # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - return procs, costs + sort!(sorted_procs, by=p->costs[p]) end """ diff --git a/src/submission.jl b/src/submission.jl index 2c5ee042e..59764b613 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -1,145 +1,211 @@ +mutable struct PayloadOne + uid::UInt + future::ThunkFuture + fargs::Vector{Argument} + options::NamedTuple + reschedule::Bool + + PayloadOne() = new() + PayloadOne(uid::UInt, future::ThunkFuture, fargs::Vector{Argument}, options::NamedTuple, reschedule::Bool) = + new(uid, future, fargs, options, reschedule) +end +mutable struct PayloadMulti + ntasks::Int + uid::Vector{UInt} + future::Vector{ThunkFuture} + fargs::Vector{Vector{Argument}} + options::Vector{NamedTuple} + reschedule::Bool +end +const AnyPayload = Union{PayloadOne, PayloadMulti} +function payload_extract(f, payload::PayloadMulti, i::Integer) + take_or_alloc!(PAYLOAD_ONE_CACHE[]; no_alloc=true) do p1 + p1.uid = payload.uid[i] + p1.future = payload.future[i] + p1.fargs = payload.fargs[i] + p1.options = payload.options[i] + p1.reschedule = true + return f(p1) + end +end +const PAYLOAD_ONE_CACHE = TaskLocalValue{ReusableCache{PayloadOne,Nothing}}(()->ReusableCache(PayloadOne, nothing, 1)) + +const THUNK_SPEC_CACHE = TaskLocalValue{ReusableCache{ThunkSpec,Nothing}}(()->ReusableCache(ThunkSpec, nothing, 1)) + # Remote -function eager_submit_internal!(@nospecialize(payload)) +function eager_submit_internal!(payload::AnyPayload) ctx = Dagger.Sch.eager_context() state = Dagger.Sch.EAGER_STATE[] task = current_task() tid = 0 return eager_submit_internal!(ctx, state, task, tid, payload) end -function eager_submit_internal!(ctx, state, task, tid, payload; uid_to_tid=Dict{UInt64,Int}()) - @nospecialize payload - ntasks, uid, future, ref, f, args, options, reschedule = payload - - if uid isa Vector - thunk_ids = Sch.ThunkID[] - for i in 1:ntasks - tid = eager_submit_internal!(ctx, state, task, tid, - (1, uid[i], future[i], ref[i], - f[i], args[i], options[i], - false); uid_to_tid) - push!(thunk_ids, tid) - uid_to_tid[uid[i]] = tid.id +eager_submit_internal!(ctx, state, task, tid, payload::Tuple{<:AnyPayload}) = + eager_submit_internal!(ctx, state, task, tid, payload[1]) +const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}}(()->ReusableCache(Dict{UInt64,Int}, nothing, 1)) +function eager_submit_internal!(ctx, state, task, tid, payload::AnyPayload; uid_to_tid=nothing) + maybe_take_or_alloc!(UID_TO_TID_CACHE[], uid_to_tid; no_alloc=true) do uid_to_tid + if payload isa PayloadMulti + thunk_ids = Sch.ThunkID[] + for i in 1:ntasks + tid = payload_extract(payload, i) do p1 + eager_submit_internal!(ctx, state, task, tid, p1; + uid_to_tid) + end + push!(thunk_ids, tid) + uid_to_tid[payload.uid[i]] = tid.id + end + put!(state.chan, Sch.RescheduleSignal()) + return thunk_ids end - put!(state.chan, Sch.RescheduleSignal()) - return thunk_ids - end + payload::PayloadOne - id = next_id() + uid, future = payload.uid, payload.future + fargs, options, reschedule = payload.fargs, payload.options, payload.reschedule - timespan_start(ctx, :add_thunk, (;thunk_id=id), (;f, args, options)) + id = next_id() - # Lookup DTask/ThunkID -> Thunk - old_args = copy(args) - args::Vector{Any} - syncdeps = if haskey(options, :syncdeps) - collect(options.syncdeps) - else - nothing - end::Union{Vector{Any},Nothing} - lock(Sch.EAGER_ID_MAP) do id_map - for (idx, (pos, arg)) in enumerate(args) - pos::Union{Symbol,Nothing} - newarg = if arg isa DTask - arg_uid = arg.uid - arg_tid = if haskey(id_map, arg_uid) - id_map[arg_uid] - else - uid_to_tid[arg_uid] - end - state.thunk_dict[arg_tid] - elseif arg isa Sch.ThunkID - arg_tid = arg.id - state.thunk_dict[arg_tid] - elseif arg isa Chunk - # N.B. Different Chunks with the same DRef handle will hash to the same slot, - # so we just pick an equivalent Chunk as our upstream - if haskey(state.waiting_data, arg) - newarg = nothing - for other in keys(state.waiting_data) - if other isa Chunk && other.handle == arg.handle - newarg = other - break + @maybelog ctx timespan_start(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options)) + + old_fargs = @reusable_vector :eager_submit_internal!_old_fargs Argument Argument(ArgPosition(), nothing) 32 + append!(old_fargs, Iterators.map(copy, fargs)) + syncdeps_vec = @reusable_vector :eager_submit_interal!_syncdeps_vec Any nothing 32 + if haskey(options, :syncdeps) + append!(syncdeps_vec, options.syncdeps) + end + + # Lookup DTask/ThunkID -> Thunk + lock(Sch.EAGER_ID_MAP) do id_map + for (idx, arg) in enumerate(fargs) + if valuetype(arg) <: DTask + arg_uid = (value(arg)::DTask).uid + arg_tid = if haskey(id_map, arg_uid) + id_map[arg_uid] + else + uid_to_tid[arg_uid] + end + @inbounds fargs[idx] = Argument(arg.pos, state.thunk_dict[arg_tid]) + elseif valuetype(arg) <: Sch.ThunkID + arg_tid = (value(arg)::Sch.ThunkID).id + @inbounds fargs[idx] = Argument(arg.pos, state.thunk_dict[arg_tid]) + elseif valuetype(arg) <: Chunk + # N.B. Different Chunks with the same DRef handle will hash to the same slot, + # so we just pick an equivalent Chunk as our upstream + chunk = value(arg)::Chunk + if haskey(state.waiting_data, chunk) + newchunk = nothing + for other in keys(state.waiting_data) + if other isa Chunk && other.handle == chunk.handle + newchunk = other + break + end end + @assert newchunk !== nothing + chunk = newchunk::Chunk end - @assert newarg !== nothing - arg = newarg::Chunk + #=FIXME:UNIQUE=# + @inbounds fargs[idx] = Argument(arg.pos, WeakChunk(chunk)) + end + end + # TODO: Iteration protocol would be faster + for idx in 1:length(syncdeps_vec) + dep = syncdeps_vec[idx] + if dep isa DTask + tid = if haskey(id_map, dep.uid) + id_map[dep.uid] + else + uid_to_tid[dep.uid] + end + @inbounds syncdeps_vec[idx] = state.thunk_dict[tid] + elseif dep isa Sch.ThunkID + tid = dep.id + @inbounds syncdeps_vec[idx] = state.thunk_dict[tid] end - WeakChunk(arg) - else - arg end - @inbounds args[idx] = pos => newarg - end - if syncdeps === nothing - return end - for (idx, dep) in enumerate(syncdeps) - newdep = if dep isa DTask - tid = if haskey(id_map, dep.uid) - id_map[dep.uid] - else - uid_to_tid[dep.uid] + #=FIXME:REALLOC=# + if !isempty(syncdeps_vec) || any(arg->istask(value(arg)), fargs) + syncdeps = Set{Any}(syncdeps_vec) + for arg in fargs + if istask(value(arg)) + push!(syncdeps, value(arg)) end - state.thunk_dict[tid] - elseif dep isa Sch.ThunkID - tid = dep.id - state.thunk_dict[tid] - else - dep end - @inbounds syncdeps[idx] = newdep + else + syncdeps = EMPTY_SYNCDEPS end - end - if syncdeps !== nothing - options = merge(options, (;syncdeps)) - end - GC.@preserve old_args args begin - # Create the `Thunk` - thunk = Thunk(f, args...; id, options...) + GC.@preserve old_fargs fargs begin + # Create the `Thunk` + thunk = take_or_alloc!(THUNK_SPEC_CACHE[]; no_alloc=true) do thunk_spec + thunk_spec.fargs = fargs + thunk_spec.syncdeps = syncdeps + thunk_spec.id = id + thunk_spec.get_result = get(options, :get_result, false) + thunk_spec.meta = get(options, :meta, false) + thunk_spec.persist = get(options, :persist, false) + thunk_spec.cache = get(options, :cache, false) + new_options = (;filter(opt->hasfield(Sch.ThunkOptions, opt[1]), Base.pairs(options))...) + toptions = Sch.ThunkOptions(; new_options...) + #= FIXME: Allow in-place options updates + for field in keys(options) + if hasfield(Sch.ThunkOptions, field) + setproperty!(toptions, field, getproperty(options, field)) + end + end + =# + thunk_spec.options = toptions + thunk_spec.propagates = get(options, :propagates, ()) + return Thunk(thunk_spec) + end - # Create a `DRef` to `thunk` so that the caller can preserve it - thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice(), - destructor=UnrefThunkByUser(thunk)) - thunk_id = Sch.ThunkID(thunk.id, thunk_ref) + # Create a `DRef` to `thunk` so that the caller can preserve it + thunk_ref = poolset(thunk; size=64, device=MemPool.CPURAMDevice(), + destructor=UnrefThunkByUser(uid, thunk)) + #=FIXME:UNIQUE=# + thunk_id = Sch.ThunkID(thunk.id, thunk_ref) - # Attach `thunk` within the scheduler - state.thunk_dict[thunk.id] = WeakThunk(thunk) - Sch.reschedule_syncdeps!(state, thunk) - @dagdebug thunk :submit "Added to scheduler" - if future !== nothing - # Ensure we attach a future before the thunk is scheduled - Sch._register_future!(ctx, state, task, tid, (future, thunk_id, false)) - @dagdebug thunk :submit "Registered future" - end - if ref !== nothing - # Preserve the `DTaskFinalizer` through `thunk` - thunk.eager_ref = ref - end - state.valid[thunk] = nothing + # Attach `thunk` within the scheduler + state.thunk_dict[thunk.id] = WeakThunk(thunk) + #=FIXME:REALLOC=# + Sch.reschedule_syncdeps!(state, thunk) + empty!(old_fargs) # reschedule_syncdeps! preserves all referenced tasks/chunks + @dagdebug thunk :submit "Added to scheduler" + if future !== nothing + # Ensure we attach a future before the thunk is scheduled + Sch._register_future!(ctx, state, task, tid, (future, thunk_id, false)) + @dagdebug thunk :submit "Registered future" + end + state.valid[thunk] = nothing - # Register Eager UID -> Sch TID - lock(Sch.EAGER_ID_MAP) do id_map - id_map[uid] = thunk.id - end + # Register Eager UID -> Sch TID + lock(Sch.EAGER_ID_MAP) do id_map + id_map[uid] = thunk.id + end - # Tell the scheduler that it has new tasks to schedule - if reschedule - put!(state.chan, Sch.RescheduleSignal()) - end + # Tell the scheduler that it has new tasks to schedule + if reschedule + put!(state.chan, Sch.RescheduleSignal()) + end - timespan_finish(ctx, :add_thunk, (;thunk_id=id), (;f, args, options)) + @maybelog ctx timespan_finish(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options)) - return thunk_id + return thunk_id + end end end struct UnrefThunkByUser + uid::UInt thunk::Thunk end function (unref::UnrefThunkByUser)() - Sch.errormonitor_tracked("unref thunk $(unref.thunk.id)", Threads.@spawn begin - # This thunk is no longer referenced by the user, mark it as ready to be - # cleaned up as eagerly as possible (or do so now) + Sch.errormonitor_tracked("unref DTask $(unref.uid) => Thunk $(unref.thunk.id)", Threads.@spawn begin + lock(Sch.EAGER_ID_MAP) do id_map + delete!(id_map, unref.uid) + end + + # The associated DTask is no longer referenced by the user, so mark the + # thunk as ready to be cleaned up as eagerly as possible (or do so now) thunk = unref.thunk state = Sch.EAGER_STATE[] if state === nothing @@ -147,27 +213,23 @@ function (unref::UnrefThunkByUser)() end @lock state.lock begin - if !Sch.delete_unused_task!(state, thunk) - # Register for deletion upon thunk completion - push!(state.thunks_to_delete, thunk) - end - # TODO: On success, walk down to children, as a fast-path + thunk.eager_accessible = false + Sch.delete_unused_task!(state, thunk) end end) end # Local -> Remote -function eager_submit!(ntasks, uid, future, finalizer_ref, f, args, options) +function eager_submit!(payload::AnyPayload) if Dagger.in_thunk() h = Dagger.sch_handle() - return exec!(eager_submit_internal!, h, ntasks, uid, future, finalizer_ref, f, args, options, true) + return exec!(eager_submit_internal!, h, payload) elseif myid() != 1 - return remotecall_fetch(1, (ntasks, uid, future, finalizer_ref, f, args, options, true)) do payload - @nospecialize payload + return remotecall_fetch(1, payload) do payload Sch.init_eager() state = Dagger.Sch.EAGER_STATE[] - lock(state.lock) do + @lock state.lock begin eager_submit_internal!(payload) end end @@ -175,33 +237,29 @@ function eager_submit!(ntasks, uid, future, finalizer_ref, f, args, options) Sch.init_eager() state = Dagger.Sch.EAGER_STATE[] return lock(state.lock) do - eager_submit_internal!((ntasks, uid, future, finalizer_ref, - f, args, options, - true)) + eager_submit_internal!(payload) end end end # Submission -> Local -function eager_process_elem_submission_to_local(id_map, x) - @nospecialize x - @assert !isa(x, Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" - if x isa Dagger.DTask && haskey(id_map, x.uid) - return Sch.ThunkID(id_map[x.uid], x.thunk_ref) - else - return x +function eager_process_elem_submission_to_local!(id_map, arg::Argument) + T = valuetype(arg) + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -# TODO: This can probably operate in-place -function eager_process_args_submission_to_local(id_map, spec::Pair{DTaskSpec,DTask}) - return Base.mapany(first(spec).args) do pos_x - pos, x = pos_x - return pos => eager_process_elem_submission_to_local(id_map, x) +function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) + spec, task = spec_pair + for arg in spec.fargs + eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local(id_map, specs::Vector{Pair{DTaskSpec,DTask}}) - return Base.mapany(specs) do spec - eager_process_args_submission_to_local(id_map, spec) +function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) + for spec_pair in spec_pairs + eager_process_args_submission_to_local!(id_map, spec_pair) end end function eager_process_options_submission_to_local(id_map, options::NamedTuple) @@ -210,7 +268,13 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) raw_syncdeps = options.syncdeps syncdeps = Set{Any}() for raw_dep in raw_syncdeps - push!(syncdeps, eager_process_elem_submission_to_local(id_map, raw_dep)) + if raw_dep isa DTask + push!(syncdeps, Sch.ThunkID(id_map[raw_dep.uid], raw_dep.thunk_ref)) + elseif raw_dep isa Sch.ThunkID + push!(syncdeps, raw_dep) + else + error("Invalid syncdep type: $(typeof(raw_dep))") + end end return merge(options, (;syncdeps)) else @@ -221,42 +285,44 @@ function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() - finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Return unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future) end function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Lookup DTask -> ThunkID - local args, options + local options lock(Sch.EAGER_ID_MAP) do id_map - args = eager_process_args_submission_to_local(id_map, spec=>task) + eager_process_args_submission_to_local!(id_map, spec=>task) options = eager_process_options_submission_to_local(id_map, spec.options) end # Submit the task - thunk_id = eager_submit!(1, - task.uid, task.future, task.finalizer_ref, - spec.f, args, options) + #=FIXME:REALLOC=# + thunk_id = eager_submit!(PayloadOne(task.uid, task.future, + spec.fargs, options, true)) task.thunk_ref = thunk_id.ref end function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) ntasks = length(specs) + #=FIXME:REALLOC_N=# uids = [task.uid for (_, task) in specs] futures = [task.future for (_, task) in specs] - finalizer_refs = [task.finalizer_ref for (_, task) in specs] # Get all functions, args/kwargs, and options - all_fs = Any[spec.f for (spec, _) in specs] - all_args = lock(Sch.EAGER_ID_MAP) do id_map + #=FIXME:REALLOC_N=# + all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local(id_map, specs) + eager_process_args_submission_to_local!(id_map, specs) + [spec.fargs for (spec, _) in specs] end all_options = Any[spec.options for (spec, _) in specs] # Submit the tasks - thunk_ids = eager_submit!(ntasks, uids, futures, finalizer_refs, all_fs, all_args, all_options) + #=FIXME:REALLOC=# + thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, + all_fargs, all_options)) for i in 1:ntasks task = specs[i][2] task.thunk_ref = thunk_ids[i].ref diff --git a/src/task-tls.jl b/src/task-tls.jl index fb42dfbc9..7623b4372 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,21 +1,28 @@ # In-Thunk Helpers """ - thunk_processor() + thunk_processor() -> Processor Get the current processor executing the current thunk. """ thunk_processor() = task_local_storage(:_dagger_processor)::Processor """ - in_thunk() + in_thunk() -> Bool Returns `true` if currently in a [`Thunk`](@ref) process, else `false`. """ in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid) """ - get_tls() + thunk_logging_enabled() -> Bool + +Returns `true` if logging is enabled for the current thunk, else `false`. +""" +thunk_logging_enabled() = task_local_storage(:_dagger_logging_enabled) + +""" + get_tls() -> NamedTuple Gets all Dagger TLS variable as a `NamedTuple`. """ @@ -24,10 +31,11 @@ get_tls() = ( sch_handle=task_local_storage(:_dagger_sch_handle), processor=thunk_processor(), task_spec=task_local_storage(:_dagger_task_spec), + logging_enabled=thunk_logging_enabled(), ) """ - set_tls!(tls) + set_tls!(tls::NamedTuple) Sets all Dagger TLS variables from the `NamedTuple` `tls`. """ @@ -36,4 +44,5 @@ function set_tls!(tls) task_local_storage(:_dagger_sch_handle, tls.sch_handle) task_local_storage(:_dagger_processor, tls.processor) task_local_storage(:_dagger_task_spec, tls.task_spec) + task_local_storage(:_dagger_logging_enabled, tls.logging_enabled) end diff --git a/src/threadproc.jl b/src/threadproc.jl index e4f25363f..1a8fb8625 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -16,7 +16,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n result = Ref{Any}() task = Task() do set_tls!(tls) - TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) + if thunk_logging_enabled() + TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id) + end result[] = @invokelatest f(args...; kwargs...) return end diff --git a/src/thunk.jl b/src/thunk.jl index cea8a9d60..33e9fa433 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -3,14 +3,33 @@ export Thunk, delayed const ID_COUNTER = Threads.Atomic{Int}(1) next_id() = Threads.atomic_add!(ID_COUNTER, 1) -function filterany(f::Base.Callable, xs) - xs_filt = Any[] - for x in xs - if f(x) - push!(xs_filt, x) - end - end - return xs_filt +const EMPTY_ARGS = Argument[] +const EMPTY_SYNCDEPS = Set{Any}() +Base.@kwdef mutable struct ThunkSpec + fargs::Vector{Argument} = EMPTY_ARGS + syncdeps::Set{Any} = EMPTY_SYNCDEPS + id::Int = 0 + get_result::Bool = false + meta::Bool = false + persist::Bool = false + cache::Bool = false + cache_ref::Any = nothing + affinity::Union{Pair{OSProc,Int}, Nothing} = nothing + options::Any#=FIXME:ThunkOptions=# = nothing + propagates::Tuple = () +end +function unset!(spec::ThunkSpec, _) + spec.fargs = EMPTY_ARGS + spec.syncdeps = EMPTY_SYNCDEPS + spec.id = 0 + spec.get_result = false + spec.meta = false + spec.persist = false + spec.cache = false + spec.cache_ref = nothing + spec.affinity = nothing + spec.options = nothing + spec.propagates = () end """ @@ -35,23 +54,22 @@ julia> collect(t) # computes the result and returns it to the current process ``` ## Arguments -- `f`: The function to be called upon execution of the `Thunk`. -- `args`: The arguments to be passed to the `Thunk`. +- `fargs`: The function and arguments to be called upon execution of the `Thunk`. - `kwargs`: The properties describing unique behavior of this `Thunk`. Details for each property are described in the next section. - `option=value`: The same as passing `kwargs` to `delayed`. ## Public Properties - `meta::Bool=false`: If `true`, instead of fetching cached arguments from -`Chunk`s and passing the raw arguments to `f`, instead pass the `Chunk`. Useful -for doing manual fetching or manipulation of `Chunk` references. Non-`Chunk` -arguments are still passed as-is. -- `processor::Processor=OSProc()` - The processor associated with `f`. Useful if -`f` is a callable struct that exists on a given processor and should be -transferred appropriately. -- `scope::Dagger.AbstractScope=DefaultScope()` - The scope associated with `f`. -Useful if `f` is a function or callable struct that may only be transferred to, -and executed within, the specified scope. +`Chunk`s and passing the raw arguments to the called function, instead pass the +`Chunk`. Useful for doing manual fetching or manipulation of `Chunk` +references. Non-`Chunk` arguments are still passed as-is. - +`processor::Processor=OSProc()` - The processor associated with the called +function. Useful if the called function is a callable struct that exists on a +given processor and should be transferred appropriately. - +`scope::Dagger.AbstractScope=DefaultScope()` - The scope associated with the +called function. Useful if the called function is a callable struct that may +only be transferred to, and executed within, the specified scope. ## Options - `options`: A `Sch.ThunkOptions` struct providing the options for the `Thunk`. @@ -59,8 +77,7 @@ If omitted, options can also be specified by passing key-value pairs as `kwargs`. """ mutable struct Thunk - f::Any # usually a Function, but could be any callable - inputs::Vector{Pair{Union{Symbol,Nothing},Any}} # TODO: Use `ImmutableArray` in 1.8 + inputs::Vector{Argument} # TODO: Use `ImmutableArray` in 1.8 syncdeps::Set{Any} id::Int get_result::Bool # whether the worker should send the result or only the metadata @@ -69,52 +86,95 @@ mutable struct Thunk cache::Bool # release the result giving the worker an opportunity to # cache it cache_ref::Any - affinity::Union{Nothing, Pair{OSProc, Int}} - eager_ref::Union{DRef,Nothing} + affinity::Union{Pair{OSProc,Int}, Nothing} options::Any # stores scheduler-specific options propagates::Tuple # which options we'll propagate - function Thunk(f, xs...; - syncdeps=nothing, - id::Int=next_id(), - get_result::Bool=false, - meta::Bool=false, - persist::Bool=false, - cache::Bool=false, - cache_ref=nothing, - affinity=nothing, - eager_ref=nothing, - processor=nothing, - scope=nothing, - options=nothing, - propagates=(), - kwargs... - ) - if !isa(f, Chunk) && (!isnothing(processor) || !isnothing(scope)) - f = tochunk(f, - something(processor, OSProc()), - something(scope, DefaultScope())) + eager_accessible::Bool + sch_accessible::Bool + finished::Bool + function Thunk(spec::ThunkSpec) + return new(spec.fargs, + spec.syncdeps, spec.id, + spec.get_result, spec.meta, spec.persist, spec.cache, + spec.cache_ref, spec.affinity, + spec.options, spec.propagates, + true, true, false) + end +end +function Thunk(f, xs...; + syncdeps=nothing, + id::Int=next_id(), + get_result::Bool=false, + meta::Bool=false, + persist::Bool=false, + cache::Bool=false, + cache_ref=nothing, + affinity=nothing, + processor=nothing, + scope=nothing, + options=nothing, + propagates=(), + kwargs... + ) + + spec = ThunkSpec() + if !(f isa Argument) + f = Argument(ArgPosition(true, 0, :NULL), f) + end + if !(valuetype(f) <: Chunk) && (!isnothing(processor) || !isnothing(scope)) + f.value = tochunk(value(f), + something(processor, OSProc()), + something(scope, DefaultScope()); rewrap=true) + end + spec.fargs = Vector{Argument}(undef, length(xs)+1) + spec.fargs[1] = f + for idx in 1:length(xs) + x = xs[idx] + if x isa Argument + spec.fargs[idx+1] = x + else + @assert x isa Pair "Invalid Thunk argument: $x" + spec.fargs[idx+1] = Argument(something(x.first, idx), x.second) end - xs = Base.mapany(identity, xs) - syncdeps_set = Set{Any}(filterany(is_task_or_chunk, Base.mapany(last, xs))) - if syncdeps !== nothing - for dep in syncdeps - push!(syncdeps_set, dep) - end + end + syncdeps_set = Set{Any}() + for idx in 2:length(spec.fargs) + x = value(spec.fargs[idx]) + if is_task_or_chunk(x) + push!(syncdeps_set, x) end - @assert all(x->x isa Pair, xs) - if options !== nothing - @assert isempty(kwargs) - new(f, xs, syncdeps_set, id, get_result, meta, persist, cache, - cache_ref, affinity, eager_ref, options, propagates) - else - new(f, xs, syncdeps_set, id, get_result, meta, persist, cache, - cache_ref, affinity, eager_ref, Sch.ThunkOptions(;kwargs...), - propagates) + end + if syncdeps !== nothing + for dep in syncdeps + push!(syncdeps_set, dep) end end + spec.syncdeps = syncdeps_set + spec.id = id + spec.get_result = get_result + spec.meta = meta + spec.persist = persist + spec.cache = cache + spec.cache_ref = cache_ref + spec.affinity = affinity + if options !== nothing + @assert isempty(kwargs) + spec.options = options::Sch.ThunkOptions + else + spec.options = Sch.ThunkOptions(;kwargs...) + end + spec.propagates = propagates + return Thunk(spec) end Serialization.serialize(io::AbstractSerializer, t::Thunk) = throw(ArgumentError("Cannot serialize a Thunk")) +function Base.getproperty(thunk::Thunk, field::Symbol) + if field == :f + return unwrap_weak_checked(value(first(thunk.inputs))) + else + return getfield(thunk, field) + end +end function affinity(t::Thunk) if t.affinity !== nothing @@ -152,13 +212,32 @@ end is_task_or_chunk(x) = istask(x) -function args_kwargs_to_pairs(args, kwargs) - args_kwargs = Pair{Union{Symbol,Nothing},Any}[] - for arg in args - push!(args_kwargs, nothing => arg) +function args_kwargs_to_arguments(f, args, kwargs) + @nospecialize f args kwargs + args_kwargs = Argument[] + push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) + for idx in 1:length(args) + arg = args[idx] + push!(args_kwargs, Argument(idx, arg)) + end + for (kw, value) in kwargs + push!(args_kwargs, Argument(kw, value)) end - for kwarg in kwargs - push!(args_kwargs, kwarg[1] => kwarg[2]) + return args_kwargs +end +function args_kwargs_to_arguments(f, args) + @nospecialize f args + args_kwargs = Argument[] + push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) + pos_ctr = 1 + for idx in 1:length(args) + pos, arg = args[idx]::Pair + if pos === nothing + push!(args_kwargs, Argument(pos_ctr, arg)) + pos_ctr += 1 + else + push!(args_kwargs, Argument(pos, arg)) + end end return args_kwargs end @@ -172,7 +251,7 @@ Creates a [`Thunk`](@ref) object which can be executed later, which will call resulting `Thunk`. """ function _delayed(f, options::Options) - (args...; kwargs...) -> Thunk(f, args_kwargs_to_pairs(args, kwargs)...; options.options...) + (args...; kwargs...) -> Thunk(args_kwargs_to_arguments(f, args, kwargs)...; options.options...) end function delayed(f, options::Options) @warn "`delayed` is deprecated. Use `Dagger.@spawn` or `Dagger.spawn` instead." maxlog=1 @@ -195,22 +274,31 @@ function unwrap_weak_checked(t::WeakThunk) t end unwrap_weak_checked(t) = t +wrap_weak(t::Thunk) = WeakThunk(t) +wrap_weak(t::WeakThunk) = t +wrap_weak(t) = t +isweak(t::WeakThunk) = true +isweak(t::Thunk) = false +isweak(t) = true Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) +chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) "A summary of the data contained in a Thunk, which can be safely serialized." struct ThunkSummary id::Int - f - inputs::Vector{Pair{Union{Symbol,Nothing},Any}} + inputs::Vector{Argument} end inputs(t::ThunkSummary) = t.inputs Base.show(io::IO, t::ThunkSummary) = show_thunk(io, t) function Base.convert(::Type{ThunkSummary}, t::Thunk) - return ThunkSummary(t.id, - t.f, - map(pos_inp->istask(pos_inp[2]) ? pos_inp[1]=>convert(ThunkSummary, pos_inp[2]) : pos_inp, - t.inputs)) + args = map(copy, t.inputs) + for arg in args + if istask(value(arg)) + arg.value = convert(ThunkSummary, value(arg)) + end + end + return ThunkSummary(t.id, args) end function Base.convert(::Type{ThunkSummary}, t::WeakThunk) t = unwrap_weak(t) @@ -244,17 +332,19 @@ function Base.showerror(io::IO, ex::ThunkFailedException) function thunk_string(t) Tinputs = Any[] - for (_, input) in t.inputs - if istask(input) - push!(Tinputs, "Thunk(id=$(input.id))") + for input in @view t.inputs[2:end] + x = value(input) + if istask(x) + push!(Tinputs, "Thunk(id=$(x.id))") else - push!(Tinputs, input) + push!(Tinputs, x) end end + f = value(t.inputs[1]) t_sig = if length(Tinputs) <= 4 - "$(t.f)($(join(Tinputs, ", ")))" + "$(f)($(join(Tinputs, ", ")))" else - "$(t.f)($(length(Tinputs)) inputs...)" + "$(f)($(length(Tinputs)) inputs...)" end return "Thunk(id=$(t.id), $t_sig)" end @@ -461,14 +551,14 @@ function spawn(f, args...; kwargs...) # Wrap f in a Chunk if necessary processor = haskey(options, :processor) ? options.processor : nothing scope = haskey(options, :scope) ? options.scope : nothing - if !isnothing(processor) || !isnothing(scope) + if !isa(f, Chunk) && !isnothing(processor) || !isnothing(scope) f = tochunk(f, something(processor, get_options(:processor, OSProc())), - something(scope, get_options(:scope, DefaultScope()))) + something(scope, get_options(:scope, DefaultScope())); rewrap=true) end # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_pairs(args, kwargs) + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) # Get task queue, and don't let it propagate task_queue = get_options(:task_queue, DefaultTaskQueue()) @@ -477,7 +567,7 @@ function spawn(f, args...; kwargs...) options = merge(options, (;propagates)) # Construct task spec and handle - spec = DTaskSpec(f, args_kwargs, options) + spec = DTaskSpec(args_kwargs, options) task = eager_spawn(spec) # Enqueue the task into the task queue @@ -503,47 +593,32 @@ fetch_all(x) = Adapt.adapt(FetchAdaptor(), x) persist!(t::Thunk) = (t.persist=true; t) cache_result!(t::Thunk) = (t.cache=true; t) -# @generated function compose{N}(f, g, t::NTuple{N}) -# if N <= 4 -# ( :(()->f(g())), -# :((a)->f(g(a))), -# :((a,b)->f(g(a,b))), -# :((a,b,c)->f(g(a,b,c))), -# :((a,b,c,d)->f(g(a,b,c,d))), )[N+1] -# else -# :((xs...) -> f(g(xs...))) -# end -# end - -# function Thunk(f::Function, t::Tuple{Thunk}) -# g = compose(f, t[1].f, t[1].inputs) -# Thunk(g, t[1].inputs) -# end - # this gives a ~30x speedup in hashing Base.hash(x::Thunk, h::UInt) = hash(x.id, hash(h, 0x7ad3bac49089a05f % UInt)) Base.isequal(x::Thunk, y::Thunk) = x.id==y.id function show_thunk(io::IO, t) lvl = get(io, :lazy_level, 0) - f = if t.f isa Chunk - Tf = t.f.chunktype + f = value(first(t.inputs)) + f = if f isa Chunk + Tf = f.chunktype if isdefined(Tf, :instance) Tf.instance else "instance of $Tf" end else - t.f + f end print(io, "Thunk[$(t.id)]($f, ") if lvl > 0 t_inputs = Any[] - for (pos, input) in inputs(t) - if pos === nothing + for arg in inputs(t)[2:end] + input = value(arg) + if ispositional(arg) push!(t_inputs, input) else - push!(t_inputs, pos => input) + push!(t_inputs, pos_kw(arg) => input) end end show(IOContext(io, :lazy_level => lvl-1), t_inputs) diff --git a/src/utils/logging-events.jl b/src/utils/logging-events.jl index f2de153b1..ad07fdd3d 100644 --- a/src/utils/logging-events.jl +++ b/src/utils/logging-events.jl @@ -199,7 +199,7 @@ function (::TaskDependencies)(ev::Event{:start}) end if ev.category == :add_thunk deps_tids = Int[] - get_deps!(Iterators.filter(Dagger.istask, Iterators.map(last, ev.timeline.args))) + get_deps!(Iterators.filter(Dagger.istask, Iterators.map(Dagger.value, ev.timeline.args))) get_deps!(get(Set, ev.timeline.options, :syncdeps)) return ev.id.thunk_id => deps_tids end diff --git a/src/utils/logging.jl b/src/utils/logging.jl index 9b923c526..d3d15dffe 100644 --- a/src/utils/logging.jl +++ b/src/utils/logging.jl @@ -76,6 +76,15 @@ these logs. """ fetch_logs!() = TimespanLogging.get_logs!(Dagger.Sch.eager_context()) +# Convenience macros to reduce allocations when logging is disabled +macro maybelog(ctx, ex) + quote + if !($(esc(ctx)).log_sink isa $(TimespanLogging.NoOpLog)) + $(esc(ex)) + end + end +end + function logs_event_pairs(f, logs::Dict) running_events = Dict{Tuple,Int}() for w in keys(logs) @@ -104,7 +113,7 @@ end Associates an argument `arg` with `name` in the logs, which logs renderers may utilize for display purposes. """ -function logs_annotate!(ctx::Context, arg, name::Union{String,Symbol}) +function logs_annotate!(ctx#=::Context=#, arg, name::Union{String,Symbol}) ismutable(arg) || throw(ArgumentError("Argument must be mutable to be annotated")) Dagger.TimespanLogging.timespan_start(ctx, :data_annotation, (;objectid=objectid(arg), name), nothing) # TODO: Remove redundant log event diff --git a/src/utils/reuse.jl b/src/utils/reuse.jl new file mode 100644 index 000000000..9aa79e319 --- /dev/null +++ b/src/utils/reuse.jl @@ -0,0 +1,542 @@ +struct ReusableCache{T,Tnull} + cache::Vector{T} + used::Vector{Bool} + null::Tnull + sized::Bool + function ReusableCache(T, null, N::Integer; sized::Bool=false) + @assert !Base.datatype_pointerfree(T) "ReusableCache is only useful for non-pointerfree types (got $T)" + #cache = [T() for _ in 1:N] + cache = Vector{T}(undef, N) + used = zeros(Bool, N) + return new{T,typeof(null)}(cache, used, null, sized) + end +end +function maybetake!(cache::ReusableCache{T}, len=nothing) where T + for idx in 1:length(cache.used) + cache.used[idx] && continue + if cache.sized && isassigned(cache.cache, idx) && length(cache.cache[idx]) != len + @debug "Skipping length $(length(cache.cache[idx])) (want length $len) @ $idx" + continue + end + cache.used[idx] = true + if !isassigned(cache.cache, idx) + if cache.sized + @debug "Allocating length $len @ $idx" + cache.cache[idx] = alloc!(T, len) + else + cache.cache[idx] = alloc!(T) + end + end + return (idx, cache.cache[idx]) + end + return nothing +end +function putback!(cache::ReusableCache{T}, idx::Integer) where T + cache.used[idx] = false +end +function take_or_alloc!(f::Function, cache::ReusableCache{T}, len=nothing; no_alloc::Bool=false) where T + idx_value = maybetake!(cache, len) + if idx_value !== nothing + idx, value = idx_value + try + return f(value) + finally + unset!(value, cache.null) + putback!(cache, idx) + end + else + if no_alloc + error("No more entries available in cache for type $T") + end + return f(T()) + end +end +function maybe_take_or_alloc!(f::Function, cache::ReusableCache{T}, value::Union{T,Nothing}, len=nothing; no_alloc::Bool=false) where T + if value !== nothing + return f(value) + else + return take_or_alloc!(f, cache, len; no_alloc=no_alloc) + end +end + +alloc!(::Type{V}, n::Integer) where V<:Vector = V(undef, n) +alloc!(::Type{D}) where {D<:Dict} = D() +alloc!(::Type{S}) where {S<:Set} = S() +alloc!(::Type{T}) where T = T() + +unset!(v::Vector, null) = fill!(v, null) +# FIXME: Inefficient to use these +unset!(d::Dict, _) = empty!(d) +unset!(s::Set, _) = empty!(s) + +macro take_or_alloc!(cache, T, var, ex) + @gensym idx_value idx + quote + $idx_value = $maybetake!($(esc(cache))) + if $idx_value !== nothing + $idx, $(esc(var)) = $idx_value + try + $(esc(ex)) + finally + $unset!($(esc(var)), $(esc(cache)).null) + $putback!($(esc(cache)), $idx) + end + else + #=let=# $(esc(var)) = $(esc(T))() + $(esc(ex)) + #end + end + end +end +macro take_or_alloc!(cache, T, len, var, ex) + @gensym idx_value idx + quote + $idx_value = $maybetake!($(esc(cache)), $(esc(len))) + if $idx_value !== nothing + $idx, $(esc(var)) = $idx_value + try + $(esc(ex)) + finally + $unset!($(esc(var)), $(esc(cache)).null) + $putback!($(esc(cache)), $idx) + end + else + #=let=# $(esc(var)) = $(esc(T))() + $(esc(ex)) + #end + end + end +end +# TODO: const causes issues with Revise +macro reusable(name, T, null, N, var, ex) + cache_name = Symbol("__$(name)_reuse_cache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = TaskLocalValue{ReusableCache{$T}}(()->ReusableCache($T, $null, $N)))) + end + quote + @take_or_alloc! $(esc(cache_name))[] $T $(esc(var)) $(esc(ex)) + end +end +macro reusable(name, T, null, N, len, var, ex) + cache_name = Symbol("__$(name)_reuse_cache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = TaskLocalValue{ReusableCache{$T}}(()->ReusableCache($T, $null, $N; sized=true)))) + end + quote + @take_or_alloc! $(esc(cache_name))[] $T $(esc(len)) $(esc(var)) $(esc(ex)) + end +end + +# FIXME: Provide ReusableObject{T} interface +# FIXME: Allow objects to be GC'd (if lost via throw/unexpected control flow) (provide optional warning mode on finalization) +# FIXME: Add take/replace interface +# FIXME: Add function annotation for multiple reuse points + +#= FIXME: UniquingCache +struct UniquingCache{K,V} + cache::Dict{WeakRef,WeakRef} + function UniquingCache(K, V) + return new(Dict{K,V}()) + end +end +=# + +mutable struct ReusableNode{T} + value::T + next::Union{ReusableNode{T},Nothing} +end +mutable struct ReusableLinkedList{T} <: AbstractVector{T} + head::Union{ReusableNode{T},Nothing} + tail::Union{ReusableNode{T},Nothing} + free_nodes::ReusableNode{T} + null::T + len::Int + function ReusableLinkedList{T}(null, N) where T + free_root = ReusableNode{T}(null, nothing) + for _ in 1:N + free_node = ReusableNode{T}(null, nothing) + free_node.next = free_root + free_root = free_node + end + return new{T}(nothing, nothing, free_root, null, N) + end +end +Base.eltype(list::ReusableLinkedList{T}) where T = T +function Base.getindex(list::ReusableLinkedList{T}, idx::Integer) where T + node = list.head + for _ in 1:(idx-1) + node === nothing && throw(BoundsError(list, idx)) + node = node.next + end + node === nothing && throw(BoundsError(list, idx)) + return node.value +end +function Base.setindex!(list::ReusableLinkedList{T}, value::T, idx::Integer) where T + node = list.head + for _ in 1:(idx-1) + node === nothing && throw(BoundsError(list, idx)) + node = node.next + end + node === nothing && throw(BoundsError(list, idx)) + node.value = value + return value +end +function Base.push!(list::ReusableLinkedList{T}, value) where T + value_conv = convert(T, value) + node = list.free_nodes + if node.next === nothing + # FIXME: Optionally allocate extras + error("No more entries available in cache for type $T") + end + list.free_nodes = node.next + node.value = value_conv + node.next = nothing + if list.head === nothing + list.head = list.tail = node + else + list.tail.next = node + list.tail = node + end + return list +end +function Base.pop!(list::ReusableLinkedList{T}) where T + if list.head === nothing + throw(ArgumentError("list must be non-empty")) + end + node = list.head + list.head = node.next + if list.head === nothing + list.tail = nothing + end + node.next = list.free_nodes + list.free_nodes = node + value = node.value + node.value = list.null + return value +end +Base.size(list::ReusableLinkedList{T}) where T = (length(list),) +function Base.length(list::ReusableLinkedList{T}) where T + node = list.head + if node === nothing + return 0 + end + len = 1 + while node.next !== nothing + len += 1 + node = node.next + end + return len +end +function Base.iterate(list::ReusableLinkedList{T}) where T + node = list.head + if node === nothing + return nothing + end + return (node.value, node) +end +function Base.iterate(list::ReusableLinkedList{T}, state::Union{Nothing,ReusableNode{T}}) where T + if state === nothing + return nothing + end + node = state.next + if node === nothing + return nothing + end + return (node.value, node) +end +function Base.in(list::ReusableLinkedList{T}, value::T) where T + node = list.head + while node !== nothing + if node.value == value + return true + end + end + return false +end +function Base.findfirst(f::Function, list::ReusableLinkedList) + node = list.head + idx = 1 + while node !== nothing + if f(node.value) + return idx + end + node = node.next + idx += 1 + end + return nothing +end +Base.sizehint!(list::ReusableLinkedList, len::Integer) = nothing +function Base.empty!(list::ReusableLinkedList{T}) where T + if list.tail !== nothing + fill!(list, list.null) + list.tail.next = list.free_nodes + list.free_nodes = list.head + list.head = list.tail = nothing + end + return list +end +function Base.fill!(list::ReusableLinkedList{T}, value::T) where T + node = list.head + while node !== nothing + node.value = value + node = node.next + end + return list +end +function Base.resize!(list::ReusableLinkedList, N::Integer) + while length(list) < N + push!(list, list.null) + end + while length(list) > N + pop!(list) + end + return list +end +function Base.deleteat!(list::ReusableLinkedList, idx::Integer) + checkbounds(list, idx) + if idx == 1 + deleted = list.head + list.head = list.head.next + deleted.next = list.free_nodes + list.free_nodes = deleted + deleted.value = list.null + return list + end + node = list.head + for _ in 1:(idx-2) + if node === nothing + throw(BoundsError(idx)) + end + node = node.next + end + if idx == length(list) + list.tail = node + end + deleted = node.next + node.next = deleted.next + deleted.next = list.free_nodes + list.free_nodes = deleted + deleted.value = list.null + return list +end +function Base.map!(f, list_out::ReusableLinkedList{T}, list_in::ReusableLinkedList{V}) where {T,V} + @assert length(list_out) == length(list_in) "lists must have the same length" + node_out = list_out.head + node_in = list_in.head + while node_in !== nothing + node_out.value = f(node_in.value) + node_in = node_in.next + node_out = node_out.next + end + return list_out +end +function Base.copyto!(list_out::ReusableLinkedList{T}, list_in::ReusableLinkedList{T}) where T + Base.map!(identity, list_out, list_in) +end + +struct ReusableSet{T} <: AbstractSet{T} + list::ReusableLinkedList{T} +end +function ReusableSet(T, null, N) + return ReusableSet{T}(ReusableLinkedList{T}(null, N)) +end +function Base.push!(set::ReusableSet{T}, value::T) where T + if !(value in set) + push!(set.list, value) + end + return set +end +function Base.pop!(set::ReusableSet{T}, value) where T + value_conv = convert(T, value) + idx = findfirst(==(value_conv), set) + if idx === nothing + throw(KeyError(value_conv)) + end + deleteat!(set, idx) + return value +end +Base.length(set::ReusableSet) = length(set.list) +function Base.iterate(set::ReusableSet) + return iterate(set.list) +end +function Base.iterate(set::ReusableSet, state) + return iterate(set.list, state) +end +function Base.empty!(set::ReusableSet{T}) where T + empty!(set.list) + return set +end + +struct ReusableDict{K,V} <: AbstractDict{K,V} + keys::ReusableLinkedList{K} + values::ReusableLinkedList{V} +end +function ReusableDict{K,V}(null_key, null_value, N::Integer) where {K,V} + keys = ReusableLinkedList{K}(null_key, N) + values = ReusableLinkedList{V}(null_value, N) + return ReusableDict{K,V}(keys, values) +end +function Base.getindex(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + throw(KeyError(key_conv)) + end + return dict.values[idx] +end +function Base.setindex!(dict::ReusableDict{K,V}, value, key) where {K,V} + key_conv = convert(K, key) + value_conv = convert(V, value) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + push!(dict.keys, key_conv) + push!(dict.values, value_conv) + else + dict.values[idx] = value_conv + end + return value +end +function Base.delete!(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + idx = findfirst(==(key_conv), dict.keys) + if idx === nothing + throw(KeyError(key_conv)) + end + deleteat!(dict.keys, idx) + deleteat!(dict.values, idx) + return dict +end +function Base.haskey(dict::ReusableDict{K,V}, key) where {K,V} + key_conv = convert(K, key) + return key_conv in dict.keys +end +function Base.iterate(dict::ReusableDict) + key = dict.keys.head + if key === nothing + return nothing + end + value = dict.values.head + return (key.value => value.value, (key, value)) +end +Base.length(dict::ReusableDict) = length(dict.keys) +function Base.iterate(dict::ReusableDict, state) + if state === nothing + return nothing + end + key, value = state + key = key.next + if key === nothing + return nothing + end + value = value.next + return (key.value => value.value, (key, value)) +end +Base.keys(dict::ReusableDict) = dict.keys +Base.values(dict::ReusableDict) = dict.values +function Base.empty!(dict::ReusableDict{K,V}) where {K,V} + empty!(dict.keys) + empty!(dict.values) + return dict +end + +macro reusable_vector(name, T, null, N) + vec_name = Symbol("__$(name)_TLV_ReusableLinkedList") + if !hasproperty(__module__, vec_name) + __module__.eval(:(#=const=# $vec_name = $TaskLocalValue{$ReusableLinkedList{$T}}(()->$ReusableLinkedList{$T}($null, $N)))) + end + return :($(esc(vec_name))[]) +end +macro reusable_dict(name, K, V, null_key, null_value, N) + dict_name = Symbol("__$(name)_TLV_ReusableDict") + if !hasproperty(__module__, dict_name) + __module__.eval(:(#=const=# $dict_name = $TaskLocalValue{$ReusableDict{$K,$V}}(()->$ReusableDict{$K,$V}($null_key, $null_value, $N)))) + end + return :($(esc(dict_name))[]) +end + +mutable struct ReusableTaskCache + tasks::Vector{Task} + chans::Vector{Channel{Any}} + ready::Vector{Threads.Atomic{Bool}} + setup_f::Function + N::Int + init::Bool + function ReusableTaskCache(N::Integer) + tasks = Vector{Task}(undef, N) + chans = Vector{Channel{Any}}(undef, N) + ready = [Threads.Atomic{Bool}(true) for _ in 1:N] + for idx in 1:N + chans[idx] = Channel{Any}(0) + chan, r = chans[idx], ready[idx] + tasks[idx] = @task reusable_task_loop(chan, r) + end + cache = new(tasks, chans, ready, t->nothing, N, false) + finalizer(cache) do cache + # Ask tasks to shut down + for idx in 1:N + Threads.atomic_xchg!(cache.ready[idx], false) + close(cache.chans[idx]) + end + end + return cache + end +end +function reusable_task_cache_init!(setup_f::Function, cache::ReusableTaskCache) + cache.init && return + cache.setup_f = setup_f + for idx in 1:cache.N + task = cache.tasks[idx] + setup_f(task) + schedule(task) + Sch.errormonitor_tracked("reusable_task_$idx", task) + end + cache.init = true + return +end +function reusable_task_loop(chan::Channel{Any}, ready::Threads.Atomic{Bool}) + r = rand(1:128) + while true + f = try + take!(chan) + catch + if !isopen(chan) + return + else + rethrow() + end + end + try + @invokelatest f() + catch err + @error "[$r] Error in reusable task" exception=(err, catch_backtrace()) + end + Threads.atomic_xchg!(ready, true) + end +end +function (cache::ReusableTaskCache)(f, name::String) + idx = findfirst(getindex, cache.ready) + if idx !== nothing + @assert Threads.atomic_xchg!(cache.ready[idx], false) + put!(cache.chans[idx], f) + Sch.errormonitor_tracked_set!(name, cache.tasks[idx]) + return cache.tasks[idx] + else + t = Task(()->f) + cache.setup_f(t) + schedule(t) + Sch.errormonitor_tracked(name, t) + return t + end + return +end + +macro reusable_tasks(name, N, setup_ex, task_name, task_ex) + cache_name = Symbol("__$(name)_TLV_ReusableTaskCache") + if !hasproperty(__module__, cache_name) + __module__.eval(:(#=const=# $cache_name = $TaskLocalValue{$ReusableTaskCache}(()->$ReusableTaskCache($N)))) + end + return esc(quote + $reusable_task_cache_init!($setup_ex, $cache_name[]) + $cache_name[]($task_ex, $task_name) + end) +end diff --git a/test/logging.jl b/test/logging.jl index bfc5de025..343dec24e 100644 --- a/test/logging.jl +++ b/test/logging.jl @@ -119,7 +119,9 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 end end end - @test length(keys(logs)) > 1 + if nprocs() > 1 + @test length(keys(logs)) > 1 + end l1 = logs[1] core = l1[:core] @@ -131,12 +133,14 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 @test any(e->haskey(e, :take), esat) @test any(e->haskey(e, :finish), esat) if Threads.nthreads() == 1 - # Note: May one day be true as scheduler evolves - @test !any(e->haskey(e, :compute), esat) - @test !any(e->haskey(e, :move), esat) - psat = l1[:psat] - # Note: May become false - @test all(e->length(e) == 0, psat) + if nprocs() > 1 + # Note: May one day be true as scheduler evolves + @test !any(e->haskey(e, :compute), esat) + @test !any(e->haskey(e, :move), esat) + psat = l1[:psat] + # Note: May become false + @test all(e->length(e) == 0, psat) + end end had_psat_proc = 0 @@ -155,7 +159,9 @@ import Colors, GraphViz, DataFrames, Plots, JSON3 @test any(e->haskey(e, :move), esat) end end - @test had_psat_proc > 0 + if nprocs() > 1 + @test had_psat_proc > 0 + end logs = TimespanLogging.get_logs!(ml) for w in keys(logs) diff --git a/test/memory-spaces.jl b/test/memory-spaces.jl index df4f69905..7e27f78e9 100644 --- a/test/memory-spaces.jl +++ b/test/memory-spaces.jl @@ -3,41 +3,59 @@ # OSProc x = 123 @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end # ThreadProc x = Dagger.tochunk(123) @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + end - x = remotecall_fetch(Dagger.tochunk, 2, 123) - @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + x = remotecall_fetch(Dagger.tochunk, 2, 123) + @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end x = Dagger.@spawn scope=Dagger.scope(worker=1) identity(123) @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(1) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + if nprocs() > 1 + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(1) + end - x = Dagger.@spawn scope=Dagger.scope(worker=2) identity(123) - @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) - @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + if nprocs() > 1 + x = Dagger.@spawn scope=Dagger.scope(worker=2) identity(123) + @test Dagger.memory_space(x) == Dagger.CPURAMMemorySpace(2) + @test remotecall_fetch(Dagger.memory_space, 2, x) == Dagger.CPURAMMemorySpace(2) + end end @testset "Processor Queries" begin w1_t1_proc = Dagger.ThreadProc(1,1) w1_t2_proc = Dagger.ThreadProc(1,2) - w2_t1_proc = Dagger.ThreadProc(2,1) - w2_t2_proc = Dagger.ThreadProc(2,2) + if nprocs() > 1 + w2_t1_proc = Dagger.ThreadProc(2,1) + w2_t2_proc = Dagger.ThreadProc(2,2) + end @test Dagger.memory_spaces(w1_t1_proc) == Set([Dagger.CPURAMMemorySpace(1)]) @test Dagger.memory_spaces(w1_t2_proc) == Set([Dagger.CPURAMMemorySpace(1)]) - @test Dagger.memory_spaces(w2_t1_proc) == Set([Dagger.CPURAMMemorySpace(2)]) - @test Dagger.memory_spaces(w2_t2_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + if nprocs() > 1 + @test Dagger.memory_spaces(w2_t1_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + @test Dagger.memory_spaces(w2_t2_proc) == Set([Dagger.CPURAMMemorySpace(2)]) + end @test only(Dagger.memory_spaces(w1_t1_proc)) == only(Dagger.memory_spaces(w1_t2_proc)) - @test only(Dagger.memory_spaces(w2_t1_proc)) != only(Dagger.memory_spaces(w1_t1_proc)) + if nprocs() > 1 + @test only(Dagger.memory_spaces(w2_t1_proc)) != only(Dagger.memory_spaces(w1_t1_proc)) + end @test_throws ArgumentError Dagger.memory_spaces(FakeProc()) w1_mem = Dagger.CPURAMMemorySpace(1) - w2_mem = Dagger.CPURAMMemorySpace(2) @test Set(Dagger.processors(w1_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(1))) - @test Set(Dagger.processors(w2_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(2))) + if nprocs() > 1 + w2_mem = Dagger.CPURAMMemorySpace(2) + @test Set(Dagger.processors(w2_mem)) == filter(proc->proc isa Dagger.ThreadProc, Dagger.get_processors(OSProc(2))) + end end end diff --git a/test/processors.jl b/test/processors.jl index 88bd29667..51ec92663 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.ThunkFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.ThunkFailedException, Dagger.Sch.SchedulingException) ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.ThunkFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.ThunkFailedException, Dagger.Sch.SchedulingException) ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/scheduler.jl b/test/scheduler.jl index 96daa491d..56c6ab0a3 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -1,3 +1,4 @@ +import Dagger: Chunk import Dagger.Sch: SchedulerOptions, ThunkOptions, SchedulerHaltedException, ComputeState, ThunkID, sch_handle @everywhere begin @@ -50,18 +51,18 @@ function dynamic_get_dag(x...) end function dynamic_add_thunk(x) h = sch_handle() - id = Dagger.Sch.add_thunk!(h, nothing=>x) do y + t = Dagger.Sch.add_thunk!(h, nothing=>x) do y y+1 end - wait(h, id) - return fetch(h, id) + wait(t) + return fetch(t) end function dynamic_add_thunk_self_dominated(x) h = sch_handle() - id = Dagger.Sch.add_thunk!(h, nothing=>h.thunk_id, nothing=>x) do y + t = Dagger.Sch.add_thunk!(h, nothing=>h.thunk_id, nothing=>x) do y y+1 end - return fetch(h, id) + return fetch(t) end function dynamic_wait_fetch_multiple(x) h = sch_handle() @@ -182,7 +183,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.ThunkFailedException collect(a) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) collect(a) end end @@ -216,134 +217,136 @@ end end end - @testset "Add new workers" begin - ps = [] - try - ps1 = addprocs(2, exeflags="--project") - append!(ps, ps1) + if nprocs() > 1 # Skip if we've disabled workers + @testset "Add new workers" begin + ps = [] + try + ps1 = addprocs(2, exeflags="--project") + append!(ps, ps1) - @everywhere vcat(ps1, myid()) $setup + @everywhere vcat(ps1, myid()) $setup - ctx = Context(ps1) - ts = delayed(vcat)((delayed(testfun)(i) for i in 1:10)...) + ctx = Context(ps1) + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:10)...) - job = @async collect(ctx, ts) + job = @async collect(ctx, ts) - while !istaskstarted(job) - sleep(0.001) - end + while !istaskstarted(job) + sleep(0.001) + end - # Will not be added, so they should never appear in output - ps2 = addprocs(2, exeflags="--project") - append!(ps, ps2) + # Will not be added, so they should never appear in output + ps2 = addprocs(2, exeflags="--project") + append!(ps, ps2) - ps3 = addprocs(2, exeflags="--project") - append!(ps, ps3) - @everywhere ps3 $setup - addprocs!(ctx, ps3) - @test length(procs(ctx)) == 4 + ps3 = addprocs(2, exeflags="--project") + append!(ps, ps3) + @everywhere ps3 $setup + addprocs!(ctx, ps3) + @test length(procs(ctx)) == 4 - @everywhere ps3 blocked=false + @everywhere ps3 blocked=false - ps_used = fetch(job) - @test ps_used isa Vector + ps_used = fetch(job) + @test ps_used isa Vector - @test any(p -> p in ps_used, ps1) - @test any(p -> p in ps_used, ps3) - @test !any(p -> p in ps2, ps_used) - finally - wait(rmprocs(ps)) + @test any(p -> p in ps_used, ps1) + @test any(p -> p in ps_used, ps3) + @test !any(p -> p in ps2, ps_used) + finally + wait(rmprocs(ps)) + end end - end - @test_skip "Remove workers" - #=@testset "Remove workers" begin - ps = [] - try - ps1 = addprocs(4, exeflags="--project") - append!(ps, ps1) - - @everywhere vcat(ps1, myid()) $setup - - # Use single to force scheduler to make use of all workers since we assert it below - ts = delayed(vcat)((delayed(testfun; single=ps1[mod1(i, end)])(i) for i in 1:10)...) - - # Use FilterLog as a callback function. - nprocs_removed = Ref(0) - first_rescheduled_thunk=Ref(false) - rmproctrigger = Dagger.FilterLog(Dagger.NoOpLog()) do event - if typeof(event) == Dagger.Event{:finish} && event.category === :cleanup_proc - nprocs_removed[] += 1 + @test_skip "Remove workers" + #=@testset "Remove workers" begin + ps = [] + try + ps1 = addprocs(4, exeflags="--project") + append!(ps, ps1) + + @everywhere vcat(ps1, myid()) $setup + + # Use single to force scheduler to make use of all workers since we assert it below + ts = delayed(vcat)((delayed(testfun; single=ps1[mod1(i, end)])(i) for i in 1:10)...) + + # Use FilterLog as a callback function. + nprocs_removed = Ref(0) + first_rescheduled_thunk=Ref(false) + rmproctrigger = Dagger.FilterLog(Dagger.NoOpLog()) do event + if typeof(event) == Dagger.Event{:finish} && event.category === :cleanup_proc + nprocs_removed[] += 1 + end + if typeof(event) == Dagger.Event{:start} && event.category === :add_thunk + first_rescheduled_thunk[] = true + end + return false end - if typeof(event) == Dagger.Event{:start} && event.category === :add_thunk - first_rescheduled_thunk[] = true - end - return false - end - ctx = Context(ps1; log_sink=rmproctrigger) - job = @async collect(ctx, ts) + ctx = Context(ps1; log_sink=rmproctrigger) + job = @async collect(ctx, ts) - # Must wait for this or else we won't get callback for rmprocs! - # Timeout so we don't stall forever if something breaks - starttime = time() - while !first_rescheduled_thunk[] && (time() - starttime < 10.0) - sleep(0.1) - end - @test first_rescheduled_thunk[] + # Must wait for this or else we won't get callback for rmprocs! + # Timeout so we don't stall forever if something breaks + starttime = time() + while !first_rescheduled_thunk[] && (time() - starttime < 10.0) + sleep(0.1) + end + @test first_rescheduled_thunk[] - rmprocs!(ctx, ps1[3:end]) - @test length(procs(ctx)) == 2 + rmprocs!(ctx, ps1[3:end]) + @test length(procs(ctx)) == 2 - # Timeout so we don't stall forever if something breaks - starttime = time() - while (nprocs_removed[] < 2) && (time() - starttime < 10.0) - sleep(0.01) - end - # this will fail if we timeout. Verify that we get the logevent for :cleanup_proc - @test nprocs_removed[] >= 2 + # Timeout so we don't stall forever if something breaks + starttime = time() + while (nprocs_removed[] < 2) && (time() - starttime < 10.0) + sleep(0.01) + end + # this will fail if we timeout. Verify that we get the logevent for :cleanup_proc + @test nprocs_removed[] >= 2 - @everywhere ps1 blocked=false + @everywhere ps1 blocked=false - res = fetch(job) - @test res isa Vector + res = fetch(job) + @test res isa Vector - @test res[1:4] |> unique |> sort == ps1 - @test all(pid -> pid in ps1[1:2], res[5:end]) - finally - # Prints "From worker X: IOError:" :/ - wait(rmprocs(ps)) - end - end=# + @test res[1:4] |> unique |> sort == ps1 + @test all(pid -> pid in ps1[1:2], res[5:end]) + finally + # Prints "From worker X: IOError:" :/ + wait(rmprocs(ps)) + end + end=# - @testset "Remove all workers throws" begin - ps = [] - try - ps1 = addprocs(2, exeflags="--project") - append!(ps, ps1) + @testset "Remove all workers throws" begin + ps = [] + try + ps1 = addprocs(2, exeflags="--project") + append!(ps, ps1) - @everywhere vcat(ps1, myid()) $setup + @everywhere vcat(ps1, myid()) $setup - ts = delayed(vcat)((delayed(testfun)(i) for i in 1:16)...) + ts = delayed(vcat)((delayed(testfun)(i) for i in 1:16)...) - ctx = Context(ps1) - job = @async collect(ctx, ts) + ctx = Context(ps1) + job = @async collect(ctx, ts) - while !istaskstarted(job) - sleep(0.001) - end + while !istaskstarted(job) + sleep(0.001) + end - rmprocs!(ctx, ps1) - @test length(procs(ctx)) == 0 + rmprocs!(ctx, ps1) + @test length(procs(ctx)) == 0 - @everywhere ps1 blocked=false - if VERSION >= v"1.3.0-alpha.110" - @test_throws TaskFailedException fetch(job) - else - @test_throws Exception fetch(job) + @everywhere ps1 blocked=false + if VERSION >= v"1.3.0-alpha.110" + @test_throws TaskFailedException fetch(job) + else + @test_throws Exception fetch(job) + end + finally + wait(rmprocs(ps)) end - finally - wait(rmprocs(ps)) end end end @@ -351,21 +354,44 @@ end @testset "Scheduler algorithms" begin @testset "Signature Calculation" begin - @test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) isa Vector{DataType} - @test Dagger.Sch.signature(+, [nothing=>1, nothing=>2]) == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]) isa Dagger.Sch.Signature + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).sig == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).sig_nokw == [typeof(+), Int, Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(2, 2)]).hash_nokw if isdefined(Core, :kwcall) - @test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig == [typeof(Core.kwcall), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig_nokw == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw else kw_f = Core.kwfunc(+) - @test Dagger.Sch.signature(+, [nothing=>1, :a=>2]) == [typeof(kw_f), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig == [typeof(kw_f), @NamedTuple{a::Int64}, typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).sig_nokw == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw == + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash != + Dagger.Sch.signature(+, [Dagger.Argument(1, 1), Dagger.Argument(:a, 2)]).hash_nokw end - @test Dagger.Sch.signature(+, []) == [typeof(+)] - @test Dagger.Sch.signature(+, [nothing=>1]) == [typeof(+), Int] + @test Dagger.Sch.signature(+, []).sig == [typeof(+)] + @test Dagger.Sch.signature(+, []).sig_nokw == [typeof(+)] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1)]).sig == [typeof(+), Int] + @test Dagger.Sch.signature(+, [Dagger.Argument(1, 1)]).sig_nokw == [typeof(+), Int] c = Dagger.tochunk(1.0) - @test Dagger.Sch.signature(*, [nothing=>c, nothing=>3]) == [typeof(*), Float64, Int] + @test Dagger.Sch.signature(*, [Dagger.Argument(1, c), Dagger.Argument(2, 3)]).sig == [typeof(*), Float64, Int] t = Dagger.@spawn 1+2 - @test Dagger.Sch.signature(/, [nothing=>t, nothing=>c, nothing=>3]) == [typeof(/), Int, Float64, Int] + @test Dagger.Sch.signature(/, [Dagger.Argument(1, t), Dagger.Argument(2, c), Dagger.Argument(3, 3)]).sig == [typeof(/), Int, Float64, Int] end @testset "Cost Estimation" begin @@ -408,8 +434,8 @@ end @test est_tx_size == tx_size t = delayed(mynothing)(args...) - inputs = Dagger.Sch.collect_task_inputs(state, t) - sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t, inputs) + Dagger.Sch.collect_task_inputs!(state, t) + sorted_procs, costs = Dagger.Sch.estimate_task_costs(state, procs, t) @test tproc1 in sorted_procs @test tproc2 in sorted_procs @@ -421,7 +447,9 @@ end @test haskey(costs, tproc1) @test haskey(costs, tproc2) @test costs[tproc1] ≈ pres1 # All chunks are local - @test costs[tproc2] ≈ (tx_size/tx_rate) + pres2 # All chunks are remote + if nprocs() > 1 + @test costs[tproc2] ≈ (tx_size/tx_rate) + pres2 # All chunks are remote + end end end end diff --git a/test/thunk.jl b/test/thunk.jl index 82f5c84f5..a7ed0412a 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -49,8 +49,10 @@ end end @testset "@spawn" begin - @test_throws_unwrap ConcurrencyViolationError remotecall_fetch(last(workers())) do - Dagger.Sch.init_eager() + if nprocs() > 1 + @test_throws_unwrap ConcurrencyViolationError remotecall_fetch(last(workers())) do + Dagger.Sch.init_eager() + end end @test Dagger.Sch.EAGER_CONTEXT[] === nothing @testset "per-call" begin @@ -69,7 +71,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.ThunkFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap (Dagger.ThunkFailedException, MethodError) fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -187,7 +189,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.ThunkFailedException fetch(a) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -200,8 +202,7 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin(r"^ThunkFailedException:", ex_str) @test occursin("Test", ex_str) @test !occursin("Root Thunk", ex_str) @@ -211,7 +212,6 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) ex_str = sprint(io->Base.showerror(io,ex)) @test occursin("Test", ex_str) @test occursin("Root Thunk", ex_str) @@ -219,28 +219,28 @@ end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.ThunkFailedException fetch(a) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.ThunkFailedException fetch(b) - @test_throws_unwrap Dagger.ThunkFailedException fetch(c) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(b) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.ThunkFailedException fetch(a) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.ThunkFailedException fetch(b) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.ThunkFailedException fetch(c) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.ThunkFailedException fetch(b) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -248,7 +248,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.ThunkFailedException fetch(c) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -258,45 +258,49 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.ThunkFailedException fetch(d) + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(d) end end - @testset "remote spawn" begin - a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) - @test Dagger.Sch.EAGER_INIT[] - @test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[])) - @test a isa Dagger.DTask - @test fetch(a) == 3 - - # Mild stress-test - @test dynamic_fib(10) == 55 - - # Errors on remote are correctly scrubbed (#430) - t2 = remotecall_fetch(2) do - t1 = Dagger.@spawn 1+"fail" - Dagger.@spawn t1+1 + if 2 in workers() + @testset "remote spawn" begin + a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) + @test Dagger.Sch.EAGER_INIT[] + @test fetch(Distributed.@spawnat 2 !(Dagger.Sch.EAGER_INIT[])) + @test a isa Dagger.DTask + @test fetch(a) == 3 + + # Mild stress-test + @test dynamic_fib(10) == 55 + + # Errors on remote are correctly scrubbed (#430) + t2 = remotecall_fetch(2) do + t1 = Dagger.@spawn 1+"fail" + Dagger.@spawn t1+1 + end + @test_throws_unwrap (Dagger.ThunkFailedException, ErrorException) fetch(t2) end - @test_throws_unwrap Dagger.ThunkFailedException fetch(t2) end - @testset "undefined function" begin - # Issues #254, #255 + if nprocs() > 1 + @testset "undefined function" begin + # Issues #254, #255 - # only defined on head node - @eval evil_f(x) = x + # only defined on head node + @eval evil_f(x) = x - eager_thunks = map(1:10) do i - single = isodd(i) ? 1 : first(workers()) - Dagger.@spawn single=single evil_f(i) - end + eager_thunks = map(1:10) do i + single = isodd(i) ? 1 : first(workers()) + Dagger.@spawn single=single evil_f(i) + end - errored(t) = try - fetch(t) - false - catch - true + errored(t) = try + fetch(t) + false + catch + true + end + @test any(t->errored(t), eager_thunks) + @test any(t->!errored(t), eager_thunks) end - @test any(t->errored(t), eager_thunks) - @test any(t->!errored(t), eager_thunks) end @testset "function chunks" begin @testset "lazy API" begin @@ -326,7 +330,9 @@ end @test_skip !all(x->x==43, collect(ctx, delayed(vcat)([delayed(pls)(1) for i in 1:10]...))) # Positive tests (no serialization) @test all(x->x==43, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope())(1) for i in 1:10]...))) - @test all(x->x==1, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope(first(workers())))(1) for i in 1:10]...))) + if nprocs() > 1 + @test all(x->x==1, collect(ctx, delayed(vcat)([delayed(pls; scope=ProcessScope(first(workers())))(1) for i in 1:10]...))) + end end @testset "Processor Data Movement" begin @everywhere Dagger.add_processor_callback!(()->MulProc(), :mulproc) diff --git a/test/util.jl b/test/util.jl index f01b3d95d..6e03c846c 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,15 +14,15 @@ end replace_obj!(ex::Symbol, obj) = Expr(:(.), obj, QuoteNode(ex)) replace_obj!(ex, obj) = ex function _test_throws_unwrap(terr, ex; to_match=[]) - @gensym rerr + @gensym oerr rerr match_expr = Expr(:block) for m in to_match if m.head == :(=) - lhs, rhs = replace_obj!(m.args[1], rerr), m.args[2] + lhs, rhs = replace_obj!(m.args[1], oerr), m.args[2] push!(match_expr.args, :(@test $lhs == $rhs)) elseif m.head == :call fn = m.args[1] - lhs, rhs = replace_obj!(m.args[2], rerr), m.args[3] + lhs, rhs = replace_obj!(m.args[2], oerr), m.args[3] if fn == :(<) push!(match_expr.args, :(@test startswith($lhs, $rhs))) elseif fn == :(>) @@ -35,12 +35,17 @@ function _test_throws_unwrap(terr, ex; to_match=[]) end end quote - $rerr = try - $(esc(ex)) + $oerr, $rerr = try + nothing, $(esc(ex)) catch err - Dagger.Sch.unwrap_nested_exception(err) + (err, Dagger.Sch.unwrap_nested_exception(err)) + end + if $terr isa Tuple + @test $oerr isa $terr[1] + @test $rerr isa $terr[2] + else + @test $rerr isa $terr end - @test $rerr isa $terr $match_expr end end