diff --git a/.circleci/config.yml b/.circleci/config.yml index 877ed569f..388b26cac 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -68,6 +68,21 @@ install_fairscale: &install_fairscale cd ~/ fi +install_xformers: &install_xformers + - run: + name: Install xFormers from Source + working_directory: ~/ + command: | + source activate metaseq + if ! python -c 'import xformers'; then + git clone https://github.com/facebookresearch/xformers.git + cd xformers + git checkout 4c06c79095ba8073dc870ee820e36a83302ae30c + git submodule update --init --recursive + pip install . + cd ~/ + fi + install_dep_pt19: &install_dep_pt19 - run: name: Install Pytorch Dependencies diff --git a/gpu_tests/test_sequence_parallel_transformer_layer.py b/gpu_tests/test_sequence_parallel_transformer_layer.py new file mode 100644 index 000000000..9cd7d0def --- /dev/null +++ b/gpu_tests/test_sequence_parallel_transformer_layer.py @@ -0,0 +1,124 @@ +import unittest +import random +import sys +import torch +from types import SimpleNamespace +from megatron.mpu import destroy_model_parallel, initialize_model_parallel +from metaseq.model_parallel.modules import ModelParallelTransformerDecoderLayer + + +def reset_seeds(): + torch.manual_seed(42) + torch.cuda.manual_seed(42) + random.seed(42) + + +def _distributed_init(): + backend = "nccl" + rank = 0 + world_size = 1 + device = 0 + torch.cuda.set_device(device) + + # Call the init process. + init_method = "tcp://" + master_ip = "localhost" + master_port = "6000" + init_method += master_ip + ":" + master_port + torch.distributed.init_process_group( + backend=backend, world_size=world_size, rank=rank, init_method=init_method + ) + + +def _assert_allclose(out, ref, atol, rtol, msg="failed"): + flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() + max_pos = flatten_diff.argmax() + max_diff = flatten_diff[max_pos] + num_different = torch.count_nonzero(flatten_diff > 0) + percentage = num_different / flatten_diff.numel() + del flatten_diff + assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( + f"{msg}: " + f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" + f"/ atol={atol}, rtol={rtol}" + f"/ total failing elements: {num_different}, percentage={percentage}" + ) + + +class TestParity(unittest.TestCase): + def test_xformers_parity(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if "xformers" not in sys.modules: + raise unittest.SkipTest("xformers not available, skipping test") + + fw_atol = 4e-3 + fw_rtol = 4e-4 + + bw_atol = 9e-2 + bw_rtol = 2e-2 + + _distributed_init() + tensor_model_parallel_size_ = 1 + initialize_model_parallel(tensor_model_parallel_size_) + + S, B, E = 8, 16, 64 + H = 2 + args = SimpleNamespace( + sequence_parallel=True, + decoder_embed_dim=E, + dropout=0.0, + decoder_attention_heads=H, + decoder_ffn_embed_dim=64, + decoder_layers=1, + attention_dropout=0.0, + memory_efficient_fp16=True, + bf16=False, + ) + S, B, E = 64, 128, 64 + x = torch.rand( + (S, B, E), device="cuda", dtype=torch.float16, requires_grad=False + ) + x_ = x.clone() + x.requires_grad = True + x_.requires_grad = True + + xf_attn_variant = "xformers_default" + std_attn_variant = "default" + + # xformers + args.attn_variant = xf_attn_variant + reset_seeds() + xf_decoder = ModelParallelTransformerDecoderLayer(args).cuda() + reset_seeds() + xf_result = xf_decoder(x) + + # std attn + args.attn_variant = std_attn_variant + reset_seeds() + decoder = ModelParallelTransformerDecoderLayer(args).cuda() + reset_seeds() + result = decoder(x_) + + torch.distributed.barrier() + _assert_allclose(xf_result, result, atol=fw_atol, rtol=fw_rtol) + + # Test Backwards + reset_seeds() + xf_result.backward(torch.ones_like(x)) + reset_seeds() + result.backward(torch.ones_like(x_)) + + torch.distributed.barrier() + _assert_allclose(x.grad, x_.grad, atol=bw_atol, rtol=bw_rtol) + + # Reset groups + destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(">> passed the test") + + +if __name__ == "__main__": + unittest.main() diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 2f8925058..9129f81dd 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -6,13 +6,23 @@ from metaseq.modules.activation_functions import gelu, gelu_back, relu, relu_back import importlib +import logging import math import torch +from types import SimpleNamespace +from metaseq.dataclass.constants import AttentionVariants # Not importing here cause cpu tests don't like it global fused_layer_norm_cuda fused_layer_norm_cuda = None +try: + import xformers.ops as xops + + has_xformers = True +except (ImportError, ModuleNotFoundError): + has_xformers = False + try: from megatron.mpu.mappings import ( _reduce_scatter_along_first_dim, @@ -26,6 +36,17 @@ has_megatron_submodule = False +class _FakeContext(SimpleNamespace): + """ + Used to provide a temporary buffer for FlashAttention's saved buffers + """ + + saved_tensors = None + + def save_for_backward(self, *args): + self.saved_tensors = args + + class SequeuceParallelTransformerBlock(torch.autograd.Function): """ This is custom FFN autograd function hardcoded for: @@ -101,10 +122,27 @@ def forward( head_dim, recompute_fc1, activation_fn_name, # "relu" or "gelu" for now + attn_variant, + xf_attn_op, ): assert ( activation_fn_name == "relu" or activation_fn_name == "gelu" ), "Only relu/gelu is supported!" + + xf_eff_attn = attn_variant == AttentionVariants.XFORMERS + if xf_eff_attn and not has_xformers: + raise ImportError( + "\n\nPlease install xformers to use memory efficient attention" + ) + + xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp + # xf_op = xops.MemoryEfficientAttentionTritonFwdFlashBwOp + if xf_eff_attn and xf_attn_op is not None: + try: + xf_op = getattr(xops, xf_attn_op) + except AttributeError: + logging.warning(f"Invalid xformers memorry efficient op specified.") + # import from apex global fused_layer_norm_cuda if fused_layer_norm_cuda is None: @@ -139,16 +177,38 @@ def forward( k, v, q = split_tensor_along_last_dim(kvq_out, 3, contiguous_split_chunks=True) seq_len, bsz, embed_dim_per_partition = q.size() - q = q.view(seq_len, -1, head_dim) - k = k.view(seq_len, -1, head_dim) - v = v.view(seq_len, -1, head_dim).transpose(0, 1) - attn, _ = SequeuceParallelTransformerBlock.forward_mha( - q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype - ) + if xf_eff_attn: + num_heads = embed_dim_per_partition // head_dim + q = q.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) + k = k.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) + v = v.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1) + + attn = ( + xops.memory_efficient_attention_forward( + q, + k, + v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, + op=xf_op[0], + ) + .transpose(0, 1) + .reshape(seq_len, bsz, num_heads * head_dim) + ) + else: + q = q.view(seq_len, -1, head_dim) + k = k.view(seq_len, -1, head_dim) + v = v.view(seq_len, -1, head_dim).transpose(0, 1) + + attn, _ = SequeuceParallelTransformerBlock.forward_mha( + q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype + ) out_proj_out = torch.matmul(attn, out_proj_weight.t()) out_proj_out = _reduce_scatter_along_first_dim(out_proj_out) + out_proj_out = out_proj_out.view_as(residual) out_proj_out = out_proj_out + residual @@ -195,12 +255,16 @@ def forward( ctx.head_dim, ctx.embed_dim_per_partition, ctx.activation_fn_name, + ctx.xf_eff_attn, + ctx.xf_op, ) = ( bsz, seq_len, head_dim, embed_dim_per_partition, activation_fn_name, + xf_eff_attn, + xf_op, ) # apply scatter gather, @@ -224,12 +288,22 @@ def backward(ctx, grad_output): fc1_weight, fc2_weight, ) = ctx.saved_tensors - bsz, seq_len, head_dim, embed_dim_per_partition, activation_fn_name = ( + ( + bsz, + seq_len, + head_dim, + embed_dim_per_partition, + activation_fn_name, + xf_eff_attn, + xf_op, + ) = ( ctx.bsz, ctx.seq_len, ctx.head_dim, ctx.embed_dim_per_partition, ctx.activation_fn_name, + ctx.xf_eff_attn, + ctx.xf_op, ) dtype = grad_output.dtype @@ -320,9 +394,28 @@ def backward(ctx, grad_output): ) # recalculate attention - attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( - q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype - ) + if xf_eff_attn: + num_heads = embed_dim_per_partition // head_dim + + attn, lse = xops.memory_efficient_attention_forward_requires_grad( + q, + k, + v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, + op=xf_op[0], + ) + out = attn + attn = ( + attn.transpose(0, 1) + .reshape(seq_len, bsz, num_heads * head_dim) + .contiguous() + ) + else: + attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( + q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype + ) handle.wait() @@ -335,9 +428,31 @@ def backward(ctx, grad_output): attn = SequeuceParallelTransformerBlock._collapse_first_dimensions(attn) grad_out_proj_weight = grad_attention_output.t().matmul(attn) - grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha( - grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim - ) + if xf_eff_attn: + grad_out_proj_input = grad_out_proj_input.reshape( + seq_len, bsz, -1, head_dim + ).transpose(0, 1) + d_q, d_k, d_v = xops.memory_efficient_attention_backward( + grad=grad_out_proj_input, + output=out, + lse=lse, + query=q, + key=k, + value=v, + attn_bias=xops.LowerTriangularMask(), + p=0.0, + scale=None, + op=xf_op[1], + ) + # bmhk => m b hk + d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1) + d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1) + d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1) + grad_kvq_proj_output = torch.cat([d_k, d_v, d_q], dim=-1) + else: + grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha( + grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim + ) ( mha_layer_norm_output, @@ -386,6 +501,8 @@ def backward(ctx, grad_output): None, None, None, + None, + None, ) @staticmethod diff --git a/metaseq/modules/transformer_decoder_layer.py b/metaseq/modules/transformer_decoder_layer.py index 56cd35d9d..5ca9a66ea 100644 --- a/metaseq/modules/transformer_decoder_layer.py +++ b/metaseq/modules/transformer_decoder_layer.py @@ -70,6 +70,9 @@ def __init__( self.activation_fn_name = getattr(args, "activation_fn", "relu") or "relu" self.skip_bias_add = (self.activation_fn_name == "gelu") and has_fused_bias_gelu + self.attn_variant = getattr(args, "attn_variant", "default") + self.xf_attn_op = getattr(args, "xf_attn_op", None) + # TODO[Susan]: Clean up these kwargs when unifying method signatures between model & non-model parallel. fc1_kwargs = { "initialize_params_on_gpu": initialize_params_on_gpu, @@ -219,6 +222,8 @@ def forward( self.self_attn.head_dim, recompute_fc1, self.activation_fn_name, + self.attn_variant, + self.xf_attn_op, ) return x