Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom rule in SparseMatrixCSC #2013

Open
hochunlin opened this issue Oct 24, 2024 · 1 comment
Open

Custom rule in SparseMatrixCSC #2013

hochunlin opened this issue Oct 24, 2024 · 1 comment

Comments

@hochunlin
Copy link

I am writing some reverse-mode custom rules for manipulating sparse matrices with Enzyme. However, I did not get the correct result compared to the finite-difference result. The following is the MWE:

using Enzyme
import Enzyme.EnzymeCore
using Random
using Test
using SparseArrays

function mul_internal!(S, A, B)
    S[:] = A*B
end

function mul_custom!(S, A, B)
    S[:] = A*B
end

function EnzymeRules.augmented_primal(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom augmented primal rule.")

    if typeof(S) <: EnzymeCore.Duplicated || typeof(S) <: EnzymeCore.BatchDuplicated
        func.val(S.val, A.val, B.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        S.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        S.dval
    else
        nothing
    end
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
    config,
    func::EnzymeRules.Const{typeof(mul_custom!)},
    ::Type{RT},
    cache,
    S::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    A::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    B::EnzymeCore.Const{<:AbstractArray{T,N}},
    ) where {RT,T,N}
    println("In custom reverse rule.")

    dys = S.dval
    dxs = A.dval
    
    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(S) <: EnzymeCore.Const) && dy !== S.val
            if !(typeof(A) <: EnzymeCore.Const) && dx !== A.val
                dx .+=  dy * (B.val)'
            end
            dy .= 0
        end
    end
    return (nothing, nothing, nothing)
end


@testset "Test dense matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are dense matrices
    A_internal = rand(3,3)
    dA_internal = make_zero(A_internal)
    B_internal = rand(3,3)
    S_internal = zeros(3,3)
    dS_internal = make_zero(S_internal)
    
    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 1: internal Enzyme rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 2: custom rule for mul in dense matrix (it works)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0
     
     # 0.08344008943212289  ≈ 0.0834400894378362
     @test dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

@testset "Test sparse matrix S = A*B: Computing S[2,2] pullback from tuning A[2,1]" begin
    Random.seed!(1234) # Make the result reproducible

    tuning_ind_i = 2
    tuning_ind_j = 1
    resulting_ind_i = 2
    resulting_ind_j = 2

    # A, B, and S are sparse matrices
    A_internal =sprand(3,3,1.0)
    dA_internal = make_zero(A_internal)
    B_internal =sprand(3,3,1.0)
    S_internal = sparse(ones(Float64,3,3))
    dS_internal = make_zero(S_internal)

    dS_internal[resulting_ind_i,resulting_ind_j] = 1 # Set the pullback at S[2,2]

    A_custom = deepcopy(A_internal)
    dA_custom = deepcopy(dA_internal)
    B_custom = deepcopy(B_internal)
    S_custom = deepcopy(S_internal)
    dS_custom = deepcopy(dS_internal)

    ϵ = 1e-5
    r_matrix = zeros(3,3)
    r_matrix[tuning_ind_i,tuning_ind_j] = ϵ/2

    finite_difference = ((A_internal + r_matrix)*B_internal - (A_internal - r_matrix)*B_internal)/ϵ

    # Case 3: internal Enzyme rule for mul in sparse matrix (it works)
    autodiff(
        Reverse, 
        mul_internal!, 
        Const,
        Duplicated(S_internal, dS_internal),
        Duplicated(A_internal, dA_internal), 
        Const(B_internal), 
    ) 
    dA_internal[1,1] # 0.0
    dA_internal[2,1] # 0.08344008943212289
    dA_internal[3,1] # 0.0
    dA_internal[1,2] # 0.0
    dA_internal[2,2] # 0.525795663891226
    dA_internal[3,2] # 0.0
    dA_internal[1,3] # 0.0
    dA_internal[2,3] # 0.8406409194782338
    dA_internal[3,3] # 0.0

    # 0.08344008943212289 ≈ 0.0834400894378362
    @test dA_internal[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2

    # Case 4: custom rule for mul in sparse matrix (it does not work)
    autodiff(
        Reverse, 
        mul_custom!, 
        Const,
        Duplicated(S_custom, dS_custom),
        Duplicated(A_custom, dA_custom), 
        Const(B_custom), 
    ) 
     dA_custom[1,1] # 0.08344008943212289 (which should be dA[2,1])
     dA_custom[2,1] # 0.525795663891226 (which should be dA[2,2])
     dA_custom[3,1] # 0.8406409194782338 (which should be dA[2,3])
     @test_throws BoundsError dA_custom[1,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,2] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[1,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[2,3] # ERROR: BoundsError
     @test_throws BoundsError dA_custom[3,3] # ERROR: BoundsError
     
     # 0.525795663891226 ≈ 0.0834400894378362
     @test_broken dA_custom[tuning_ind_i,tuning_ind_j] ≈ finite_difference[resulting_ind_i,resulting_ind_j] rtol = 1e-2
end

In this example, I tried to get the autodiff from a matrix multiplication: S = A*B, where B is a constant matrix.

I test the four cases:

Case 1. Dense matrix multiplication with internal enzyme rule
Case 2. Dense matrix multiplication with my custom enzyme rule
Case 3. Sparse matrix multiplication with internal enzyme rule
Case 4. Sparse matrix multiplication with my custom enzyme rule

The case 4 does not give the right result. The custom rule in sparse matrix multiplication does not work because the pullbacks passed to it are incorrect, such as passing dA[2,1] wrongly to dA[1,1].

Now, I am confused why case 3 (internal Enzyme rule) can work, but case 4 (custom Enzyme rule) doesn't. Is there a subtlety in the implementation of the custom rule for the sparse matrix I am missing? Thanks for any suggestion or insight in advance!

@hochunlin
Copy link
Author

Hi, I tested more on the sparse matrix. I bumped into more issues.

In the following MWE, I tried to get the autodiff from an inversion matrix: y = x⁻¹, where the element x[3,2] is changed by the parameter r (i.e. x[3,2] = x[3,2] + r):

using Enzyme
import EnzymeCore
using SparseArrays
using Test
using Random

function inv_plus_r_without_custom_rule_for_plus!(y::AbstractArray, x::AbstractArray, r::AbstractArray)
    x[3,2] = x[3,2] + r[1]
    inv!(y,x)
end

function inv_plus_r!(y::AbstractArray, x::AbstractArray, r::AbstractArray)
    x[3,2] = x[3,2] + r[1]
    inv!(y,x)
end

function inv!(y::AbstractArray, x::AbstractArray)
    y[:] = inv(Matrix(x))
end


function EnzymeRules.augmented_primal(config, func::EnzymeRules.Const{typeof(inv_plus_r!)}, ::Type{RT},
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    r::EnzymeCore.Annotation{<:AbstractArray{T,1}}
) where {RT,T,N}
    println("In custom augmented primal rule in inv_plus_r!.")

    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
        func.val(y.val, x.val, r.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        y.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        y.dval
    else
        nothing
    end

    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::EnzymeRules.Const{typeof(inv_plus_r!)}, ::Type{RT}, cache,
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    r::EnzymeCore.Annotation{<:AbstractArray{T,1}}
) where {RT,T,N}
    println("In custom reverse rule in inv_plus_r!.")

    dys = y.dval
    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end
    for (dy, dx) in zip(dys, dxs)
        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
                dx .+=  - inv(Matrix(x.val))' * dy * inv(Matrix(x.val))'
            end
            dy .= 0
        end
    end
    r.dval[1] = dxs[1][3,2]
    return (nothing, nothing,nothing)
end

function EnzymeRules.augmented_primal(config, func::EnzymeRules.Const{typeof(inv!)}, ::Type{RT},
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}}
) where {RT,T,N}
    println("In custom augmented primal rule in inv!.")

    if typeof(y) <: EnzymeCore.Duplicated || typeof(y) <: EnzymeCore.BatchDuplicated
        func.val(y.val, x.val)
    end

    primal = if EnzymeRules.needs_primal(config)
        y.val
    else
        nothing
    end
    shadow = if EnzymeRules.needs_shadow(config)
        y.dval
    else
        nothing
    end
    return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(config, func::EnzymeRules.Const{typeof(inv!)}, ::Type{RT}, cache,
    y::EnzymeCore.Annotation{<:AbstractArray{T,N}},
    x::EnzymeCore.Annotation{<:AbstractArray{T,N}}
) where {RT,T,N}
    println("In custom reverse rule in inv!.")

    dys = y.dval
    dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval

    if EnzymeRules.width(config) == 1
        dys = (dys,)
        dxs = (dxs,)
    end

    for (dy, dx) in zip(dys, dxs)
        if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
            if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
                dx .+=  - inv(Matrix(x.val))' * dy * inv(Matrix(x.val))'
            end
            dy .= 0
        end
    end
    return (nothing, nothing)
end

@testset "Test: Computing y[2,2] (y = x⁻¹) pullback at x[3,2] = x[3,2] + r from x = sprand(1000,1000,probability = 0.9) with partial custom rule and whole custom rule" begin
    Random.seed!(1234) # Make the result reproducible

    x_partial_rule = sprand(Float64,100,100,0.7)
    dx_partial_rule = make_zero(x_partial_rule)
    y_partial_rule = make_zero(sparse(ones(Float64, 100,100)))
    dy_partial_rule = make_zero(y_partial_rule);
    dy_partial_rule[2, 2] = 1; # set pullback seed to [1, 1] index
    r_partial_rule = [0.0]
    dr_partial_rule = [0.0]
    

    Random.seed!(1234) # Make the result reproducible

    x_whole_rule = sprand(Float64,100,100,0.7)
    dx_whole_rule = make_zero(x_whole_rule)
    y_whole_rule = make_zero(sparse(ones(Float64, 100,100)))
    dy_whole_rule = make_zero(y_whole_rule);
    dy_whole_rule[2, 2] = 1; # set pullback seed to [1, 1] index
    r_whole_rule = [0.0]
    dr_whole_rule = [0.0]

    delta = spzeros(size(x_whole_rule));
    fd_delta = 1e-5;
    delta[3, 2] = fd_delta;
    delta;
    grad_fd = (inv(Matrix(x_whole_rule .+ delta / 2)) - inv(Matrix(x_whole_rule .- delta / 2)) ) / fd_delta

    Enzyme.autodiff(
        Reverse,
        inv_plus_r_without_custom_rule_for_plus!,
        Const,
        Duplicated(y_partial_rule, dy_partial_rule),
        Duplicated(x_partial_rule, dx_partial_rule),
        Duplicated(r_partial_rule, dr_partial_rule),
    )

    Enzyme.autodiff(
        Reverse,
        inv_plus_r!,
        Const,
        Duplicated(y_whole_rule, dy_whole_rule),
        Duplicated(x_whole_rule, dx_whole_rule),
        Duplicated(r_whole_rule, dr_whole_rule),
    )

    @test dx_partial_rule[3,2] ≈ grad_fd[2, 2] rtol = 1e-3
    @test_broken dr_partial_rule[1] ≈ grad_fd[2, 2] rtol = 1e-3
    @test_broken dr_partial_rule[1] ≈ dx_partial_rule[3,2] rtol = 1e-3

    @test dx_whole_rule[3,2] ≈ grad_fd[2, 2] rtol = 1e-3
    @test dr_whole_rule[1] ≈ grad_fd[2, 2] rtol = 1e-3
    @test dr_whole_rule[1] ≈ dx_whole_rule[3,2] rtol = 1e-3
end

Here I compute the pullback by writing

  • the custom rule for the only inversion part
  • the custom rule for the inversion and addition part

I expect the pullback dx[3,2] = dr in this simple case. However, if I just wrote the custom rule for the only inversion part, then dx[3,2] != dr.

After testing, this issue also only happened in the sparse matrices, any thought and insight would be appreciated. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant