Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EnzymeRule giving ERROR: AssertionError: sz == sizeof(Int) #2085

Open
ptiede opened this issue Nov 12, 2024 · 1 comment
Open

EnzymeRule giving ERROR: AssertionError: sz == sizeof(Int) #2085

ptiede opened this issue Nov 12, 2024 · 1 comment

Comments

@ptiede
Copy link
Contributor

ptiede commented Nov 12, 2024

Ok this is probably my fault with a buggy rule, but I am not sure what is causing this.

I have a custom rule for a \ for a sparse cholesky. Here is a MWE

using LinearAlgebra
using SparseArrays
using Enzyme: EnzymeRules
using Enzyme

struct CholeskyFactor{T, P<:AbstractMatrix{T},C} <: AbstractMatrix{T}
    cov::P
    cho::C
end
CholeskyFactor(cov::AbstractMatrix) = CholeskyFactor(cov, cholesky(cov))
Base.parent(m::CholeskyFactor) = m.cov
Base.size(m::CholeskyFactor) = size(parent(m))
Base.getindex(m::CholeskyFactor, i::Int) = getindex(parent(m), i)
Base.getindex(m::CholeskyFactor, I::Vararg{Int, 2}) = getindex(parent(m), I...)
Base.adjoint(m::CholeskyFactor) = CholeskyFactor(m.cov, adjoint(m.cho))

Base.:\(c::CholeskyFactor, v::AbstractVector) = c.cho\v


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, 
                                      c::Annotation{<:CholeskyFactor}, v::Annotation{<:AbstractArray}) where {RT}

    if !(typeof(c) <: Const)
        throw(ArgumentError("CholeskyFactor must be a Const"))
    end

    cache_c = if EnzymeRules.overwritten(config)[2]
        c.val
    else
        c.val
    end

    res = c.val \ v.val

    cache_v = if EnzymeRules.overwritten(config)[3]
        copy(v.val)
    else
        v.val
    end

    cache_res = if EnzymeRules.needs_primal(config)
        copy(res)
    else
        res
    end


    primal = if EnzymeRules.needs_primal(config)
        res
    else
        nothing
    end

    dres = if EnzymeRules.width(config) == 1
        zero(res)
    else
        ntuple(Val(EnzymeRules.width(config))) do i
            Base.@_inline_meta
            zero(res)
        end
    end

    retres = if EnzymeRules.needs_primal(config)
        res
    else
        nothing
    end

    cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{
        eltype(RT),
        EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing,
        typeof(cache_c),
        typeof(cache_v)
        }}(
        (cache_res, dres, cache_c, cache_v)
    )

    # For EnzymeCore 0.8
    return EnzymeRules.AugmentedReturn{
        EnzymeRules.primal_type(config, RT),
        EnzymeRules.shadow_type(config, RT),
        typeof(cache)
    }(retres, dres, cache)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfig, func::Const{typeof(\)}, ::Type{RT}, cache, 
                             c::Annotation{<:CholeskyFactor}, v::Annotation{<:AbstractArray}) where {RT}

    res, dres, cache_c, cache_v = cache
    
    if !(typeof(c) <: Const)
        throw(ArgumentError("CholeskyFactor must be a Const"))
    end

    
    if !EnzymeRules.overwritten(config)[3]
        cache_v = v.val
    end

    if !EnzymeRules.overwritten(config)[2]
        cache_c = c.val
    end

    if EnzymeRules.width(config) == 1
        dress = (dres,)
    else
        dress = dres
    end

    dvs = if EnzymeRules.width(config) == 1
            (v.dval,)
    else
        v.dval
    end

    for (dv, dres) in zip(dvs, dress)
        z = adjoint(c.val.cho) \ dres
        dv .+= z
        dres .= zero(eltype(dres))
    end

    return (nothing, nothing)
end

N = 5
σ = sprand(N,N, 0.1)
Σ = 0.5.*+ σ') + 5 .* Diagonal(ones(N))
C = CholeskyFactor(Σ)

f(C, x) = sum(C\x)

x = rand(N)
dx = zero(x)
autodiff(Reverse, f, Active, Const(C), Duplicated(x, dx))

The error I get on my machine is

ERROR: AssertionError: sz == sizeof(Int)
Stacktrace:
  [1] should_recurse(typ2::Any, arg_t::LLVM.IntegerType, byref::GPUCompiler.ArgumentCC, dl::LLVM.DataLayout)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/absint.jl:210
  [2] abs_typeof(arg::LLVM.LoadInst, partial::Bool, seenphis::Set{LLVM.PHIInst})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/absint.jl:575
  [3] abs_typeof(arg::LLVM.LoadInst)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/absint.jl:281
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:7048
  [5] codegen
    @ ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:6128 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:8431
  [7] _thunk
    @ ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:8431 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:8472 [inlined]
  [9] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:8604
 [10] #s2103#19072
    @ ~/.julia/packages/Enzyme/3Svej/src/compiler.jl:8741 [inlined]
 [11] 
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [13] autodiff
    @ ~/.julia/packages/Enzyme/3Svej/src/Enzyme.jl:473 [inlined]
 [14] autodiff(::ReverseMode{…}, ::typeof(f), ::Type{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/3Svej/src/Enzyme.jl:512
 [15] top-level scope
    @ ~/Research/Enzyme/cholrule.jl:134
Some type information was truncated. Use `show(err)` to see complete types.

and this is on Enzyme#main using Julia 1.10.5. This only started to appear in recent Enzyme versions, and was fine e.g., in 0.13.9.

@wsmoses
Copy link
Member

wsmoses commented Nov 13, 2024

using LinearAlgebra
using SparseArrays
using Enzyme: EnzymeRules
using Enzyme

N = 5
σ = sprand(N,N, 0.1)
s = 0.5.*+ σ') + 5 .* Diagonal(ones(N))

function f(cho, x)
    F = cho
    B = x
    sys = SparseArrays.CHOLMOD.CHOLMOD_A
    SparseArrays.CHOLMOD.Dense{Float64}( SparseArrays.CHOLMOD.cholmod_l_solve(sys, F, B, getcommon(Int64)))
end

x = rand(N)
cho = cholesky(s)
x2 = SparseArrays.CHOLMOD.Dense{Float64}(x)
autodiff(Forward, f, Const, Const(cho), Const(x2))

@vchuravy I think we need to talk this through

in any case paul this should be a workaround for you for now: #2086

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants