Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Rmath derivatives #1620

Open
mhauru opened this issue Jul 8, 2024 · 2 comments
Open

Missing Rmath derivatives #1620

mhauru opened this issue Jul 8, 2024 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@mhauru
Copy link
Contributor

mhauru commented Jul 8, 2024

A few Distributions.jl distributions rely on Rmath.jl, for which derivatives seem to not be defined.

module MWE

using Distributions: NoncentralBeta, logpdf
using Distributions: NoncentralChisq, NoncentralF, NoncentralT
using Enzyme

f(x) = logpdf(NoncentralBeta(1.0, 1.0, 1.0), x[1])
g(x) = logpdf(NoncentralChisq(1.0, 1.0), x[1])
h(x) = logpdf(NoncentralF(1.0, 1.0, 1.0), x[1])
i(x) = logpdf(NoncentralT(1.0, 1.0), x[1])
Enzyme.gradient(Enzyme.Forward, f, [0.5])
Enzyme.gradient(Enzyme.Forward, g, [0.5])
Enzyme.gradient(Enzyme.Forward, h, [0.5])
Enzyme.gradient(Enzyme.Forward, i, [0.5])

end

For NoncentralBeta, output:

Current scope:
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_8695({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="4757723584" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #4 !dbg !47 {
top:
  %1 = call {}*** @julia.get_pgcstack() #5
  %ptls_field3 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !8
  %3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #5, !dbg !48
  fence syncscope("singlethread") seq_cst
  %4 = addrspacecast {} addrspace(10)* %0 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !49
  %arraylen_ptr = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %4, i64 0, i32 1, !dbg !49
  %arraylen = load i64, i64 addrspace(11)* %arraylen_ptr, align 8, !dbg !49, !tbaa !18, !range !21, !alias.scope !22, !noalias !25
  %inbounds.not = icmp eq i64 %arraylen, 0, !dbg !49
  br i1 %inbounds.not, label %oob, label %idxend, !dbg !49

oob:                                              ; preds = %top
  %errorbox = alloca i64, align 8, !dbg !49
  store i64 1, i64* %errorbox, align 8, !dbg !49, !noalias !50
  %5 = addrspacecast {} addrspace(10)* %0 to {} addrspace(12)*, !dbg !49
  call void @ijl_bounds_error_ints({} addrspace(12)* noundef %5, i64* noundef nonnull align 8 %errorbox, i64 noundef 1) #6, !dbg !49
  unreachable, !dbg !49

idxend:                                           ; preds = %top
  %6 = addrspacecast {} addrspace(10)* %0 to double addrspace(13)* addrspace(11)*, !dbg !49
  %arrayptr6 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %6, align 16, !dbg !49, !tbaa !33, !alias.scope !53, !noalias !25, !nonnull !7
  %arrayref = load double, double addrspace(13)* %arrayptr6, align 8, !dbg !49, !tbaa !36, !alias.scope !39, !noalias !40
  %7 = call double @dnbeta(double %arrayref, double noundef 1.000000e+00, double noundef 1.000000e+00, double noundef 1.000000e+00, i32 noundef 1) #5, !dbg !54
  ret double %7, !dbg !54
}

No forward mode derivative found for dnbeta
 at context:   %7 = call double @dnbeta(double %arrayref, double noundef 1.000000e+00, double noundef 1.000000e+00, double noundef 1.000000e+00, i32 noundef 1) #5, !dbg !41

Stacktrace:
 [1] nbetalogpdf
   @ ~/.julia/packages/StatsFuns/mrf0e/src/rmath.jl:77
 [2] logpdf
   @ ~/.julia/packages/Distributions/ji8PW/src/univariates.jl:645
 [3] f
   @ ~/projects/Enzyme-mwes/callinst_metadata/noncentralbeta.jl:7

The others produce similar outputs but instead of dnbeta for dnchisq, dnf, and dnt. Reverse mode likewise missing.

@wsmoses
Copy link
Member

wsmoses commented Jul 8, 2024

Yeah unfortunately I've never seen those functions before so I'm not quite sure what they intend to compute.

Know any docs for offhand. Also would be happy to show you how to add internal support within Enzyme for function derivatives, if interested

@yebai
Copy link

yebai commented Jul 9, 2024

Related: compintell/Mooncake.jl#31

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants