forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
autograd.cpp
217 lines (203 loc) · 6.43 KB
/
autograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/variable.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#endif
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <c10/util/irange.h>
namespace torch::autograd {
// NB: This code duplicates existing logic at torch/autograd/__init__.py and
// torch._C._EngineBase.run_backward in torch/csrc/autograd/python_engine.cpp
// This is a purely C++ API for Autograd without any dependencies on python
// it can be exposed in PyTorch C++ API and TorchScript. We will need to
// maintain the logic equality of this file and the python file together if one
// changes.
// TODO: Make the Python API above to just call this C++ API.
static variable_list _make_grads(
const variable_list& outputs,
const variable_list& grad_outputs) {
size_t num_tensors = outputs.size();
size_t num_gradients = grad_outputs.size();
variable_list new_grads;
new_grads.reserve(num_tensors);
if (grad_outputs.empty()) {
for (const Variable& output : outputs) {
if (output.requires_grad()) {
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
TORCH_CHECK(
c10::isFloatingType(output.scalar_type()),
"grad can be computed only for real scalar outputs but got ",
output.scalar_type());
new_grads.emplace_back(
at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
}
} else {
TORCH_CHECK(
num_tensors == num_gradients,
"got ",
num_tensors,
" tensors and ",
num_gradients,
" gradients");
for (const auto i : c10::irange(outputs.size())) {
const Variable& output = outputs[i];
const Variable& grad_output = grad_outputs[i];
if (!grad_output.defined()) {
if (output.requires_grad()) {
TORCH_CHECK(
output.numel() == 1,
"grad can be implicitly created only for scalar outputs");
TORCH_CHECK(
c10::isFloatingType(output.scalar_type()),
"grad can be computed only for real scalar outputs but got ",
output.scalar_type());
new_grads.emplace_back(
at::ones_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
} else {
TORCH_CHECK(
grad_output.is_complex() == output.is_complex(),
"For complex Tensors, both grad_output and output are required ",
"to have the same dtype. Mismatch in dtype: grad_output[",
grad_output,
"] has a dtype of ",
grad_output.scalar_type(),
" and output[",
output,
"] has a dtype of ",
output.scalar_type(),
".");
// grad output is defined, just append to the new_grads
new_grads.emplace_back(grad_output);
}
}
}
return new_grads;
}
static variable_list run_backward(
const variable_list& outputs,
const variable_list& grad_outputs,
bool keep_graph,
bool create_graph,
const variable_list& inputs,
bool allow_unused,
bool accumulate_grad) {
size_t num_tensors = outputs.size();
edge_list roots;
roots.reserve(num_tensors);
for (const auto i : c10::irange(num_tensors)) {
const Variable& output = outputs[i];
auto gradient_edge = impl::gradient_edge(output);
TORCH_CHECK(
gradient_edge.function,
"element ",
i,
" of tensors does not require grad and does not have a grad_fn");
roots.push_back(std::move(gradient_edge));
}
edge_list output_edges;
if (!inputs.empty()) {
size_t num_inputs = inputs.size();
output_edges.reserve(num_inputs);
for (const auto i : c10::irange(num_inputs)) {
const Variable& input = inputs[i];
const auto output_nr = input.output_nr();
auto grad_fn = input.grad_fn();
if (!grad_fn) {
grad_fn = impl::try_get_grad_accumulator(input);
}
if (accumulate_grad) {
input.retain_grad();
}
TORCH_CHECK(
input.requires_grad(),
"element ",
i,
" of the input tensors does not require grad");
if (!grad_fn) {
// See NOTE [ Autograd Unreachable Input ] for details
output_edges.emplace_back(std::make_shared<Identity>(), 0);
} else {
output_edges.emplace_back(grad_fn, output_nr);
}
}
}
variable_list grad_inputs = Engine::get_default_engine().execute(
roots,
grad_outputs,
keep_graph,
create_graph,
accumulate_grad,
output_edges);
// check if grad_inputs contains None or not base on the allow_unused flag
if (!inputs.empty() && !allow_unused) {
size_t num_inputs = inputs.size();
for (const auto i : c10::irange(num_inputs)) {
TORCH_CHECK(
grad_inputs[i].defined(),
"element ",
i,
"of the "
"differentiated Tensors appears to not have been used "
"in the graph. Set allow_unused=True if this is the "
"desired behavior.");
}
}
return grad_inputs;
}
void backward(
const variable_list& tensors,
const variable_list& grad_tensors,
std::optional<bool> retain_graph,
bool create_graph,
const variable_list& inputs) {
variable_list gradients = _make_grads(tensors, grad_tensors);
if (!retain_graph) {
retain_graph = create_graph;
}
run_backward(
tensors,
gradients,
retain_graph.value(),
create_graph,
inputs,
/*allow_unused=*/true,
/*accumulate_grad=*/true);
}
variable_list grad(
const variable_list& outputs,
const variable_list& inputs,
const variable_list& grad_outputs,
std::optional<bool> retain_graph,
bool create_graph,
bool allow_unused) {
variable_list gradients = _make_grads(outputs, grad_outputs);
if (!retain_graph) {
retain_graph = create_graph;
}
return run_backward(
outputs,
gradients,
retain_graph.value(),
create_graph,
inputs,
allow_unused,
/*accumulate_grad=*/false);
}
namespace forward_ad {
uint64_t enter_dual_level() {
return ForwardADLevel::get_next_idx();
}
void exit_dual_level(uint64_t level) {
ForwardADLevel::release_idx(level);
}
} // namespace forward_ad
} // namespace torch::autograd