Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization #12

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/BehaviorDispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Dispatches a behavior to the correct method based on the arguments passed in.
It is called by the fallback definition of a behavior. It looks up the original
DuckType that implemented the behavior and calls the method on that DuckType.
"""
function dispatch_behavior(behavior::Type{Behavior{F, S}}, args...; kwargs...) where {F, S}
Base.@assume_effects :foldable function dispatch_behavior(
behavior::Type{Behavior{F, S}}, args...; kwargs...) where {F, S}
DuckT = get_specific_duck_type(fieldtypes(S), args)
OGDuckT = find_original_duck_type(DuckT, behavior)
if isnothing(OGDuckT)
Expand Down
11 changes: 11 additions & 0 deletions src/DuckDispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ include("DuckTypeMacro.jl")
return T[x for x in arg1]
end
@test my_collect((1, 2)) == [1, 2]

ch = Channel{Int}() do ch
for i in 1:2
put!(ch, i)
end
end

DuckDispatch.@duck_dispatch function container_collect(arg1::IsContainer{T}) where {T}
return T[x for x in arg1]
end
@test_throws ErrorException container_collect(ch)
end

end
4 changes: 2 additions & 2 deletions src/MethodDispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Base.@constprop :aggressive function wrap_args(duck_sigs, args)
check_quacks_like = CheckQuacksLike(Tuple{arg_types...})

# this is a tuple of bools which indicate if the method matches the input args
quack_check_result = map(check_quacks_like, duck_sigs)
quack_check_result = tuple_map(check_quacks_like, duck_sigs)

number_of_matches = sum(quack_check_result)
# todo make this a MethodError
Expand All @@ -97,7 +97,7 @@ Base.@constprop :aggressive function wrap_args(duck_sigs, args)

method_match = get_most_specific(quack_check_result, duck_sigs)
method_types = fieldtypes(method_match)[2:end]
wrapped_args = map(wrap_with_guise, method_types, args)
wrapped_args = tuple_map(wrap_with_guise, method_types, args)
return wrapped_args
end

Expand Down
53 changes: 31 additions & 22 deletions src/TypeUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,6 @@ Returns the union of all the behaviors that a `DuckType` implements, including t
return :($u)
end

"""
`find_original_duck_type(::Type{D}, ::Type{B}) -> Union{Nothing, DataType}`
Returns the DuckType that originally implemented a behavior `B` in the DuckType `D`.
This allows us to take a `DuckType` which was composed of many others, find the original,
and then rewrap to that original type.
"""
function find_original_duck_type(::Type{D}, ::Type{B}) where {D, B}
these_behaviors = tuple_collect(get_top_level_behaviors(D))
if any(implies(b, B) for b in these_behaviors)
return D
end
for dt in tuple_collect(get_duck_types(D))
child_res = find_original_duck_type(dt, B)::Union{Nothing, DataType}
!isnothing(child_res) && return child_res
end
end

"""
implies(::Type{DuckType}, ::Type{DuckType}})
Return true if the first DuckType is a composition that contains the second DuckType.
Expand Down Expand Up @@ -91,7 +74,27 @@ function implies(::Type{T1}, ::Type{T2}) where {T1, T2}
end
function implies(t1::Tuple, t2::Tuple)
length(t1) != length(t2) && return false
return all(Iterators.map(implies, t1, t2))
return tuple_all(implies, t1, t2)
end

"""
`find_original_duck_type(::Type{D}, ::Type{B}) -> Union{Nothing, DataType}`
Returns the DuckType that originally implemented a behavior `B` in the DuckType `D`.
This allows us to take a `DuckType` which was composed of many others, find the original,
and then rewrap to that original type.
"""
@generated function find_original_duck_type(::Type{D}, ::Type{B}) where {D, B}
these_behaviors = tuple_collect(get_top_level_behaviors(D))
res = if any(implies(b, B) for b in these_behaviors)
D
end
for dt in tuple_collect(get_duck_types(D))
child_res = find_original_duck_type(dt, B)::Union{Nothing, DataType}
if !isnothing(child_res)
res = child_res
end
end
return :($res)
end

"""
Expand Down Expand Up @@ -139,8 +142,14 @@ Rewraps a `Guise` to implement a different `DuckType`.
"""
rewrap(x::Guise{I1, <:Any}, ::Type{I2}) where {I1, I2 <: DuckType} = wrap(I2, unwrap(x))

function rewrap_where_this(sig::Type{<:Tuple}, ::Type{D}, args::Tuple) where {D <: DuckType}
return map(fieldtypes(sig), args) do T, arg
T === This ? rewrap(arg, D) : arg
end
@generated function wrap_if_this(::Type{T}, ::Type{D}, arg) where {T, D}
T === This && return :(rewrap(arg, D))
return :(arg)
end

@generated function rewrap_where_this(
::Type{T}, ::Type{D}, args::Tuple) where {D <: DuckType, T <: Tuple}
fields = fieldtypes(T)
duck_types = tuple((D for _ in fields)...)
return :(tuple_map(wrap_if_this, $fields, $duck_types, args))
end
12 changes: 7 additions & 5 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ get_duck_type(::G) where {G <: Guise} = get_duck_type(G)
struct TypeChecker{Data}
t::Type{Data}
end
function (::TypeChecker{Data})(::Type{B}) where {Data, B <: Behavior}
@generated function (::TypeChecker{Data})(::Type{B}) where {
Data, B <: Behavior}
sig_types = fieldtypes(get_signature(B))::Tuple
func_type = get_func_type(B)
replaced = map((x) -> x === This ? Data : x, sig_types)
return !isempty(methods(func_type.instance, replaced))
replaced = tuple_map((x) -> x === This ? Data : x, sig_types)
checks = :($hasmethod($(func_type.instance), $replaced) ||
!isempty(methods($(func_type.instance), $replaced)))
return checks
end

"""
Expand All @@ -83,8 +86,7 @@ end
function (x::CheckQuacksLike{T})(::Type{M}) where {T, M}
method_arg_types = fieldtypes(M)
input_arg_types = (DispatchedOnDuckType, fieldtypes(T)...)
can_quack = map(quacks_like, method_arg_types, input_arg_types)
return all(can_quack)
return tuple_all(quacks_like, method_arg_types, input_arg_types)
end

"""
Expand Down
31 changes: 31 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,35 @@ Returns a tuple of the types in a Union type.
U === Union{} && return ()
types = tuple(Base.uniontypes(U)...)
return :($tuple($(types...)))
end

function make_f_calls(f, tuples)
first_tuple_types = fieldtypes(tuples[1])
tuple_lengths = length(first_tuple_types)
number_of_tuples = length(tuples)
@assert all(length(fieldtypes(t)) == tuple_lengths for t in tuples) "all tuples must be same length"
f_calls = [:(f($(
(:(tuples[$j][$i]) for j in 1:number_of_tuples)...)
)
)
for i in 1:tuple_lengths]
return f_calls
end

@generated function tuple_map(f, tuples...)
f_calls = make_f_calls(f, tuples)
quote
tuple(
$(f_calls...)
)
end
end

@generated function tuple_all(f, tuples...)
f_calls = make_f_calls(f, tuples)
with_returns = [:($f_call || return false) for f_call in f_calls]
quote
$(with_returns...)
return true
end
end
Loading