Skip to content

Commit

Permalink
Support is_grads_batched parameter for autograd.grad (#10399)
Browse files Browse the repository at this point in the history
close #10397 

这里只需要在参数检查的时候对 is_grads_batched 做处理就行了,不需要侵入到 AutogradEgnine 里。
实际后向计算的时候,会自己做 broadcast 操作,如果计算错误,是算子对 broadcast 支持的不全。

---------

Co-authored-by: wyg1997 <[email protected]>
  • Loading branch information
wyg1997 and wyg1997 authored Jan 10, 2024
1 parent 6ed4991 commit 9754f6d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 9 deletions.
30 changes: 23 additions & 7 deletions oneflow/api/python/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ bool IsScalarTensor(const one::Tensor& tensor) {
// If output is a scalar tensor, out_grad will also be a scaler or empty(will be initted to
// `oneflow.ones([1])`).
Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,
const one::TensorTuple& out_grads) {
const one::TensorTuple& out_grads,
bool is_grads_batched) {
size_t grad_size = out_grads.empty() ? outputs.size() : out_grads.size();
auto gradients = std::make_shared<one::TensorTuple>(grad_size);
CHECK_EQ_OR_RETURN(outputs.size(), gradients->size())
Expand All @@ -70,9 +71,22 @@ Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,
<< "Grad can be implicitly created only for scalar outputs";
gradients->at(i) = JUST(one::functional::OnesLike(outputs.at(i)));
} else {
CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape()))
<< "out_grad's shape must be same as output's (" << outputs.at(i)->shape()->ToString()
<< " vs " << out_grads.at(i)->shape()->ToString() << ")";
if (is_grads_batched) {
if (*(outputs.at(i)->shape()) != *JUST(out_grads.at(i)->shape()->Slice(1))) {
THROW(RuntimeError) << "If `is_grads_batched=True`, we interpret the first "
<< "dimension of each grad_output as the batch dimension. "
<< "The sizes of the remaining dimensions are expected to match "
<< "the shape of corresponding output, but a mismatch "
<< "was detected: grad_output[" << i << "] has a shape of "
<< out_grads.at(i)->shape()->ToString() << " and output[" << i
<< "] has a shape of " << outputs.at(i)->shape()->ToString() << ".";
}

} else {
CHECK_EQ_OR_RETURN(*(outputs.at(i)->shape()), *(out_grads.at(i)->shape()))
<< "out_grad's shape must be same as output's (" << outputs.at(i)->shape()->ToString()
<< " vs " << out_grads.at(i)->shape()->ToString() << ")";
}
if (JUST(oneflow::VectorAt(outputs, i))->dtype()
!= JUST(oneflow::VectorAt(out_grads, i))->dtype()) {
JUST(oneflow::VectorAt(*gradients, i)) =
Expand All @@ -93,15 +107,16 @@ Maybe<one::TensorTuple> Backward(const one::TensorTuple& outputs, const one::Ten
PythonFrameGuard pf;
BackwardPassScopeGuard backward_guard;
if (create_graph) { retain_graph = true; }
std::shared_ptr<one::TensorTuple> gradients = JUST(CheckAndInitOutGrads(outputs, out_grads));
std::shared_ptr<one::TensorTuple> gradients =
JUST(CheckAndInitOutGrads(outputs, out_grads, /*is_grads_batched=*/false));
JUST(one::GetThreadLocalAutogradEngine()->RunBackwardAndSaveGrads4LeafTensorIf(
outputs, *gradients, retain_graph, create_graph));
return std::make_shared<one::TensorTuple>(0);
}

Maybe<one::TensorTuple> Grad(const one::TensorTuple& outputs, const one::TensorTuple& inputs,
const one::TensorTuple& out_grads, bool retain_graph,
bool create_graph, bool allow_unused) {
bool create_graph, bool allow_unused, bool is_grads_batched) {
PythonFrameGuard pf;
BackwardPassScopeGuard backward_guard;
if (create_graph) { retain_graph = true; }
Expand All @@ -110,7 +125,8 @@ Maybe<one::TensorTuple> Grad(const one::TensorTuple& outputs, const one::TensorT
inputs.begin(), inputs.end(),
[](const std::shared_ptr<one::Tensor>& tensor) { return tensor->requires_grad(); }))
<< "All input tensors `.requires_grad` should be true";
std::shared_ptr<one::TensorTuple> gradients = JUST(CheckAndInitOutGrads(outputs, out_grads));
std::shared_ptr<one::TensorTuple> gradients =
JUST(CheckAndInitOutGrads(outputs, out_grads, is_grads_batched));
return one::GetThreadLocalAutogradEngine()->RunBackwardAndReturnInputsTensorGradIf(
outputs, inputs, *gradients, retain_graph, create_graph, allow_unused);
}
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/common/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ Maybe<Shape> Shape::Slice(int64_t start_dim, int64_t end_dim) const {
if (start_dim > ndims) { start_dim = ndims; }
if (end_dim > ndims) { end_dim = ndims; }
std::shared_ptr<Shape> shape = std::make_shared<Shape>();
for (int64_t i = start_dim; i < end_dim && i < ndims; ++i) { shape->emplace_back(this->At(i)); }
shape->assign(this->begin() + start_dim, this->begin() + end_dim);
return shape;
}

Maybe<Shape> Shape::Slice(int64_t start_dim) const { return Slice(start_dim, NumAxes()); }

bool Shape::operator==(const Shape& rhs) const {
if (is_initialized_ != rhs.is_initialized_) { return false; }
if (is_initialized_ == false) { return true; }
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Shape final : public DimVector, public MutShapeMixIn<Shape> {
AxisVector Axes4BroadcastTo(ShapeView broadcast_dim_vec) const;

Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;
Maybe<Shape> Slice(int64_t start_dim) const;

bool operator==(const Shape& rhs) const;

Expand Down
7 changes: 7 additions & 0 deletions python/oneflow/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def grad(
retain_graph: bool = False,
create_graph: bool = False,
allow_unused: bool = False,
is_grads_batched: bool = False,
) -> Tuple[Tensor]:
r"""
Computes and returns the sum of gradients of outputs with respect to the inputs.
Expand Down Expand Up @@ -56,6 +57,11 @@ def grad(
allow_unused (bool, optional): If ``False``, specifying inputs that were not
used when computing outputs (and therefore their grad is always zero)
is an error. Defaults to ``False``.
is_grads_batched (bool, optional): If True, the first dimension of each tensor in
grad_outputs will be interpreted as the batch dimension. Instead of computing a single
vector-Jacobian product, we compute a batch of vector-Jacobian products for each “vector”
in the batch. This should lead to performance improvements when compared to manually
looping and performing backward multiple times. Defaults to ``False``.
Returns:
Tuple(Tensor): A tuple of tensors containing the gradients for each ``inputs``.
Expand All @@ -67,6 +73,7 @@ def grad(
retain_graph,
create_graph,
allow_unused,
is_grads_batched,
)
return tuple([x for x in in_grads])

Expand Down
55 changes: 54 additions & 1 deletion python/oneflow/test/modules/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
from collections import OrderedDict

import torch as original_torch
import numpy as np
import oneflow as flow
import oneflow.unittest
Expand Down Expand Up @@ -210,7 +211,6 @@ def test_no_grad_domain_call_backward(test_case):
flow.autograd.backward(y, flow.ones_like(y))
test_case.assertTrue(np.array_equal(x.grad.numpy(), np.full(random_shape, 2.0)))

@unittest.skip("skip for now, becase it failed 2 times in past week")
@autotest(n=1, auto_backward=False, check_graph=False)
def test_acc_grad_inplace_update(test_case):
random_shape = [random(1, 5).to(int).value() for _ in range(4)]
Expand Down Expand Up @@ -283,6 +283,59 @@ def test_autograd_grad_allow_unused(test_case):
)[0]
test_case.assertTrue(dddx.oneflow is None and dddx.pytorch is None)

def test_autograd_is_grads_batched(test_case):
x = flow.randn(2, 2, requires_grad=True)

out = x.clone() # Size([2, 2])
batched_grad = flow.arange(3).expand(2, 2, 3).transpose(0, 2) # Size([3, 2, 2])
(grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
test_case.assertTrue(
np.array_equal(
grad.cpu().detach().numpy(),
flow.arange(3)
.expand(2, 2, 3)
.transpose(0, 2)
.to(dtype=grad.dtype)
.numpy(),
)
)

# Detect shape mismatch
grad_out = flow.ones(2, 2)
with test_case.assertRaisesRegex(
RuntimeError, "If `is_grads_batched=True`, we interpret the first"
):
flow.autograd.grad(
outputs=out,
grad_outputs=(grad_out,),
inputs=(x,),
is_grads_batched=True,
)

# TODO: ReduceSum backward not support broadcast grad with shape (3, ) to (3, 2, 2)
# # Scalar outputs
# out = x.sum() # Size([])
# batched_grad = flow.arange(3) # Size([3])
# (grad,) = flow.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
# test_case.assertTrue(
# np.array_equal(
# grad.cpu().detach().numpy(),
# flow.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype).numpy(),
# )
# )

# We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
grad_out = flow.ones(2).unsqueeze(1)
with test_case.assertRaisesRegex(
RuntimeError, "If `is_grads_batched=True`, we interpret the first"
):
flow.autograd.grad(
outputs=out,
grad_outputs=(grad_out,),
inputs=(x,),
is_grads_batched=True,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9754f6d

Please sign in to comment.