Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About attention mask #141

Open
SuperiorDtj opened this issue Sep 9, 2024 · 18 comments
Open

About attention mask #141

SuperiorDtj opened this issue Sep 9, 2024 · 18 comments

Comments

@SuperiorDtj
Copy link

Why are there no attention masks in DIT and U-Net?
DIT directly removes the attention masks, including both self and cross attention.
In U-Net, the mask is applied by multiplying the keys (k) and values (v) with the mask before the softmax, rather than the more common approach of assigning the masked parts a value close to negative infinity before softmax.
Is these implementation reasonable?

@nateraw
Copy link

nateraw commented Sep 9, 2024

have been meaning to ask about this too, thanks @SuperiorDtj.

For cross attention specifically, it's disabled here with a comment referring to a "kernel issue for flash attention". @zqevans does this still need to be disabled?

if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention

@ksasso1028
Copy link

ksasso1028 commented Sep 18, 2024

because DITs are essentially just transformer encoders, a mask is not needed since they are designed to see all timesteps at once (non causal attention). Masks are typically reserved for transformer decoders, which work in a autoregressive way. The CNN filters in the Unet are also non causal, which is why you do not need to factor in the mask. For diffusion seeing future timesteps could be very important, and could be a reason they work well at generating audio coherent across long timesteps. @nateraw could only imagine this making training slower, especially if you have multiple conditions for cross attn.

@SuperiorDtj
Copy link
Author

because DITs are essentially just transformer encoders, a mask is not needed since they are designed to see all timesteps at once (non causal attention). Masks are typically reserved for transformer decoders, which work in a autoregressive way. The CNN filters in the Unet are also non causal, which is why you do not need to factor in the mask. For diffusion seeing future timesteps could be very important, and could be a reason they work well at generating audio coherent across long timesteps. @nateraw could only imagine this making training slower, especially if you have multiple conditions for cross attn.

Thank you for your response. The reason I want to use the attention mask is that I want to experiment with training on audio of varying lengths instead of using duration features, similar to some recent diffusion model TTS works.
On the other hand, the text features received by cross attention are also not fixed-length. We are wondering if inputting the corresponding mask during training could help the model focus more on useful information. However, our experiments have shown that adding the mask quickly leads to the training loss becoming NaN. We would like to know if this is due to this version of DiT.

@ksasso1028
Copy link

ksasso1028 commented Sep 19, 2024

because DITs are essentially just transformer encoders, a mask is not needed since they are designed to see all timesteps at once (non causal attention). Masks are typically reserved for transformer decoders, which work in a autoregressive way. The CNN filters in the Unet are also non causal, which is why you do not need to factor in the mask. For diffusion seeing future timesteps could be very important, and could be a reason they work well at generating audio coherent across long timesteps. @nateraw could only imagine this making training slower, especially if you have multiple conditions for cross attn.

Thank you for your response. The reason I want to use the attention mask is that I want to experiment with training on audio of varying lengths instead of using duration features, similar to some recent diffusion model TTS works. On the other hand, the text features received by cross attention are also not fixed-length. We are wondering if inputting the corresponding mask during training could help the model focus more on useful information. However, our experiments have shown that adding the mask quickly leads to the training loss becoming NaN. We would like to know if this is due to this version of DiT.

can you link a paper to better describe what you are referring to? cross attention is typically used when sequences being compared are not of the same length so this is normal. It is perfectly normal to train on audio of varying lengths, using adequate padding methods. Can you explain what you expect a mask to achieve in this context? Because this is non causal self attention, at each time step the model has access to all tokens across the sequence and cross attn. As opposed to how an LLM would work, where each time step can only see previous tokens. So not using a mask exposes the model to much more information (which would lead to overfitting in the autoregressive case).

@ksasso1028
Copy link

ksasso1028 commented Sep 19, 2024

if you are looking to generate audio which is not a fixed length without the timing conditions then you will need to use a transformer decoder to autoregressively generate latent tokens up until a stop token, or limit reached. the training of this kind of network would require masks. but because this autoencoder is continuous and not discrete, would be better off using something like encodec for this task

@SuperiorDtj

@SuperiorDtj
Copy link
Author

SuperiorDtj commented Sep 20, 2024

because DITs are essentially just transformer encoders, a mask is not needed since they are designed to see all timesteps at once (non causal attention). Masks are typically reserved for transformer decoders, which work in a autoregressive way. The CNN filters in the Unet are also non causal, which is why you do not need to factor in the mask. For diffusion seeing future timesteps could be very important, and could be a reason they work well at generating audio coherent across long timesteps. @nateraw could only imagine this making training slower, especially if you have multiple conditions for cross attn.

Thank you for your response. The reason I want to use the attention mask is that I want to experiment with training on audio of varying lengths instead of using duration features, similar to some recent diffusion model TTS works. On the other hand, the text features received by cross attention are also not fixed-length. We are wondering if inputting the corresponding mask during training could help the model focus more on useful information. However, our experiments have shown that adding the mask quickly leads to the training loss becoming NaN. We would like to know if this is due to this version of DiT.

can you link a paper to better describe what you are referring to? cross attention is typically used when sequences being compared are not of the same length so this is normal. It is perfectly normal to train on audio of varying lengths, using adequate padding methods. Can you explain what you expect a mask to achieve in this context? Because this is non causal self attention, at each time step the model has access to all tokens across the sequence and cross attn. As opposed to how an LLM would work, where each time step can only see previous tokens. So not using a mask exposes the model to much more information (which would lead to overfitting in the autoregressive case).

https://arxiv.org/abs/2406.11427 This article compares two methods: variable-length audio and fixed-length audio that includes silent segments. My understanding is that the variable-length audio approach is achieved by passing in the self-attention mask. Specifically, through the self-attention mask, the model can learn how to handle noise latent of different lengths.

Regarding the cross-attention mask, this is because the input features, such as phoneme sequences or text features, are not fixed-length. I hope that introducing the cross-attention mask will prevent the model from focusing on the padding token, as the length of the padding token during actual inference may be variable. I intuitively believe that this could affect the model's performance.

@ksasso1028
Copy link

ksasso1028 commented Sep 20, 2024

Looking at this paper, it introduces a second model to predict the length of the sample size for the DIT. The DIT then operates on this fixed vector iteratively. The auxiliary network they use to predict the sequence length uses an encoder - decoder architecture, hence the reason for the mask. The mask is not used in the DIT @SuperiorDtj

@ksasso1028
Copy link

from the paper
image

@ksasso1028
Copy link

and regarding the cross attention features not being fixed length, this is how cross attention works in practice. The reason a mask is used in this paper is to prevent the model from looking ahead at textual features which contribute to how it predicts the length..otherwise it would overfit rather easily+

@SuperiorDtj
Copy link
Author

Looking at this paper, it introduces a second model to predict the length of the sample size for the DIT. The DIT then operates on this fixed vector iteratively. The auxiliary network they use to predict the sequence length uses an encoder - decoder architecture, hence the reason for the mask. The mask is not used in the DIT @SuperiorDtj

Thank you for your response.
Could you elaborate on what is meant by "The DIT then operates on this fixed vector iteratively"? From my perspective, the diffusion model still encounters variable-length audio during training. If we don't introduce a self-attention mask, the model will focus on the silent segments represented by the padding during batch training.

@ksasso1028

@ksasso1028
Copy link

@SuperiorDtj diffusion operates in steps across a fixed sized vector of noise. all this paper is doing is setting the size of the noise based on the textual features it extracts from the transformer encoder. The same could be done for setting the seconds total timing condition for stable audio, allowing you to generate variable length audio.

as to how they trained the model, I am not sure and it is a different story. But to effectively train neural nets, you must use batching which requires your batches to be of the same length. As i only glanced at the paper, whether it was achieved through padding or other means I am not sure. and I dont think they trained with a batch size of 1 :o

@ksasso1028
Copy link

my best guess is that they use a mask on the DIT output to erase tokens per sample based on the predicted length. All this "mask" would be doing is removing the information in this output. is similar to how the prepend mask is utilized here.

@SuperiorDtj
Copy link
Author

my best guess is that they use a mask on the DIT output to erase tokens per sample based on the predicted length. All this "mask" would be doing is removing the information in this output. is similar to how the prepend mask is utilized here.

Yes, I have also found in my experiments that introducing the attention mask during training results in decreased speed and causes the loss to become NaN. Therefore, I continued with the settings from stable audio and did not introduce self or cross attention masks.

@ksasso1028
Copy link

it would be rather simple in this case. You still need to train on fixed length batches, but can use a mask to 0 out token information for the silent portions (anything after predicted length). so during training your attention is not influenced by these tokens, but they still exist in terms of sequence length. For cross attention it will not make sense, since the DIT needs access to all of this information. assuming you have a length predictor, you can either set seconds total or create a "mask" to 0 out all tokens after seconds total. The timing condition also lets the model know anything after seconds total is not worthy information so I would not overthink it. The main piece you need here is a means to predict the sequence length given text which is not the DIT.

@SuperiorDtj
Copy link
Author

it would be rather simple in this case. You still need to train on fixed length batches, but can use a mask to 0 out token information for the silent portions (anything after predicted length). so during training your attention is not influenced by these tokens, but they still exist in terms of sequence length. For cross attention it will not make sense, since the DIT needs access to all of this information. assuming you have a length predictor, you can either set seconds total or create a "mask" to 0 out all tokens after seconds total. The timing condition also lets the model know anything after seconds total is not worthy information so I would not overthink it. The main piece you need here is a means to predict the sequence length given text which is not the DIT.

Thank you for your help; this has been a valuable discussion!

@lixucuhk
Copy link

lixucuhk commented Nov 1, 2024

my best guess is that they use a mask on the DIT output to erase tokens per sample based on the predicted length. All this "mask" would be doing is removing the information in this output. is similar to how the prepend mask is utilized here.

Yes, I have also found in my experiments that introducing the attention mask during training results in decreased speed and causes the loss to become NaN. Therefore, I continued with the settings from stable audio and did not introduce self or cross attention masks.

I am in a similar situation with you. For the u-net version, may I ask how your attention mask was deployed that results in loss being NaN? The default one in codes, i.e. the mask is applied by multiplying the keys (k) and values (v) with the mask before the softmax, or the common one, assigning the masked parts a value close to negative infinity before softmax? Thank you very much!!

@SuperiorDtj
Copy link
Author

my best guess is that they use a mask on the DIT output to erase tokens per sample based on the predicted length. All this "mask" would be doing is removing the information in this output. is similar to how the prepend mask is utilized here.

Yes, I have also found in my experiments that introducing the attention mask during training results in decreased speed and causes the loss to become NaN. Therefore, I continued with the settings from stable audio and did not introduce self or cross attention masks.

I am in a similar situation with you. For the u-net version, may I ask how your attention mask was deployed that results in loss being NaN? The default one in codes, i.e. the mask is applied by multiplying the keys (k) and values (v) with the mask before the softmax, or the common one, assigning the masked parts a value close to negative infinity before softmax? Thank you very much!!

My Unet version does not encounter NaN issues. I have only tried the settings in stable audio, which involve multiplying the sim before softmax by a fixed number (0). Have you encountered NaN in your Unet version?

@lixucuhk
Copy link

lixucuhk commented Nov 4, 2024

my best guess is that they use a mask on the DIT output to erase tokens per sample based on the predicted length. All this "mask" would be doing is removing the information in this output. is similar to how the prepend mask is utilized here.

Yes, I have also found in my experiments that introducing the attention mask during training results in decreased speed and causes the loss to become NaN. Therefore, I continued with the settings from stable audio and did not introduce self or cross attention masks.

I am in a similar situation with you. For the u-net version, may I ask how your attention mask was deployed that results in loss being NaN? The default one in codes, i.e. the mask is applied by multiplying the keys (k) and values (v) with the mask before the softmax, or the common one, assigning the masked parts a value close to negative infinity before softmax? Thank you very much!!

My Unet version does not encounter NaN issues. I have only tried the settings in stable audio, which involve multiplying the sim before softmax by a fixed number (0). Have you encountered NaN in your Unet version?

Thank you for your reply! Not yet. I decided not to use masking in the cross-attention module, instead, I plan to use an additional token to represent the padding token. Such that we can pad all conditions into the same length with the "padding token".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants