We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux.huber_loss
It looks like Flux.huber_loss is type unstable when it comes to Zygote autodiff ?
using Flux, Zygote import Statistics: mean function internfunc_nobroad(m, x, y) modelvals = m(x) Flux.mse(modelvals, y) end function internfunc_nobroad_huberloss(m, x, y) modelvals = m(x) Flux.huber_loss(modelvals, y) end function wrapfunc(model, xdata, ydata, func) grad = let xdata=xdata, ydata=ydata Zygote.gradient(m -> func(m, xdata, ydata), model) end return grad end fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1)) fobs_ar = fill(5f0, 5, 10) labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
It looks like
Flux.huber_loss
is type unstable when it comes to Zygote autodiff ?The text was updated successfully, but these errors were encountered: