Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Oct 3, 2023
1 parent a9c4a73 commit 98058a3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/Comrade.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ include("bayes/bayes.jl")
include("inference/inference.jl")
include("calibration/calibration.jl")
include("clean.jl")

include("rules.jl")

# Load extensions using requires for verions < 1.9
if !isdefined(Base, :get_extension)
Expand Down
31 changes: 31 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

#from Lux to speed up tuple merging
function ChainRulesCore.rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2}
y = merge(nt1, nt2)
function ∇merge(dy)
dnt1 = NamedTuple((f1 => (f1 in F2 ? NoTangent() : getproperty(dy, f1)) for f1 in F1))
dnt2 = NamedTuple((f2 => getproperty(dy, f2) for f2 in F2))
return (NoTangent(), dnt1, dnt2)
end
function ∇merge(dy::Union{NoTangent, ZeroTangent})
return (NoTangent(), NoTangent(), NoTangent())
end
return y, ∇merge
end

function ChainRulesCore.rrule(::typeof(vec), x::AbstractMatrix)
y = vec(x)
∇vec(dy) = (NoTangent(), reshape(dy, size(x)))
return y, ∇vec
end

function ChainRulesCore.rrule(::typeof(collect), v::Vector)
y = collect(v)
∇collect(dy) = (NoTangent(), dy)
return y, ∇collect
end

function ChainRulesCore.rrule(::typeof(copy), x)
∇copy(dy) = (NoTangent(), dy)
return copy(x), ∇copy
end

0 comments on commit 98058a3

Please sign in to comment.