Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No proper support for aten::triu OP inputs when inputs are CallNodes [Bug] #3

Closed
JushBJJ opened this issue Mar 30, 2024 · 7 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@JushBJJ
Copy link
Contributor

JushBJJ commented Mar 30, 2024

Hi I'm doing the bounties with @marty1885 and @JonathanALevine.

Context

When attempting to implement Qwen 1.5 (0.5B) tenstorrent/tt-buda-demos#20, we encountered an issue where we had this error:

Traceback (most recent call last):
  File "/home/jush/qwen/.venv/lib/python3.8/site-packages/numpy/lib/twodim_base.py", line 536, in triu
    mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
  File "/home/jush/qwen/.venv/lib/python3.8/site-packages/tvm/relay/expr.py", line 146, in __sub__
    raise TypeError(f'convert "{str(other)}" with `const` first')
TypeError: convert "1" with `const` first

Qwen 1.5 (0.5B) implementation demo can be found in tenstorrent/tt-buda-demos#37

The OP code in this case for triu would be:

%aten::triu_0_0 : Int(256, 256, strides=[256, 1], requires_grad=0, device=cpu) = aten::triu(%aten::ones_like_0_0, %aten::Int_7_0), scope: utils.modeling_qwen2.Qwen2ForCausalLM::/utils.modeling_qwen2.Qwen2Model::model # /home/jush/qwen/.venv/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:169:0

*Requires transformers >=v4.37.0

Root Cause

The triu function in the PyTorchOpConverter class in tvm/relay/frontend/pytorch.py uses np.triu to handle these inputs.

mask = np.triu(np.ones(x_shape), inputs[1]).astype(np.bool)

However, it is not possible to handle these input types when the inputs parameter is a list containing nested functions. In this case, inputs[1] would be:

CallNode(Op(add), [CallNode(Op(subtract), [CallNode(Op(subtract), [CallNode(Op(add), [Constant(256), Constant(0)], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(256)], (nullptr), [])], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(32768)], (nullptr), [])], (nullptr), []), Constant(1)], (nullptr), [])

Workaround

In draft PR #2, I used self.trilu(inputs, input_types, mode="triu") simply because _op.trilu can do triangular upper instead of triangular bottom if upper=True.

def trilu(self, inputs, input_types, mode):
    data = inputs[0]
    k = inputs[1] if inputs[1] else 0
    upper = True if mode == "triu" else False
    return _op.trilu(data, k, upper)

This workaround successfully compiles the Qwen 1.5 (0.5B) model but was still unable to properly run.

See #2 to see the issues of this workaround.

Environment

OS: Ubuntu 20.04
Pybuda Version: v0.10.5.gs.240315
TVM Version (from latest Pybuda): 0.14.0

Steps to reproduce

See tenstorrent/tt-buda-demos#37

Triage

frontend:pytorch

@JushBJJ JushBJJ changed the title No proper support for aten::triu OP inputs when inputs are [Bug] No proper support for aten::triu OP inputs when inputs are CallNodes [Bug] Mar 30, 2024
@staylorTT
Copy link

Hi There, thanks for filing this issue. I will do my best to get someone to take a look at this from the TT side.

@milank94 milank94 added the bug Something isn't working label May 7, 2024
@AleksKnezevic
Copy link
Contributor

To clarify @JushBJJ are the inputs to the op constant (even if they're call nodes). That is, would we be able to extract them using _infer_value?

@JushBJJ
Copy link
Contributor Author

JushBJJ commented May 10, 2024

To clarify @JushBJJ are the inputs to the op constant (even if they're call nodes). That is, would we be able to extract them using _infer_value?

I will get into this very soon later today, my e75 just arrived yesterday

@staylorTT
Copy link

@JushBJJ Do you have any updates here?

@JushBJJ
Copy link
Contributor Author

JushBJJ commented Jul 25, 2024

@JushBJJ Do you have any updates here?

Hi, incredibly sorry for forgetting about this. Since I don't have much to do atm this is full priority now.
Currently trying to pass some Buda hurdles atm with the more recent updates but so far I haven't reached this issue again...yet

@staylorTT
Copy link

Appreciate the update. If you hit this again please report back.

@JushBJJ
Copy link
Contributor Author

JushBJJ commented Jul 25, 2024

@staylorTT Yup this can stay closed, can confirm that this is no longer an issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants