From b4539bbe51bee7a8baca5b5633b08ef123cfa016 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Fri, 11 Nov 2022 10:24:33 -0800 Subject: [PATCH 01/28] [wip] Adding flash attention for sequence parallel --- .../sequence_parallel_transformer_layer.py | 101 +++++++----------- 1 file changed, 36 insertions(+), 65 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 2f8925058..874610a08 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -8,11 +8,19 @@ import importlib import math import torch +from types import SimpleNamespace # Not importing here cause cpu tests don't like it global fused_layer_norm_cuda fused_layer_norm_cuda = None +try: + import xformers.ops as xops + + has_xformers = True +except (ImportError, ModuleNotFoundError): + has_xformers = False + try: from megatron.mpu.mappings import ( _reduce_scatter_along_first_dim, @@ -26,6 +34,17 @@ has_megatron_submodule = False +class _FakeContext(SimpleNamespace): + """ + Used to provide a temporary buffer for FlashAttention's saved buffers + """ + + saved_tensors = None + + def save_for_backward(self, *args): + self.saved_tensors = args + + class SequeuceParallelTransformerBlock(torch.autograd.Function): """ This is custom FFN autograd function hardcoded for: @@ -36,60 +55,6 @@ class SequeuceParallelTransformerBlock(torch.autograd.Function): gelu, layernorm: always recomputed i.e. no activation memory for these """ - @staticmethod - def forward_mha(q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype): - scaling = head_dim**-0.5 - matmul_result = torch.empty( - bsz * (embed_dim_per_partition // head_dim), - seq_len, - seq_len, - dtype=dtype, - device=torch.cuda.current_device(), - ) - # Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math - matmul_result = torch.baddbmm( - matmul_result, - math.sqrt(scaling) * q.transpose(0, 1), - math.sqrt(scaling) * k.transpose(0, 1).transpose(1, 2), - beta=0.0, - ) - # attn_probs = matmul_result - scale_t = torch.tensor([1.0]) - attn_probs = scaled_upper_triang_masked_softmax_cuda.forward( - matmul_result, scale_t[0] - ) - attn = torch.bmm(attn_probs, v) - attn = attn.transpose(0, 1).contiguous().view(seq_len, bsz, -1) - return attn, attn_probs - - @staticmethod - def backward_mha(grad_mha_output, q, k, v, attn_probs, seq_len, bsz, head_dim): - scaling = head_dim**-0.5 - grad_mha_output = grad_mha_output.view(seq_len, -1, head_dim).transpose(0, 1) - grad_v = ( - torch.bmm(attn_probs.transpose(1, 2), grad_mha_output) - .transpose(0, 1) - .contiguous() - .view(seq_len, bsz, -1) - ) - grad_attn_probs_out = torch.bmm(grad_mha_output, v.transpose(1, 2)) - - grad_attn_probs_in = scaled_upper_triang_masked_softmax_cuda.backward( - grad_attn_probs_out, attn_probs, 1.0 - ) - grad_q = torch.bmm( - math.sqrt(scaling) * grad_attn_probs_in, - math.sqrt(scaling) * k.transpose(0, 1), - ) - grad_q = grad_q.transpose(0, 1).contiguous().view(seq_len, bsz, -1) - grad_k = torch.bmm( - math.sqrt(scaling) * grad_attn_probs_in.transpose(1, 2), - math.sqrt(scaling) * q.transpose(0, 1), - ) - grad_k = grad_k.transpose(0, 1).contiguous().view(seq_len, bsz, -1) - grad_kvq_proj_output = torch.cat([grad_k, grad_v, grad_q], dim=-1) - return grad_kvq_proj_output - @staticmethod def forward( ctx, @@ -139,16 +104,17 @@ def forward( k, v, q = split_tensor_along_last_dim(kvq_out, 3, contiguous_split_chunks=True) seq_len, bsz, embed_dim_per_partition = q.size() - q = q.view(seq_len, -1, head_dim) - k = k.view(seq_len, -1, head_dim) - v = v.view(seq_len, -1, head_dim).transpose(0, 1) + q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - attn, _ = SequeuceParallelTransformerBlock.forward_mha( - q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype - ) + attn = xops.MemoryEfficientAttentionFlashAttentionOp.forward( + None, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + ).view(seq_len, bsz, -1) out_proj_out = torch.matmul(attn, out_proj_weight.t()) out_proj_out = _reduce_scatter_along_first_dim(out_proj_out) + out_proj_out = out_proj_out.view_as(residual) out_proj_out = out_proj_out + residual @@ -320,9 +286,10 @@ def backward(ctx, grad_output): ) # recalculate attention - attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( - q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype - ) + fake_ctx = _FakeContext() + attn = xops.MemoryEfficientAttentionFlashAttentionOp.forward( + fake_ctx, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + ).view(seq_len, bsz, -1) handle.wait() @@ -335,9 +302,13 @@ def backward(ctx, grad_output): attn = SequeuceParallelTransformerBlock._collapse_first_dimensions(attn) grad_out_proj_weight = grad_attention_output.t().matmul(attn) - grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha( - grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim + d_q, d_k, d_v, _, _ = xops.MemoryEfficientAttentionFlashAttentionOp.backward( + fake_ctx, grad_out_proj_input ) + d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) + d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) + d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) + grad_kvq_proj_output = torch.cat([d_k, d_v, d_q], dim=-1) ( mha_layer_norm_output, From 2830c3de7bcee9e6d053c73f18810d1ad0db2bf3 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 00:24:34 +0000 Subject: [PATCH 02/28] change to faster flash attn --- .../modules/sequence_parallel_transformer_layer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 874610a08..dd6756901 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -108,7 +108,7 @@ def forward( k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - attn = xops.MemoryEfficientAttentionFlashAttentionOp.forward( + attn = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.forward( None, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 ).view(seq_len, bsz, -1) @@ -287,7 +287,7 @@ def backward(ctx, grad_output): # recalculate attention fake_ctx = _FakeContext() - attn = xops.MemoryEfficientAttentionFlashAttentionOp.forward( + attn = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.forward( fake_ctx, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 ).view(seq_len, bsz, -1) @@ -302,7 +302,7 @@ def backward(ctx, grad_output): attn = SequeuceParallelTransformerBlock._collapse_first_dimensions(attn) grad_out_proj_weight = grad_attention_output.t().matmul(attn) - d_q, d_k, d_v, _, _ = xops.MemoryEfficientAttentionFlashAttentionOp.backward( + d_q, d_k, d_v, _, _ = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.backward( fake_ctx, grad_out_proj_input ) d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) From 6dc5006a3b88492336b7815672af9742f5038923 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 00:31:52 +0000 Subject: [PATCH 03/28] add back standard attention --- .../sequence_parallel_transformer_layer.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index dd6756901..8651c6ebb 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -55,6 +55,60 @@ class SequeuceParallelTransformerBlock(torch.autograd.Function): gelu, layernorm: always recomputed i.e. no activation memory for these """ + @staticmethod + def forward_mha(q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype): + scaling = head_dim**-0.5 + matmul_result = torch.empty( + bsz * (embed_dim_per_partition // head_dim), + seq_len, + seq_len, + dtype=dtype, + device=torch.cuda.current_device(), + ) + # Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math + matmul_result = torch.baddbmm( + matmul_result, + math.sqrt(scaling) * q.transpose(0, 1), + math.sqrt(scaling) * k.transpose(0, 1).transpose(1, 2), + beta=0.0, + ) + # attn_probs = matmul_result + scale_t = torch.tensor([1.0]) + attn_probs = scaled_upper_triang_masked_softmax_cuda.forward( + matmul_result, scale_t[0] + ) + attn = torch.bmm(attn_probs, v) + attn = attn.transpose(0, 1).contiguous().view(seq_len, bsz, -1) + return attn, attn_probs + + @staticmethod + def backward_mha(grad_mha_output, q, k, v, attn_probs, seq_len, bsz, head_dim): + scaling = head_dim**-0.5 + grad_mha_output = grad_mha_output.view(seq_len, -1, head_dim).transpose(0, 1) + grad_v = ( + torch.bmm(attn_probs.transpose(1, 2), grad_mha_output) + .transpose(0, 1) + .contiguous() + .view(seq_len, bsz, -1) + ) + grad_attn_probs_out = torch.bmm(grad_mha_output, v.transpose(1, 2)) + + grad_attn_probs_in = scaled_upper_triang_masked_softmax_cuda.backward( + grad_attn_probs_out, attn_probs, 1.0 + ) + grad_q = torch.bmm( + math.sqrt(scaling) * grad_attn_probs_in, + math.sqrt(scaling) * k.transpose(0, 1), + ) + grad_q = grad_q.transpose(0, 1).contiguous().view(seq_len, bsz, -1) + grad_k = torch.bmm( + math.sqrt(scaling) * grad_attn_probs_in.transpose(1, 2), + math.sqrt(scaling) * q.transpose(0, 1), + ) + grad_k = grad_k.transpose(0, 1).contiguous().view(seq_len, bsz, -1) + grad_kvq_proj_output = torch.cat([grad_k, grad_v, grad_q], dim=-1) + return grad_kvq_proj_output + @staticmethod def forward( ctx, From 55af48ffb63b5644de2a10162b6a9ecfba9ccf3a Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 01:13:53 +0000 Subject: [PATCH 04/28] gate mem efficient attn behind a flag --- .../sequence_parallel_transformer_layer.py | 69 ++++++++++++++----- metaseq/modules/transformer_decoder_layer.py | 2 + 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 8651c6ebb..2716a6f86 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -9,6 +9,7 @@ import math import torch from types import SimpleNamespace +from metaseq.dataclass.constants import AttentionVariants # Not importing here cause cpu tests don't like it global fused_layer_norm_cuda @@ -120,10 +121,26 @@ def forward( head_dim, recompute_fc1, activation_fn_name, # "relu" or "gelu" for now + attn_variant, + xf_attn_op, ): assert ( activation_fn_name == "relu" or activation_fn_name == "gelu" ), "Only relu/gelu is supported!" + + xf_eff_attn = attn_variant == AttentionVariants.XFORMERS + if xf_eff_attn and not has_xformers: + raise ImportError( + "\n\nPlease install xformers to use memory efficient attention" + ) + + xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp + if xf_eff_attn and xf_attn_op is not None: + try: + xf_op = getattr(xops, xf_attn_op) + except AttributeError: + logging.warning(f"Invalid xformers memorry efficient op specified.") + # import from apex global fused_layer_norm_cuda if fused_layer_norm_cuda is None: @@ -158,13 +175,23 @@ def forward( k, v, q = split_tensor_along_last_dim(kvq_out, 3, contiguous_split_chunks=True) seq_len, bsz, embed_dim_per_partition = q.size() - q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - attn = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.forward( - None, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 - ).view(seq_len, bsz, -1) + if xf_eff_attn: + q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + + attn = xf_op.forward( + None, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + ).view(seq_len, bsz, -1) + else: + q = q.view(seq_len, -1, head_dim) + k = k.view(seq_len, -1, head_dim) + v = v.view(seq_len, -1, head_dim).transpose(0, 1) + + attn, _ = SequeuceParallelTransformerBlock.forward_mha( + q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype + ) out_proj_out = torch.matmul(attn, out_proj_weight.t()) out_proj_out = _reduce_scatter_along_first_dim(out_proj_out) @@ -340,10 +367,15 @@ def backward(ctx, grad_output): ) # recalculate attention - fake_ctx = _FakeContext() - attn = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.forward( - fake_ctx, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 - ).view(seq_len, bsz, -1) + if xf_eff_attn: + fake_ctx = _FakeContext() + attn = xf_op.forward( + fake_ctx, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + ).view(seq_len, bsz, -1) + else: + attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( + q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype + ) handle.wait() @@ -356,13 +388,16 @@ def backward(ctx, grad_output): attn = SequeuceParallelTransformerBlock._collapse_first_dimensions(attn) grad_out_proj_weight = grad_attention_output.t().matmul(attn) - d_q, d_k, d_v, _, _ = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp.backward( - fake_ctx, grad_out_proj_input - ) - d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) - d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) - d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) - grad_kvq_proj_output = torch.cat([d_k, d_v, d_q], dim=-1) + if xf_eff_attn: + d_q, d_k, d_v, _, _ = xf_op.backward(fake_ctx, grad_out_proj_input) + d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) + d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) + d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) + grad_kvq_proj_output = torch.cat([d_k, d_v, d_q], dim=-1) + else: + grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha( + grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim + ) ( mha_layer_norm_output, diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index 56cd35d9d..8e5d7ddd1 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -219,6 +219,8 @@ def forward( self.self_attn.head_dim, recompute_fc1, self.activation_fn_name, + attn_variant=getattr(args, "attn_variant", "default") + xf_attn_op=getattr(args, "xf_attn_op", None) ) return x From 25270a6defca6b7fe518e2ff99b62687dc64d36a Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 01:57:23 +0000 Subject: [PATCH 05/28] cleanup --- metaseq/modules/transformer_decoder_layer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index 8e5d7ddd1..b7336b9ff 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -70,6 +70,9 @@ def __init__( self.activation_fn_name = getattr(args, "activation_fn", "relu") or "relu" self.skip_bias_add = (self.activation_fn_name == "gelu") and has_fused_bias_gelu + self.attn_variant = getattr(args, "attn_variant", "default") + sefl.xf_attn_op = getattr(args, "xf_attn_op", None) + # TODO[Susan]: Clean up these kwargs when unifying method signatures between model & non-model parallel. fc1_kwargs = { "initialize_params_on_gpu": initialize_params_on_gpu, @@ -219,8 +222,8 @@ def forward( self.self_attn.head_dim, recompute_fc1, self.activation_fn_name, - attn_variant=getattr(args, "attn_variant", "default") - xf_attn_op=getattr(args, "xf_attn_op", None) + self.attn_variant, + self.xf_attn_op, ) return x From 0514d6185b59fbe1688d443cbcd53da443b9a6ef Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 02:10:29 +0000 Subject: [PATCH 06/28] save flags for backwards --- .../sequence_parallel_transformer_layer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 2716a6f86..550bbcec7 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -242,12 +242,16 @@ def forward( ctx.head_dim, ctx.embed_dim_per_partition, ctx.activation_fn_name, + ctx.xf_eff_attn, + ctx.xf_op, ) = ( bsz, seq_len, head_dim, embed_dim_per_partition, activation_fn_name, + xf_eff_attn, + xf_op, ) # apply scatter gather, @@ -271,12 +275,22 @@ def backward(ctx, grad_output): fc1_weight, fc2_weight, ) = ctx.saved_tensors - bsz, seq_len, head_dim, embed_dim_per_partition, activation_fn_name = ( + ( + bsz, + seq_len, + head_dim, + embed_dim_per_partition, + activation_fn_name, + xf_eff_attn, + xf_op, + ) = ( ctx.bsz, ctx.seq_len, ctx.head_dim, ctx.embed_dim_per_partition, ctx.activation_fn_name, + ctx.xf_eff_attn, + ctx.xf_op, ) dtype = grad_output.dtype From d3605260d915c33bc2dcdacdd44a4df8c822fc61 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 30 Nov 2022 02:24:11 +0000 Subject: [PATCH 07/28] linting --- .../modules/sequence_parallel_transformer_layer.py | 1 + metaseq/modules/transformer_decoder_layer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 550bbcec7..6e0467223 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -6,6 +6,7 @@ from metaseq.modules.activation_functions import gelu, gelu_back, relu, relu_back import importlib +import logging import math import torch from types import SimpleNamespace diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index b7336b9ff..5ca9a66ea 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -71,7 +71,7 @@ def __init__( self.skip_bias_add = (self.activation_fn_name == "gelu") and has_fused_bias_gelu self.attn_variant = getattr(args, "attn_variant", "default") - sefl.xf_attn_op = getattr(args, "xf_attn_op", None) + self.xf_attn_op = getattr(args, "xf_attn_op", None) # TODO[Susan]: Clean up these kwargs when unifying method signatures between model & non-model parallel. fc1_kwargs = { From de36a374fe37107dbac9347dd201764fad2c0964 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 1 Dec 2022 01:54:32 +0000 Subject: [PATCH 08/28] add test --- ...est_sequence_parallel_transformer_layer.py | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/test_sequence_parallel_transformer_layer.py diff --git a/tests/test_sequence_parallel_transformer_layer.py b/tests/test_sequence_parallel_transformer_layer.py new file mode 100644 index 000000000..5f800ed97 --- /dev/null +++ b/tests/test_sequence_parallel_transformer_layer.py @@ -0,0 +1,87 @@ +import os +import unittest +import random +import tempfile +import torch +from types import SimpleNamespace +from megatron.mpu import initialize_model_parallel +from metaseq.model_parallel.modules import ModelParallelTransformerDecoderLayer + + +def reset_seeds(): + torch.manual_seed(42) + torch.cuda.manual_seed(42) + random.seed(42) + + +def _distributed_init(): + backend = "nccl" + local_rank = None + rank = 0 + world_size = 1 + device = 0 + torch.cuda.set_device(device) + + # Call the init process. + init_method = "tcp://" + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6000") + init_method += master_ip + ":" + master_port + torch.distributed.init_process_group( + backend=backend, world_size=world_size, rank=rank, init_method=init_method + ) + + +# TODO: add dtype +class TestParity(unittest.TestCase): + def test_xformers_parity(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + + _distributed_init() + tensor_model_parallel_size_ = 1 + initialize_model_parallel(tensor_model_parallel_size_) + + args = SimpleNamespace( + sequence_parallel=True, + decoder_embed_dim=64, + dropout=0.0, + decoder_attention_heads=1, + decoder_ffn_embed_dim=64, + decoder_layers=1, + attention_dropout=0.0, + ) + S, B, E = 128, 2, 64 + x = torch.rand((S, B, E), device="cuda", requires_grad=False) + x_ = x.clone() + x.requires_grad = True + x_.requires_grad = True + + xf_attn_variant = "xformers_default" + std_attn_variant = "default" + + # xformers + args.attn_variant = xf_attn_variant + reset_seeds() + xf_decoder = ModelParallelTransformerDecoderLayer(args).cuda() + xf_result = xf_decoder(x) + + # std attn + args.attn_variant = std_attn_variant + reset_seeds() + decoder = TransformerDecoderLayer(args).cuda() + result = decoder(x_) + + assert torch.allclose(xf_result, result) + + loss_xf = torch.norm(xf_result) + loss_xf.backward() + + loss = torch.norm(result) + loss.backward() + + assert torch.allclose(x.grad, x_.grad) + + +if __name__ == "__main__": + unittest.main() From 9b09c89123e6961804c74762ce3743dacd2dedb3 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 1 Dec 2022 20:37:02 +0000 Subject: [PATCH 09/28] move test to gpu tests --- .../test_sequence_parallel_transformer_layer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) rename {tests => gpu_tests}/test_sequence_parallel_transformer_layer.py (85%) diff --git a/tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py similarity index 85% rename from tests/test_sequence_parallel_transformer_layer.py rename to gpu_tests/test_sequence_parallel_transformer_layer.py index 5f800ed97..ac44279d2 100644 --- a/tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -1,7 +1,6 @@ import os import unittest import random -import tempfile import torch from types import SimpleNamespace from megatron.mpu import initialize_model_parallel @@ -16,17 +15,18 @@ def reset_seeds(): def _distributed_init(): backend = "nccl" - local_rank = None rank = 0 world_size = 1 device = 0 torch.cuda.set_device(device) + # Call the init process. init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") master_port = os.getenv("MASTER_PORT", "6000") init_method += master_ip + ":" + master_port + init_method="file:///d:/tmp/some_file" torch.distributed.init_process_group( backend=backend, world_size=world_size, rank=rank, init_method=init_method ) @@ -69,8 +69,10 @@ def test_xformers_parity(self): # std attn args.attn_variant = std_attn_variant reset_seeds() - decoder = TransformerDecoderLayer(args).cuda() + decoder = ModelParallelTransformerDecoderLayer((args).cuda() result = decoder(x_) + + torch.distributed.barrier() assert torch.allclose(xf_result, result) @@ -80,8 +82,16 @@ def test_xformers_parity(self): loss = torch.norm(result) loss.backward() + torch.distributed.barrier() assert torch.allclose(x.grad, x_.grad) + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + if __name__ == "__main__": unittest.main() From 1d42bba1bfe5ea1897474b102ab2874d7ec3c63c Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 1 Dec 2022 20:38:24 +0000 Subject: [PATCH 10/28] fix --- gpu_tests/test_sequence_parallel_transformer_layer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index ac44279d2..bce898e5b 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -20,13 +20,12 @@ def _distributed_init(): device = 0 torch.cuda.set_device(device) - # Call the init process. init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") master_port = os.getenv("MASTER_PORT", "6000") init_method += master_ip + ":" + master_port - init_method="file:///d:/tmp/some_file" + init_method = "file:///d:/tmp/some_file" torch.distributed.init_process_group( backend=backend, world_size=world_size, rank=rank, init_method=init_method ) @@ -69,9 +68,9 @@ def test_xformers_parity(self): # std attn args.attn_variant = std_attn_variant reset_seeds() - decoder = ModelParallelTransformerDecoderLayer((args).cuda() + decoder = ModelParallelTransformerDecoderLayer(args).cuda() result = decoder(x_) - + torch.distributed.barrier() assert torch.allclose(xf_result, result) @@ -90,7 +89,7 @@ def test_xformers_parity(self): torch.distributed.barrier() if torch.distributed.get_rank() == 0: - print('>> passed the test :-)') + print(">> passed the test :-)") if __name__ == "__main__": From 66d52f859675b1e16f98fd5321d1c0e4af58a704 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 1 Dec 2022 22:50:53 +0000 Subject: [PATCH 11/28] lint --- gpu_tests/test_sequence_parallel_transformer_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index bce898e5b..941656fbd 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -3,7 +3,7 @@ import random import torch from types import SimpleNamespace -from megatron.mpu import initialize_model_parallel +from megatron.mpu import destroy_model_parallel, initialize_model_parallel from metaseq.model_parallel.modules import ModelParallelTransformerDecoderLayer @@ -85,7 +85,7 @@ def test_xformers_parity(self): assert torch.allclose(x.grad, x_.grad) # Reset groups - mpu.destroy_model_parallel() + destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: From 2dbb1ca9b1b05751bcdc75aee26d6111557caf9e Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 2 Dec 2022 00:25:46 +0000 Subject: [PATCH 12/28] fix tests --- .../test_sequence_parallel_transformer_layer.py | 10 +++++----- .../modules/sequence_parallel_transformer_layer.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index 941656fbd..a5356dde5 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -22,16 +22,14 @@ def _distributed_init(): # Call the init process. init_method = "tcp://" - master_ip = os.getenv("MASTER_ADDR", "localhost") - master_port = os.getenv("MASTER_PORT", "6000") + master_ip = "localhost" + master_port = "6000" init_method += master_ip + ":" + master_port - init_method = "file:///d:/tmp/some_file" torch.distributed.init_process_group( backend=backend, world_size=world_size, rank=rank, init_method=init_method ) -# TODO: add dtype class TestParity(unittest.TestCase): def test_xformers_parity(self): if not torch.cuda.is_available(): @@ -51,7 +49,9 @@ def test_xformers_parity(self): attention_dropout=0.0, ) S, B, E = 128, 2, 64 - x = torch.rand((S, B, E), device="cuda", requires_grad=False) + x = torch.rand( + (S, B, E), device="cuda", dtype=torch.float16, requires_grad=False + ) x_ = x.clone() x.requires_grad = True x_.requires_grad = True diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 6e0467223..0dc8afd63 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -182,8 +182,8 @@ def forward( k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - attn = xf_op.forward( - None, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + attn = xf_op.forward_no_grad( + q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None ).view(seq_len, bsz, -1) else: q = q.view(seq_len, -1, head_dim) @@ -385,7 +385,13 @@ def backward(ctx, grad_output): if xf_eff_attn: fake_ctx = _FakeContext() attn = xf_op.forward( - fake_ctx, q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0 + fake_ctx, + q, + k, + v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, ).view(seq_len, bsz, -1) else: attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( From 02717ddd7ead5711f4b1d3ff585a5c4750ff37c3 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 2 Dec 2022 00:34:18 +0000 Subject: [PATCH 13/28] add args --- gpu_tests/test_sequence_parallel_transformer_layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index a5356dde5..487152a22 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -47,8 +47,10 @@ def test_xformers_parity(self): decoder_ffn_embed_dim=64, decoder_layers=1, attention_dropout=0.0, + memory_efficient_fp16=True, + bf16=False, ) - S, B, E = 128, 2, 64 + S, B, E = 128, 1, 64 x = torch.rand( (S, B, E), device="cuda", dtype=torch.float16, requires_grad=False ) From 674091832bccd5dd552fe6796d71cd9f5f0ee944 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 2 Dec 2022 18:36:46 +0000 Subject: [PATCH 14/28] install xformers for gpu tests --- .circleci/config.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 877ed569f..6e3d69492 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -68,6 +68,22 @@ install_fairscale: &install_fairscale cd ~/ fi + +install_xformers: &install_xformers + - run: + name: Install xFormers from Source + working_directory: ~/ + command: | + source activate metaseq + if ! python -c 'import xformers'; then + git clone https://github.com/facebookresearch/xformers.git + cd xformers + git checkout 4c06c79095ba8073dc870ee820e36a83302ae30c + git submodule update --init --recursive + pip install . + cd ~/ + fi + install_dep_pt19: &install_dep_pt19 - run: name: Install Pytorch Dependencies @@ -155,6 +171,7 @@ commands: steps: - <<: *install_dep_common - <<: *install_fairscale + - <<: *install_xformers - <<: *install_dep_fused_ops - <<: *install_repo - <<: *download_and_configure_125m_with_hf_dependencies From f68a3d8ee0fd4b427f3aaf80894fb506e3096160 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 2 Dec 2022 19:07:11 +0000 Subject: [PATCH 15/28] update shapes --- gpu_tests/test_sequence_parallel_transformer_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index 487152a22..aca8900fe 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -50,7 +50,7 @@ def test_xformers_parity(self): memory_efficient_fp16=True, bf16=False, ) - S, B, E = 128, 1, 64 + S, B, E = 64, 128, 64 x = torch.rand( (S, B, E), device="cuda", dtype=torch.float16, requires_grad=False ) From 3b1cb73deed72827271ee1ae82e9c89ad2f36b4b Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 2 Dec 2022 19:34:11 +0000 Subject: [PATCH 16/28] skip if xformers not available --- ...est_sequence_parallel_transformer_layer.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index aca8900fe..af9f17008 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -1,6 +1,6 @@ -import os import unittest import random +import sys import torch from types import SimpleNamespace from megatron.mpu import destroy_model_parallel, initialize_model_parallel @@ -30,10 +30,30 @@ def _distributed_init(): ) +def _allclose(out, ref, atol, rtol, msg="failed"): + flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() + max_pos = flatten_diff.argmax() + max_diff = flatten_diff[max_pos] + num_different = torch.count_nonzero(flatten_diff > 0) + percentage = num_different / flatten_diff.numel() + del flatten_diff + assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( + f"{msg}: " + f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" + f"/ atol={atol}, rtol={rtol}" + f"/ total failing elements: {num_different}, percentage={percentage}" + ) + + class TestParity(unittest.TestCase): def test_xformers_parity(self): if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") + if "xformers" not in sys.modules: + raise unittest.SkipTest("xformers not available, skipping test") + + atol = 4e-3 + rtol = 4e-4 _distributed_init() tensor_model_parallel_size_ = 1 @@ -73,9 +93,7 @@ def test_xformers_parity(self): decoder = ModelParallelTransformerDecoderLayer(args).cuda() result = decoder(x_) - torch.distributed.barrier() - - assert torch.allclose(xf_result, result) + assert _allclose(xf_result, result, atol=atol, rtol=rtol) loss_xf = torch.norm(xf_result) loss_xf.backward() @@ -84,7 +102,7 @@ def test_xformers_parity(self): loss.backward() torch.distributed.barrier() - assert torch.allclose(x.grad, x_.grad) + assert torch.allclose(x.grad, x_.grad, atol=atol, rtol=rtol) # Reset groups destroy_model_parallel() From 7db010b81482951c0bfdfa6b86491f8b45b48fa0 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 14 Dec 2022 23:38:29 +0000 Subject: [PATCH 17/28] use Triton since fastest for zucchini shape, fix reshaping of attn --- .../modules/sequence_parallel_transformer_layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 0dc8afd63..a916e4b74 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -135,7 +135,8 @@ def forward( "\n\nPlease install xformers to use memory efficient attention" ) - xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp + # xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp + xf_op = xops.MemoryEfficientAttentionTritonFwdFlashBwOp if xf_eff_attn and xf_attn_op is not None: try: xf_op = getattr(xops, xf_attn_op) @@ -182,9 +183,12 @@ def forward( k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + # bmhk -> m, b*h, k attn = xf_op.forward_no_grad( q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None - ).view(seq_len, bsz, -1) + ) # .permute((0, 2, 1, 3)).view(-1, seq_len, head_dim).transpose(0, 1) + attn = attn.transpose(0, 1).view(seq_len, -1, head_dim) + # OR transpose 0/1, then view else: q = q.view(seq_len, -1, head_dim) k = k.view(seq_len, -1, head_dim) From 546458fd9c41aa02d54c6a3c253174a9a7d209a9 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 14 Dec 2022 23:39:54 +0000 Subject: [PATCH 18/28] cleaner reshaping --- .../modules/sequence_parallel_transformer_layer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index a916e4b74..535b62a9f 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -183,12 +183,13 @@ def forward( k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - # bmhk -> m, b*h, k - attn = xf_op.forward_no_grad( - q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None - ) # .permute((0, 2, 1, 3)).view(-1, seq_len, head_dim).transpose(0, 1) - attn = attn.transpose(0, 1).view(seq_len, -1, head_dim) - # OR transpose 0/1, then view + attn = ( + xf_op.forward_no_grad( + q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None + ) + .transpose(0, 1) + .view(seq_len, -1, head_dim) + ) else: q = q.view(seq_len, -1, head_dim) k = k.view(seq_len, -1, head_dim) From 156c057f162093232446672ca6ad35077c4ebbd8 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Wed, 14 Dec 2022 23:45:43 +0000 Subject: [PATCH 19/28] clean up tests --- gpu_tests/test_sequence_parallel_transformer_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index af9f17008..07a712e77 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -37,7 +37,7 @@ def _allclose(out, ref, atol, rtol, msg="failed"): num_different = torch.count_nonzero(flatten_diff > 0) percentage = num_different / flatten_diff.numel() del flatten_diff - assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( + return torch.allclose(out, ref, rtol=rtol, atol=atol), ( f"{msg}: " f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" f"/ atol={atol}, rtol={rtol}" @@ -93,6 +93,7 @@ def test_xformers_parity(self): decoder = ModelParallelTransformerDecoderLayer(args).cuda() result = decoder(x_) + torch.distributed.barrier() assert _allclose(xf_result, result, atol=atol, rtol=rtol) loss_xf = torch.norm(xf_result) @@ -102,14 +103,14 @@ def test_xformers_parity(self): loss.backward() torch.distributed.barrier() - assert torch.allclose(x.grad, x_.grad, atol=atol, rtol=rtol) + assert _allclose(x.grad, x_.grad, atol=atol, rtol=rtol) # Reset groups destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: - print(">> passed the test :-)") + print(">> passed the test") if __name__ == "__main__": From dea3360fe3ed48393518d47f0ef71c38d0b009a9 Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 5 Dec 2022 14:00:25 -0500 Subject: [PATCH 20/28] Do not install xFormers in circleCI, need updated Cuda --- .circleci/config.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6e3d69492..388b26cac 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -68,7 +68,6 @@ install_fairscale: &install_fairscale cd ~/ fi - install_xformers: &install_xformers - run: name: Install xFormers from Source @@ -171,7 +170,6 @@ commands: steps: - <<: *install_dep_common - <<: *install_fairscale - - <<: *install_xformers - <<: *install_dep_fused_ops - <<: *install_repo - <<: *download_and_configure_125m_with_hf_dependencies From 832340aadec34620d4749246d6eeedb49e4d597c Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 15 Dec 2022 18:03:45 +0000 Subject: [PATCH 21/28] clean up tests --- .../test_sequence_parallel_transformer_layer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index 07a712e77..725d55062 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -30,14 +30,14 @@ def _distributed_init(): ) -def _allclose(out, ref, atol, rtol, msg="failed"): +def _assert_allclose(out, ref, atol, rtol, msg="failed"): flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() max_pos = flatten_diff.argmax() max_diff = flatten_diff[max_pos] num_different = torch.count_nonzero(flatten_diff > 0) percentage = num_different / flatten_diff.numel() del flatten_diff - return torch.allclose(out, ref, rtol=rtol, atol=atol), ( + assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( f"{msg}: " f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" f"/ atol={atol}, rtol={rtol}" @@ -94,16 +94,14 @@ def test_xformers_parity(self): result = decoder(x_) torch.distributed.barrier() - assert _allclose(xf_result, result, atol=atol, rtol=rtol) + _assert_allclose(xf_result, result, atol=atol, rtol=rtol) - loss_xf = torch.norm(xf_result) - loss_xf.backward() - - loss = torch.norm(result) - loss.backward() + # Test Backwards + xf_result.backward(torch.ones_like(x)) + result.backward(torch.ones_like(x_)) torch.distributed.barrier() - assert _allclose(x.grad, x_.grad, atol=atol, rtol=rtol) + _assert_allclose(x.grad, x_.grad, atol=atol, rtol=rtol) # Reset groups destroy_model_parallel() From d33855c182a7db771ca45d4ae62917cc3e9a8414 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 15 Dec 2022 20:49:59 +0000 Subject: [PATCH 22/28] fixing bwd, some tmp changes --- .../sequence_parallel_transformer_layer.py | 69 ++++++++++++++----- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 535b62a9f..2fc11fbb9 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -30,6 +30,7 @@ ) from megatron.mpu.utils import split_tensor_along_last_dim from megatron.model.fused_softmax import scaled_upper_triang_masked_softmax_cuda + from megatron.mpu import get_tensor_model_parallel_world_size has_megatron_submodule = True except (ImportError, ModuleNotFoundError): @@ -135,8 +136,8 @@ def forward( "\n\nPlease install xformers to use memory efficient attention" ) - # xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp - xf_op = xops.MemoryEfficientAttentionTritonFwdFlashBwOp + xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp + # xf_op = xops.MemoryEfficientAttentionTritonFwdFlashBwOp if xf_eff_attn and xf_attn_op is not None: try: xf_op = getattr(xops, xf_attn_op) @@ -184,8 +185,14 @@ def forward( v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) attn = ( - xf_op.forward_no_grad( - q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None + xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, + op=xf_op[0], ) .transpose(0, 1) .view(seq_len, -1, head_dim) @@ -321,7 +328,9 @@ def backward(ctx, grad_output): actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) # Now wait for reduce scatter - handle.wait() + world_size = get_tensor_model_parallel_world_size() + if world_size != 1: + handle.wait() ffn_layer_norm_output, handle = _gather_along_first_dim( ffn_layer_norm_output, async_op=True, cached_buffer_name="mpu" @@ -330,7 +339,8 @@ def backward(ctx, grad_output): grad_fc2_input = grad_output.matmul(fc2_weight) if ctx.recompute_fc1: - handle.wait() + if world_size != 1: + handle.wait() assert fc1_out is None fc1_out = torch.matmul(ffn_layer_norm_output, fc1_weight.t()) actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) @@ -350,7 +360,8 @@ def backward(ctx, grad_output): grad_fc2_weight = grad_output.t().matmul(actv_out) grad_fc1_input = grad_actv_input.matmul(fc1_weight) - handle.wait() + if world_size != 1: + handle.wait() grad_actv_input = SequeuceParallelTransformerBlock._collapse_first_dimensions( grad_actv_input @@ -367,7 +378,8 @@ def backward(ctx, grad_output): grad_fc1_weight = grad_actv_input.t().matmul(ffn_layer_norm_output) - handle.wait() + if world_size != 1: + handle.wait() grad_attention_output = fused_layer_norm_cuda.backward( grad_fc1_input.contiguous(), @@ -388,22 +400,25 @@ def backward(ctx, grad_output): # recalculate attention if xf_eff_attn: - fake_ctx = _FakeContext() - attn = xf_op.forward( - fake_ctx, + # TODO: reshape q/k/v? + attn, lse = xops.memory_efficient_attention_forward_requires_grad( q, k, v, attn_bias=xops.LowerTriangularMask(), p=0.0, scale=None, - ).view(seq_len, bsz, -1) + op=xf_op[0], + ) + out = attn + attn = attn.transpose(0, 1).reshape(seq_len, -1, head_dim).contiguous() else: attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype ) - handle.wait() + if world_size != 1: + handle.wait() grad_out_proj_input = grad_attention_output.matmul(out_proj_weight) grad_attention_output = ( @@ -415,7 +430,22 @@ def backward(ctx, grad_output): grad_out_proj_weight = grad_attention_output.t().matmul(attn) if xf_eff_attn: - d_q, d_k, d_v, _, _ = xf_op.backward(fake_ctx, grad_out_proj_input) + grad_out_proj_input = grad_out_proj_input.reshape( + seq_len, bsz, -1, head_dim + ) + d_q, d_k, d_v = xops.memory_efficient_attention_backward( + grad=grad_out_proj_input, + output=out, + lse=lse, + query=q, + key=k, + value=v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, + op=xf_op[1], + ) + # bmhk => m bh k d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) @@ -425,6 +455,9 @@ def backward(ctx, grad_output): grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim ) + print(f"Diana Debug: shape of dq in std = {grad_kvq_proj_output[2].shape}") + print(f"Diana Debug: seq_len={seq_len}, bsz={bsz}, head_dim={head_dim}") + ( mha_layer_norm_output, mha_layer_norm_mean, @@ -438,7 +471,8 @@ def backward(ctx, grad_output): cached_buffer_name="mpu", ) grad_input = grad_kvq_proj_output.matmul(kvq_proj_weight) - handle.wait() + if world_size != 1: + handle.wait() grad_input, handle = _reduce_scatter_along_first_dim(grad_input, async_op=True) mha_layer_norm_output = ( @@ -452,7 +486,8 @@ def backward(ctx, grad_output): ) ) grad_kvq_weight = grad_kvq_proj_output.t().matmul(mha_layer_norm_output) - handle.wait() + if world_size != 1: + handle.wait() grad_input = fused_layer_norm_cuda.backward( grad_input.contiguous(), @@ -472,6 +507,8 @@ def backward(ctx, grad_output): None, None, None, + None, + None, ) @staticmethod From c84ad934d3c948324e0c16e761083f5157e9909a Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Thu, 22 Dec 2022 11:06:28 -0800 Subject: [PATCH 23/28] add testing and logic for multiple heads, fix bug in bwd --- ...est_sequence_parallel_transformer_layer.py | 6 ++++- .../sequence_parallel_transformer_layer.py | 24 ++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index 725d55062..c331bed5c 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -63,7 +63,7 @@ def test_xformers_parity(self): sequence_parallel=True, decoder_embed_dim=64, dropout=0.0, - decoder_attention_heads=1, + decoder_attention_heads=2, decoder_ffn_embed_dim=64, decoder_layers=1, attention_dropout=0.0, @@ -85,19 +85,23 @@ def test_xformers_parity(self): args.attn_variant = xf_attn_variant reset_seeds() xf_decoder = ModelParallelTransformerDecoderLayer(args).cuda() + reset_seeds() xf_result = xf_decoder(x) # std attn args.attn_variant = std_attn_variant reset_seeds() decoder = ModelParallelTransformerDecoderLayer(args).cuda() + reset_seeds() result = decoder(x_) torch.distributed.barrier() _assert_allclose(xf_result, result, atol=atol, rtol=rtol) # Test Backwards + reset_seeds() xf_result.backward(torch.ones_like(x)) + reset_seeds() result.backward(torch.ones_like(x_)) torch.distributed.barrier() diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 2fc11fbb9..6c2159ef7 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -180,9 +180,10 @@ def forward( seq_len, bsz, embed_dim_per_partition = q.size() if xf_eff_attn: - q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + num_heads = embed_dim_per_partition // head_dim + q = q.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) + k = k.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) + v = v.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) attn = ( xops.memory_efficient_attention_forward( @@ -195,8 +196,9 @@ def forward( op=xf_op[0], ) .transpose(0, 1) - .view(seq_len, -1, head_dim) + .reshape(seq_len, bsz, num_heads*head_dim) ) + # TODO: Reshape q/k/v back to original? else: q = q.view(seq_len, -1, head_dim) k = k.view(seq_len, -1, head_dim) @@ -401,6 +403,12 @@ def backward(ctx, grad_output): # recalculate attention if xf_eff_attn: # TODO: reshape q/k/v? + # q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + # k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + # v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) + + num_heads = embed_dim_per_partition // head_dim + attn, lse = xops.memory_efficient_attention_forward_requires_grad( q, k, @@ -411,7 +419,7 @@ def backward(ctx, grad_output): op=xf_op[0], ) out = attn - attn = attn.transpose(0, 1).reshape(seq_len, -1, head_dim).contiguous() + attn = attn.transpose(0, 1).reshape(seq_len, bsz, num_heads*head_dim).contiguous() else: attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype @@ -432,7 +440,7 @@ def backward(ctx, grad_output): if xf_eff_attn: grad_out_proj_input = grad_out_proj_input.reshape( seq_len, bsz, -1, head_dim - ) + ).transpose(0,1) d_q, d_k, d_v = xops.memory_efficient_attention_backward( grad=grad_out_proj_input, output=out, @@ -445,7 +453,7 @@ def backward(ctx, grad_output): scale=None, op=xf_op[1], ) - # bmhk => m bh k + # bmhk => m b hk d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) @@ -455,8 +463,6 @@ def backward(ctx, grad_output): grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim ) - print(f"Diana Debug: shape of dq in std = {grad_kvq_proj_output[2].shape}") - print(f"Diana Debug: seq_len={seq_len}, bsz={bsz}, head_dim={head_dim}") ( mha_layer_norm_output, From 2ab4400c3a96a958fbf6bb467f902f234184737d Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 13:14:54 -0800 Subject: [PATCH 24/28] clean up tests and add separate tolerances for fwd and bwd --- .../test_sequence_parallel_transformer_layer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py index c331bed5c..9cd7d0def 100644 --- a/gpu_tests/test_sequence_parallel_transformer_layer.py +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -52,18 +52,23 @@ def test_xformers_parity(self): if "xformers" not in sys.modules: raise unittest.SkipTest("xformers not available, skipping test") - atol = 4e-3 - rtol = 4e-4 + fw_atol = 4e-3 + fw_rtol = 4e-4 + + bw_atol = 9e-2 + bw_rtol = 2e-2 _distributed_init() tensor_model_parallel_size_ = 1 initialize_model_parallel(tensor_model_parallel_size_) + S, B, E = 8, 16, 64 + H = 2 args = SimpleNamespace( sequence_parallel=True, - decoder_embed_dim=64, + decoder_embed_dim=E, dropout=0.0, - decoder_attention_heads=2, + decoder_attention_heads=H, decoder_ffn_embed_dim=64, decoder_layers=1, attention_dropout=0.0, @@ -96,7 +101,7 @@ def test_xformers_parity(self): result = decoder(x_) torch.distributed.barrier() - _assert_allclose(xf_result, result, atol=atol, rtol=rtol) + _assert_allclose(xf_result, result, atol=fw_atol, rtol=fw_rtol) # Test Backwards reset_seeds() @@ -105,7 +110,7 @@ def test_xformers_parity(self): result.backward(torch.ones_like(x_)) torch.distributed.barrier() - _assert_allclose(x.grad, x_.grad, atol=atol, rtol=rtol) + _assert_allclose(x.grad, x_.grad, atol=bw_atol, rtol=bw_rtol) # Reset groups destroy_model_parallel() From 722ede49b4c29ef4f9bed6ac89d49a6ab24e9021 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 13:17:30 -0800 Subject: [PATCH 25/28] remove changes to code needed for testing with world size of 1 --- .../sequence_parallel_transformer_layer.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 6c2159ef7..00a77ae2f 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -330,9 +330,7 @@ def backward(ctx, grad_output): actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) # Now wait for reduce scatter - world_size = get_tensor_model_parallel_world_size() - if world_size != 1: - handle.wait() + handle.wait() ffn_layer_norm_output, handle = _gather_along_first_dim( ffn_layer_norm_output, async_op=True, cached_buffer_name="mpu" @@ -341,8 +339,7 @@ def backward(ctx, grad_output): grad_fc2_input = grad_output.matmul(fc2_weight) if ctx.recompute_fc1: - if world_size != 1: - handle.wait() + handle.wait() assert fc1_out is None fc1_out = torch.matmul(ffn_layer_norm_output, fc1_weight.t()) actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out) @@ -362,8 +359,7 @@ def backward(ctx, grad_output): grad_fc2_weight = grad_output.t().matmul(actv_out) grad_fc1_input = grad_actv_input.matmul(fc1_weight) - if world_size != 1: - handle.wait() + handle.wait() grad_actv_input = SequeuceParallelTransformerBlock._collapse_first_dimensions( grad_actv_input @@ -380,8 +376,7 @@ def backward(ctx, grad_output): grad_fc1_weight = grad_actv_input.t().matmul(ffn_layer_norm_output) - if world_size != 1: - handle.wait() + handle.wait() grad_attention_output = fused_layer_norm_cuda.backward( grad_fc1_input.contiguous(), @@ -425,8 +420,7 @@ def backward(ctx, grad_output): q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype ) - if world_size != 1: - handle.wait() + handle.wait() grad_out_proj_input = grad_attention_output.matmul(out_proj_weight) grad_attention_output = ( @@ -477,8 +471,7 @@ def backward(ctx, grad_output): cached_buffer_name="mpu", ) grad_input = grad_kvq_proj_output.matmul(kvq_proj_weight) - if world_size != 1: - handle.wait() + handle.wait() grad_input, handle = _reduce_scatter_along_first_dim(grad_input, async_op=True) mha_layer_norm_output = ( @@ -492,8 +485,7 @@ def backward(ctx, grad_output): ) ) grad_kvq_weight = grad_kvq_proj_output.t().matmul(mha_layer_norm_output) - if world_size != 1: - handle.wait() + handle.wait() grad_input = fused_layer_norm_cuda.backward( grad_input.contiguous(), From 523f4e1356aa5d8996e09c66c13152a441236069 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 13:45:44 -0800 Subject: [PATCH 26/28] lint fixes --- .../modules/sequence_parallel_transformer_layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 00a77ae2f..0823b0936 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -30,7 +30,6 @@ ) from megatron.mpu.utils import split_tensor_along_last_dim from megatron.model.fused_softmax import scaled_upper_triang_masked_softmax_cuda - from megatron.mpu import get_tensor_model_parallel_world_size has_megatron_submodule = True except (ImportError, ModuleNotFoundError): @@ -434,7 +433,7 @@ def backward(ctx, grad_output): if xf_eff_attn: grad_out_proj_input = grad_out_proj_input.reshape( seq_len, bsz, -1, head_dim - ).transpose(0,1) + ).transpose(0, 1) d_q, d_k, d_v = xops.memory_efficient_attention_backward( grad=grad_out_proj_input, output=out, @@ -457,7 +456,6 @@ def backward(ctx, grad_output): grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim ) - ( mha_layer_norm_output, mha_layer_norm_mean, From 9aa46d2bb61843c9b258e5b4523c8ad8b80f2ac8 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 14:07:14 -0800 Subject: [PATCH 27/28] formatting --- .../modules/sequence_parallel_transformer_layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 0823b0936..d29b7ebe5 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -195,7 +195,7 @@ def forward( op=xf_op[0], ) .transpose(0, 1) - .reshape(seq_len, bsz, num_heads*head_dim) + .reshape(seq_len, bsz, num_heads * head_dim) ) # TODO: Reshape q/k/v back to original? else: @@ -413,7 +413,11 @@ def backward(ctx, grad_output): op=xf_op[0], ) out = attn - attn = attn.transpose(0, 1).reshape(seq_len, bsz, num_heads*head_dim).contiguous() + attn = ( + attn.transpose(0, 1) + .reshape(seq_len, bsz, num_heads * head_dim) + .contiguous() + ) else: attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype From d0aa8b6e2b9818505b05dcf071d21a63c194cd49 Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:25:22 -0500 Subject: [PATCH 28/28] Clean up comments --- .../modules/sequence_parallel_transformer_layer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index d29b7ebe5..9129f81dd 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -197,7 +197,6 @@ def forward( .transpose(0, 1) .reshape(seq_len, bsz, num_heads * head_dim) ) - # TODO: Reshape q/k/v back to original? else: q = q.view(seq_len, -1, head_dim) k = k.view(seq_len, -1, head_dim) @@ -396,11 +395,6 @@ def backward(ctx, grad_output): # recalculate attention if xf_eff_attn: - # TODO: reshape q/k/v? - # q = q.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - # k = k.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - # v = v.view(seq_len, bsz, -1, head_dim).transpose(0, 1) - num_heads = embed_dim_per_partition // head_dim attn, lse = xops.memory_efficient_attention_forward_requires_grad(