diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5e60faf16..ba44b8ac6 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4269,13 +4269,18 @@ def new_ones(self, inputs, input_types): def tril(self, inputs, input_types): x = inputs[0] x_shape = _infer_shape(x) - - y = np.tril(np.ones(x_shape)).astype(_convert_tvm_to_np_dtype(input_types[0])) - y = tvm.nd.array(y) - y = tvm.relay.Constant(y) - - return _op.multiply(x, y) - + diagonal = inputs[1] + + count = np.arange(np.prod(x_shape)).reshape(x_shape) + comp = count.transpose(-1, -2) + count = tvm.relay.Constant(tvm.nd.array(count)) + count = _op.cast(count, self.infer_type(diagonal).dtype) + comp = tvm.relay.Constant(tvm.nd.array(comp)) + comp = _op.cast(comp, self.infer_type(diagonal).dtype) + + tril = _op.less_equal(comp, _op.add(count, diagonal)) + tril = _op.cast(tril, self.infer_type(x).dtype) + return _op.multiply(tril, x) def triu(self, inputs, input_types): x = inputs[0]