Skip to content

Commit

Permalink
Fix over-eager casting to PrimeField in broadcast
Browse files Browse the repository at this point in the history
For example, this would cast a boolean result to the PrimeField (#16).
  • Loading branch information
tkluck committed Jul 14, 2021
1 parent 90f2b79 commit 5f73ae9
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 66 deletions.
159 changes: 93 additions & 66 deletions src/Broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,135 +77,162 @@ const FusedModBroadcasted{F, InnerStyle} = Broadcasted{<:FusedModStyle{F, InnerS

# -----------------------------------------------------------------------------
#
# Override instantiate to use lazy reduction operations
# Override instantiate to use deferred reduction operations
#
# -----------------------------------------------------------------------------

instantiate(bc::Broadcasted{<:FusedModStyle{F}}) where F <: PrimeField = instantiate(fieldvals(F, unreducedbroadcast(bc)))
instantiate(bc::FusedModBroadcasted) = instantiate(maybe_unreducedbroadcast(bc))

# -----------------------------------------------------------------------------
#
# Helper types and methods to compute the fused mod broadcast
#
# -----------------------------------------------------------------------------

"""
struct UnreducedBroadcast
Represents a broadcast whose eltype should be F, but which has delayed
the mod operation and so for now has eltype an Integer.
The instantiate(...) function can be used to convert it to a Broadcast
object with eltype equal to F.
"""
struct UnreducedBroadcast{F <: PrimeField, BC <: Broadcasted, Bounds}
bc :: BC
end

eltype(::UnreducedBroadcast{F}) where F <: PrimeField = F
@inline UnreducedBroadcast(F, bounds, bc) = UnreducedBroadcast{F, typeof(bc), bounds}(bc)

const TupleOf{F, N} = NTuple{N, F}
# the kinds of arguments we may find for a mergeable broadcast
const BroadcastableWithBounds{F} = Union{UnreducedBroadcast{F}, AbstractArray{F}, TupleOf{F}, F}

bounds(::Type{<:UnreducedBroadcast{F, BC, Bounds}}) where {F, BC, Bounds} = Bounds
bounds(a::Type{<:AbstractArray{<:PrimeField{I}}}) where I = I(0) : I(char(eltype(a)) - 1)
bounds(::Type{<:TupleOf{F}}) where F <: PrimeField{I} where I = I(0) : I(char(F) - 1)
bounds(a::Type{<:PrimeField{I}}) where I = I(0) : I(char(a) - 1)

widenleaves(I, ubc::UnreducedBroadcast) = widenleaves(I, ubc.bc)
widenleaves(I, a::AbstractArray{<:PrimeField}) = broadcasted(widenleaves, I, a)
widenleaves(I, a::AbstractArray{<:Integer}) = broadcasted(widenleaves, I, a)
widenleaves(I, a::TupleOf{<:PrimeField}) = map(a_i -> widenleaves(I, a_i), a)
widenleaves(I, a::TupleOf{<:Integer}) = map(a_i -> widenleaves(I, a_i), a)
widenleaves(I, a::PrimeField) = a.n % I
widenleaves(I, a::Integer) = a % I
function widenleaves(I, bc::Broadcasted)
if bc.f == widenleaves
return broadcasted(widenleaves, promote_type(I, bc.args[1][]), bc.args[2])
"""
struct Widener{I <: Integer}
A callable object that traverses a tree of Broadcasted operations
and converts the arguments to the integer type I.
"""
struct Widener{I <: Integer}
end

eltype(::Widener{I}) where I = I

(w::Widener)(a::PrimeField) = a.n % eltype(w)
(w::Widener)(a::Integer) = a % eltype(w)
(w::Widener)(a::AbstractArray{<:PrimeField}) = broadcasted(w, a)
(w::Widener)(a::AbstractArray{<:Integer}) = broadcasted(w, a)
(w::Widener)(a::TupleOf{<:PrimeField}) = map(w, a)
(w::Widener)(a::TupleOf{<:Integer}) = map(w, a)
(w::Widener)(ubc::UnreducedBroadcast) = UnreducedBroadcast(eltype(ubc), bounds(typeof(ubc)), w(ubc.bc))
(w::Widener)(bc::Broadcasted) = begin
B = Broadcasted{typeof(BroadcastStyle(typeof(bc)))}
if bc.f isa FusableOps
args = map(w, bc.args)
return B(bc.f, args, bc.axes)
elseif bc.f isa Widener
w′ = Widener{promote_type(eltype(w), eltype(bc.f))}()
return B(w′, bc.args, bc.axes)
else
extendedleaves = widenleavestuple(I, bc.args...)
return Broadcasted{typeof(BroadcastStyle(typeof(bc)))}(bc.f, extendedleaves, bc.axes)
return B(w, (bc,), nothing)
end
end

widenleavestuple(I) = ()
widenleavestuple(I, arg, args...) = (widenleaves(I, arg), widenleavestuple(I, args...)...)
widenleavestuple(I, arg, args...) = (Widener{I}()(arg), widenleavestuple(I, args...)...)

intvals(ubc::UnreducedBroadcast) = ubc.bc
intvals(a::Integer) = a
intvals(a::PrimeField) = a.n
intvals(a::AbstractArray{<:PrimeField}) = reinterpret(inttype(eltype(a)), a)
intvals(a::AbstractArray{<:Integer}) = a
intvals(a::TupleOf{<:PrimeField}) = map(a -> a.n, a)
intvals(a::TupleOf{<:PrimeField}) = map(intvals, a)
intvals(a::TupleOf{<:Integer}) = a
intvals(a::PrimeField) = a.n
intvals(a::Integer) = a
intvals(ubc::UnreducedBroadcast) = ubc.bc

intvalstuple() = ()
intvalstuple(arg, args...) = (intvals(arg), intvalstuple(args...)...)

reducedvals(F, ubc::UnreducedBroadcast) = broadcasted(posmod, intvals(ubc), char(F))
reducedvals(F, a::AbstractArray{<:PrimeField}) = broadcasted(a_i -> a_i.n, a)
reducedvals(F, a::AbstractArray{<:Integer}) = a
reducedvals(F, a::TupleOf{<:PrimeField}) = map(a_i -> a_i.n, a)
reducedvals(F, a::TupleOf{<:Integer}) = a
reducedvals(F, a::PrimeField) = a.n
reducedvals(F, a::Integer) = posmod(a, char(F))

reducedvalstuple(F) = ()
reducedvalstuple(F, arg, args...) = (reducedvals(F, arg), reducedvalstuple(F, args...)...)

fieldvals(F, bc::Broadcasted) = broadcasted(F, bc)
fieldvals(F, ubc::UnreducedBroadcast) = broadcasted(F, intvals(ubc))
fieldvals(F, a::AbstractArray{<:PrimeField}) = a
fieldvals(F, a::AbstractArray{<:Integer}) = broadcasted(F, a)
fieldvals(F, a::TupleOf{<:PrimeField}) = a
fieldvals(F, a::TupleOf{<:Integer}) = map(F, a)
fieldvals(F, a::PrimeField) = a
fieldvals(F, a::Integer) = F(a)

fieldvalstuple(F) = ()
fieldvalstuple(F, arg, args...) = (fieldvals(F, arg), fieldvalstuple(F, args...)...)

unreducedbroadcast(a) = a
unreducedbroadcast(bc::FusedModBroadcasted{F}) where F <: PrimeField = unreducedbroadcast(F, bc.f, style(BroadcastStyle(typeof(bc))), bc.axes, map(unreducedbroadcast, bc.args)...)

function unreducedbroadcast(F, f, innerstyle, axes, args...)
fieldargs = fieldvalstuple(F, args...)
bc = Broadcasted{typeof(innerstyle)}(f, fieldargs)
return UnreducedBroadcast(F, bounds(F), bc)
end
maybe_unreducedbroadcast(a) = a
maybe_unreducedbroadcast(bc::FusedModBroadcasted{F}) where F <: PrimeField = _maybe_unreducedbroadcast(F, bc.f, style(BroadcastStyle(typeof(bc))), bc.axes, map(maybe_unreducedbroadcast, bc.args)...)

_maybe_unreducedbroadcast(F, f, innerstyle, axes, args...) = Broadcasted{typeof(innerstyle)}(f, maybe_reduce.(args), axes)

notnarrower(I, J) = promote_type(I, J) == I
iswider(I, J) = !notnarrower(J, I)

# the kinds of arguments we may find for a mergeable broadcast
const FusableOps = Union{typeof(+), typeof(-), typeof(*), typeof(^)}
const BroadcastableWithBounds{F} = Union{UnreducedBroadcast{F}, AbstractArray{F}, TupleOf{F}, F}

function unreducedbroadcast(::Type{F}, f::FusableOps, innerstyle, axes, args::BroadcastableWithBounds{F}...) where F <: PrimeField
function _maybe_unreducedbroadcast(::Type{F}, f::FusableOps, innerstyle, axes, args::BroadcastableWithBounds{F}...) where F <: PrimeField
resultbounds = joinbounds(f, map(bounds typeof, args)...)
I = eltype(resultbounds)
J = promote_type(map(eltype bounds typeof, args)...)
# The integer operation won't overflow because J is at least as big as I
if notnarrower(J, I)
intargs = intvalstuple(args...)
return UnreducedBroadcast(F, resultbounds, Broadcasted{typeof(innerstyle)}(f, intargs, axes))
bci = Broadcasted{typeof(innerstyle)}(f, intargs, axes)
return UnreducedBroadcast(F, resultbounds, bci)
# It might overflow, but in any case it stays a bits type. We widen the arguments
# before doing the operation to prevent the overflow.
elseif notnarrower(Int, I)
extargs = widenleavestuple(I, args...)
return UnreducedBroadcast(F, resultbounds, Broadcasted{typeof(innerstyle)}(f, extargs, axes))
elseif iswider(J, inttype(F))
redargs = reducedvalstuple(F, args...)
return unreducedbroadcast(f, innerstyle, axes, redargs...)
w = Widener{I}()
intargs = intvalstuple(args...)
bci = Broadcasted{typeof(innerstyle)}(f, intargs, axes)
wbci = w(bci)
return UnreducedBroadcast(F, resultbounds, wbci)
else
fieldargs = fieldvalstuple(F, args...)
bci = Broadcasted{typeof(innerstyle)}((a -> a.n)f, fieldargs, axes)
return UnreducedBroadcast(F, bounds(F), bci)
return Broadcasted{typeof(innerstyle)}(f, maybe_reduce.(args), axes)
end
end

instantiate(ubc::UnreducedBroadcast) = broadcasted(eltype(ubc), intvals(ubc))
maybe_reduce(x) = x
maybe_reduce(ubc::UnreducedBroadcast) = instantiate(ubc)

# -----------------------------------------------------------------------------
#
# Transform / and // to a version that also works when reinterpreting as integers
#
# -----------------------------------------------------------------------------
const DivOp = Union{typeof(/), typeof(//)}
function unreducedbroadcast(::Type{F}, f::DivOp, innerstyle, axes, a::BroadcastableWithBounds{F}, b::BroadcastableWithBounds{F}) where F <: PrimeField
b_inverse = unreducedbroadcast(F, inv, innerstyle, axes, b)
return unreducedbroadcast(F, *, innerstyle, axes, a, b_inverse)
function _maybe_unreducedbroadcast(::Type{F}, f::DivOp, innerstyle, axes, a::BroadcastableWithBounds{F}, b::BroadcastableWithBounds{F}) where F <: PrimeField
b_inverse = _maybe_unreducedbroadcast(F, inv, innerstyle, axes, b)
return _maybe_unreducedbroadcast(F, *, innerstyle, axes, a, b_inverse)
end

function unreducedbroadcast(::Type{F}, f::typeof(inv), innerstyle, axes, a::BroadcastableWithBounds{F}) where F <: PrimeField
function _maybe_unreducedbroadcast(::Type{F}, f::typeof(inv), innerstyle, axes, a::BroadcastableWithBounds{F}) where F <: PrimeField
bc_inverse = broadcasted(invmod, intvals(a), char(F))
return UnreducedBroadcast(F, bounds(F), bc_inverse)
end


# -----------------------------------------------------------------------------
#
# Testability
#
# -----------------------------------------------------------------------------
_is_integer_broadcast(leaf) = eltype(leaf) <: Integer
_is_integer_broadcast(bc::Broadcasted) = all(_is_integer_broadcast, bc.args) || bc.f isa Widener
_is_integer_broadcast(ubc::UnreducedBroadcast) = _is_integer_broadcast(ubc.bc)
isfused(bc) = false
isfused(bc::FusedModBroadcasted) = _is_integer_broadcast(maybe_unreducedbroadcast(bc))

broadcasted_calls(x) = x
broadcasted_calls(expr::Expr) = if expr.head == :call
Expr(:call, broadcasted, broadcasted_calls.(expr.args)...)
elseif expr.head == :$
expr.args[1]
else
Expr(expr.head, broadcasted_calls.(expr.args)...)
end
macro broadcasted(expr)
expr = broadcasted_calls(expr)
esc(expr)
end

end # module
10 changes: 10 additions & 0 deletions test/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using LinearAlgebra: norm, tr
using GaloisFields.Broadcast: @broadcasted, isfused

const MAXITERATIONS = 100
const MAXITERATIONS2 = round(Int, sqrt(MAXITERATIONS))
Expand Down Expand Up @@ -334,6 +335,15 @@ const MAXITERATIONS3 = round(Int, cbrt(MAXITERATIONS))

# tuple broadcasting
@test (F[x[1:10];]...,) .+ (F[y[1:10];]...,) == (F[x[1:10] .+ y[1:10];]...,)

# Booleans
@test eltype(F[x;] .== F[y;]) == Bool

# test that operations get fused
@test isfused(@broadcasted $(F[x;]) + $(F[y;]))
@test isfused(@broadcasted $(F[x;]) + $(F(y[1])) * $(F[y;]))
@test isfused(@broadcasted $(F[x;]) - $(F(y[1])) * $(F[y;]))
@test isfused(@broadcasted $(F[x;]) - $(F[y;]) / $(F(y[1])))
end

@testset "Random selection" begin
Expand Down

0 comments on commit 5f73ae9

Please sign in to comment.