diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 6c2159ef7..00a77ae2f 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -330,9 +330,7 @@ def backward(ctx, grad_output): actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) # Now wait for reduce scatter - world_size = get_tensor_model_parallel_world_size() - if world_size != 1: - handle.wait() + handle.wait() ffn_layer_norm_output, handle = _gather_along_first_dim( ffn_layer_norm_output, async_op=True, cached_buffer_name="mpu" @@ -341,8 +339,7 @@ def backward(ctx, grad_output): grad_fc2_input = grad_output.matmul(fc2_weight) if ctx.recompute_fc1: - if world_size != 1: - handle.wait() + handle.wait() assert fc1_out is None fc1_out = torch.matmul(ffn_layer_norm_output, fc1_weight.t()) actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) @@ -362,8 +359,7 @@ def backward(ctx, grad_output): grad_fc2_weight = grad_output.t().matmul(actv_out) grad_fc1_input = grad_actv_input.matmul(fc1_weight) - if world_size != 1: - handle.wait() + handle.wait() grad_actv_input = SequeuceParallelTransformerBlock._collapse_first_dimensions( grad_actv_input @@ -380,8 +376,7 @@ def backward(ctx, grad_output): grad_fc1_weight = grad_actv_input.t().matmul(ffn_layer_norm_output) - if world_size != 1: - handle.wait() + handle.wait() grad_attention_output = fused_layer_norm_cuda.backward( grad_fc1_input.contiguous(), @@ -425,8 +420,7 @@ def backward(ctx, grad_output): q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype ) - if world_size != 1: - handle.wait() + handle.wait() grad_out_proj_input = grad_attention_output.matmul(out_proj_weight) grad_attention_output = ( @@ -477,8 +471,7 @@ def backward(ctx, grad_output): cached_buffer_name="mpu", ) grad_input = grad_kvq_proj_output.matmul(kvq_proj_weight) - if world_size != 1: - handle.wait() + handle.wait() grad_input, handle = _reduce_scatter_along_first_dim(grad_input, async_op=True) mha_layer_norm_output = ( @@ -492,8 +485,7 @@ def backward(ctx, grad_output): ) ) grad_kvq_weight = grad_kvq_proj_output.t().matmul(mha_layer_norm_output) - if world_size != 1: - handle.wait() + handle.wait() grad_input = fused_layer_norm_cuda.backward( grad_input.contiguous(),