From fe89b3b8101959b6a7a4581c46f7d5925fa071ea Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 13:45:44 -0800 Subject: [PATCH] lint fixes --- .../modules/sequence_parallel_transformer_layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 00a77ae2f..0823b0936 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -30,7 +30,6 @@ ) from megatron.mpu.utils import split_tensor_along_last_dim from megatron.model.fused_softmax import scaled_upper_triang_masked_softmax_cuda - from megatron.mpu import get_tensor_model_parallel_world_size has_megatron_submodule = True except (ImportError, ModuleNotFoundError): @@ -434,7 +433,7 @@ def backward(ctx, grad_output): if xf_eff_attn: grad_out_proj_input = grad_out_proj_input.reshape( seq_len, bsz, -1, head_dim - ).transpose(0,1) + ).transpose(0, 1) d_q, d_k, d_v = xops.memory_efficient_attention_backward( grad=grad_out_proj_input, output=out, @@ -457,7 +456,6 @@ def backward(ctx, grad_output): grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim ) - ( mha_layer_norm_output, mha_layer_norm_mean,