Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding flash attention for sequence parallel #565

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b4539bb
[wip] Adding flash attention for sequence parallel
stephenroller Nov 11, 2022
2830c3d
change to faster flash attn
dianaml0 Nov 30, 2022
6dc5006
add back standard attention
dianaml0 Nov 30, 2022
55af48f
gate mem efficient attn behind a flag
dianaml0 Nov 30, 2022
25270a6
cleanup
dianaml0 Nov 30, 2022
0514d61
save flags for backwards
dianaml0 Nov 30, 2022
d360526
linting
dianaml0 Nov 30, 2022
de36a37
add test
dianaml0 Dec 1, 2022
9b09c89
move test to gpu tests
dianaml0 Dec 1, 2022
1d42bba
fix
dianaml0 Dec 1, 2022
66d52f8
lint
dianaml0 Dec 1, 2022
2dbb1ca
fix tests
dianaml0 Dec 2, 2022
02717dd
add args
dianaml0 Dec 2, 2022
6740918
install xformers for gpu tests
dianaml0 Dec 2, 2022
f68a3d8
update shapes
dianaml0 Dec 2, 2022
3b1cb73
skip if xformers not available
dianaml0 Dec 2, 2022
7db010b
use Triton since fastest for zucchini shape, fix reshaping of attn
dianaml0 Dec 14, 2022
546458f
cleaner reshaping
dianaml0 Dec 14, 2022
156c057
clean up tests
dianaml0 Dec 14, 2022
dea3360
Do not install xFormers in circleCI, need updated Cuda
dianaml0 Dec 5, 2022
832340a
clean up tests
dianaml0 Dec 15, 2022
d33855c
fixing bwd, some tmp changes
dianaml0 Dec 15, 2022
c84ad93
add testing and logic for multiple heads, fix bug in bwd
dianaml0 Dec 22, 2022
2ab4400
clean up tests and add separate tolerances for fwd and bwd
dianaml0 Dec 23, 2022
722ede4
remove changes to code needed for testing with world size of 1
dianaml0 Dec 23, 2022
523f4e1
lint fixes
dianaml0 Dec 23, 2022
9aa46d2
formatting
dianaml0 Dec 23, 2022
d0aa8b6
Clean up comments
dianaml0 Jan 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ install_fairscale: &install_fairscale
cd ~/
fi

install_xformers: &install_xformers
- run:
name: Install xFormers from Source
working_directory: ~/
command: |
source activate metaseq
if ! python -c 'import xformers'; then
git clone https://github.com/facebookresearch/xformers.git
cd xformers
git checkout 4c06c79095ba8073dc870ee820e36a83302ae30c
git submodule update --init --recursive
pip install .
cd ~/
fi

install_dep_pt19: &install_dep_pt19
- run:
name: Install Pytorch Dependencies
Expand Down
124 changes: 124 additions & 0 deletions gpu_tests/test_sequence_parallel_transformer_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest
import random
import sys
import torch
from types import SimpleNamespace
from megatron.mpu import destroy_model_parallel, initialize_model_parallel
from metaseq.model_parallel.modules import ModelParallelTransformerDecoderLayer


def reset_seeds():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random.seed(42)


def _distributed_init():
backend = "nccl"
rank = 0
world_size = 1
device = 0
torch.cuda.set_device(device)

# Call the init process.
init_method = "tcp://"
master_ip = "localhost"
master_port = "6000"
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend=backend, world_size=world_size, rank=rank, init_method=init_method
)


def _assert_allclose(out, ref, atol, rtol, msg="failed"):
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_diff = flatten_diff[max_pos]
num_different = torch.count_nonzero(flatten_diff > 0)
percentage = num_different / flatten_diff.numel()
del flatten_diff
assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f"/ atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different}, percentage={percentage}"
)


class TestParity(unittest.TestCase):
def test_xformers_parity(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if "xformers" not in sys.modules:
raise unittest.SkipTest("xformers not available, skipping test")

fw_atol = 4e-3
fw_rtol = 4e-4

bw_atol = 9e-2
bw_rtol = 2e-2

_distributed_init()
tensor_model_parallel_size_ = 1
initialize_model_parallel(tensor_model_parallel_size_)

S, B, E = 8, 16, 64
H = 2
args = SimpleNamespace(
sequence_parallel=True,
decoder_embed_dim=E,
dropout=0.0,
decoder_attention_heads=H,
decoder_ffn_embed_dim=64,
decoder_layers=1,
attention_dropout=0.0,
memory_efficient_fp16=True,
bf16=False,
)
S, B, E = 64, 128, 64
x = torch.rand(
(S, B, E), device="cuda", dtype=torch.float16, requires_grad=False
)
x_ = x.clone()
x.requires_grad = True
x_.requires_grad = True

xf_attn_variant = "xformers_default"
std_attn_variant = "default"

# xformers
args.attn_variant = xf_attn_variant
reset_seeds()
xf_decoder = ModelParallelTransformerDecoderLayer(args).cuda()
reset_seeds()
xf_result = xf_decoder(x)

# std attn
args.attn_variant = std_attn_variant
reset_seeds()
decoder = ModelParallelTransformerDecoderLayer(args).cuda()
reset_seeds()
result = decoder(x_)

torch.distributed.barrier()
_assert_allclose(xf_result, result, atol=fw_atol, rtol=fw_rtol)

# Test Backwards
reset_seeds()
xf_result.backward(torch.ones_like(x))
reset_seeds()
result.backward(torch.ones_like(x_))

torch.distributed.barrier()
_assert_allclose(x.grad, x_.grad, atol=bw_atol, rtol=bw_rtol)

# Reset groups
destroy_model_parallel()

torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(">> passed the test")


if __name__ == "__main__":
unittest.main()
143 changes: 130 additions & 13 deletions metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,23 @@
from metaseq.modules.activation_functions import gelu, gelu_back, relu, relu_back

import importlib
import logging
import math
import torch
from types import SimpleNamespace
from metaseq.dataclass.constants import AttentionVariants

# Not importing here cause cpu tests don't like it
global fused_layer_norm_cuda
fused_layer_norm_cuda = None

try:
import xformers.ops as xops

has_xformers = True
except (ImportError, ModuleNotFoundError):
has_xformers = False

try:
from megatron.mpu.mappings import (
_reduce_scatter_along_first_dim,
Expand All @@ -26,6 +36,17 @@
has_megatron_submodule = False


class _FakeContext(SimpleNamespace):
"""
Used to provide a temporary buffer for FlashAttention's saved buffers
"""

saved_tensors = None

def save_for_backward(self, *args):
self.saved_tensors = args


class SequeuceParallelTransformerBlock(torch.autograd.Function):
"""
This is custom FFN autograd function hardcoded for:
Expand Down Expand Up @@ -101,10 +122,27 @@ def forward(
head_dim,
recompute_fc1,
activation_fn_name, # "relu" or "gelu" for now
attn_variant,
xf_attn_op,
):
assert (
activation_fn_name == "relu" or activation_fn_name == "gelu"
), "Only relu/gelu is supported!"

xf_eff_attn = attn_variant == AttentionVariants.XFORMERS
if xf_eff_attn and not has_xformers:
raise ImportError(
"\n\nPlease install xformers to use memory efficient attention"
)

xf_op = xops.MemoryEfficientAttentionCutlassFwdFlashBwOp
# xf_op = xops.MemoryEfficientAttentionTritonFwdFlashBwOp
if xf_eff_attn and xf_attn_op is not None:
try:
xf_op = getattr(xops, xf_attn_op)
except AttributeError:
logging.warning(f"Invalid xformers memorry efficient op specified.")

# import from apex
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
Expand Down Expand Up @@ -139,16 +177,38 @@ def forward(

k, v, q = split_tensor_along_last_dim(kvq_out, 3, contiguous_split_chunks=True)
seq_len, bsz, embed_dim_per_partition = q.size()
q = q.view(seq_len, -1, head_dim)
k = k.view(seq_len, -1, head_dim)
v = v.view(seq_len, -1, head_dim).transpose(0, 1)

attn, _ = SequeuceParallelTransformerBlock.forward_mha(
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
)
if xf_eff_attn:
num_heads = embed_dim_per_partition // head_dim
q = q.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1)
k = k.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1)
v = v.view(seq_len, bsz, num_heads, head_dim).transpose(0, 1)

attn = (
xops.memory_efficient_attention_forward(
q,
k,
v,
attn_bias=xops.LowerTriangularMask(),
p=0.0,
scale=None,
op=xf_op[0],
)
.transpose(0, 1)
.reshape(seq_len, bsz, num_heads * head_dim)
)
else:
q = q.view(seq_len, -1, head_dim)
k = k.view(seq_len, -1, head_dim)
v = v.view(seq_len, -1, head_dim).transpose(0, 1)

attn, _ = SequeuceParallelTransformerBlock.forward_mha(
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
)

out_proj_out = torch.matmul(attn, out_proj_weight.t())
out_proj_out = _reduce_scatter_along_first_dim(out_proj_out)
out_proj_out = out_proj_out.view_as(residual)

out_proj_out = out_proj_out + residual

Expand Down Expand Up @@ -195,12 +255,16 @@ def forward(
ctx.head_dim,
ctx.embed_dim_per_partition,
ctx.activation_fn_name,
ctx.xf_eff_attn,
ctx.xf_op,
) = (
bsz,
seq_len,
head_dim,
embed_dim_per_partition,
activation_fn_name,
xf_eff_attn,
xf_op,
)

# apply scatter gather,
Expand All @@ -224,12 +288,22 @@ def backward(ctx, grad_output):
fc1_weight,
fc2_weight,
) = ctx.saved_tensors
bsz, seq_len, head_dim, embed_dim_per_partition, activation_fn_name = (
(
bsz,
seq_len,
head_dim,
embed_dim_per_partition,
activation_fn_name,
xf_eff_attn,
xf_op,
) = (
ctx.bsz,
ctx.seq_len,
ctx.head_dim,
ctx.embed_dim_per_partition,
ctx.activation_fn_name,
ctx.xf_eff_attn,
ctx.xf_op,
)
dtype = grad_output.dtype

Expand Down Expand Up @@ -320,9 +394,28 @@ def backward(ctx, grad_output):
)

# recalculate attention
attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha(
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
)
if xf_eff_attn:
num_heads = embed_dim_per_partition // head_dim

attn, lse = xops.memory_efficient_attention_forward_requires_grad(
q,
k,
v,
attn_bias=xops.LowerTriangularMask(),
p=0.0,
scale=None,
op=xf_op[0],
)
out = attn
attn = (
attn.transpose(0, 1)
.reshape(seq_len, bsz, num_heads * head_dim)
.contiguous()
)
else:
attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha(
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
)

handle.wait()

Expand All @@ -335,9 +428,31 @@ def backward(ctx, grad_output):
attn = SequeuceParallelTransformerBlock._collapse_first_dimensions(attn)
grad_out_proj_weight = grad_attention_output.t().matmul(attn)

grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha(
grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim
)
if xf_eff_attn:
grad_out_proj_input = grad_out_proj_input.reshape(
seq_len, bsz, -1, head_dim
).transpose(0, 1)
d_q, d_k, d_v = xops.memory_efficient_attention_backward(
grad=grad_out_proj_input,
output=out,
lse=lse,
query=q,
key=k,
value=v,
attn_bias=xops.LowerTriangularMask(),
p=0.0,
scale=None,
op=xf_op[1],
)
# bmhk => m b hk
d_q = d_q.transpose(0, 1).view(seq_len, bsz, -1)
d_k = d_k.transpose(0, 1).view(seq_len, bsz, -1)
d_v = d_v.transpose(0, 1).view(seq_len, bsz, -1)
grad_kvq_proj_output = torch.cat([d_k, d_v, d_q], dim=-1)
else:
grad_kvq_proj_output = SequeuceParallelTransformerBlock.backward_mha(
grad_out_proj_input, q, k, v, attn_probs, seq_len, bsz, head_dim
)

(
mha_layer_norm_output,
Expand Down Expand Up @@ -386,6 +501,8 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
None,
)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions metaseq/modules/transformer_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(
self.activation_fn_name = getattr(args, "activation_fn", "relu") or "relu"
self.skip_bias_add = (self.activation_fn_name == "gelu") and has_fused_bias_gelu

self.attn_variant = getattr(args, "attn_variant", "default")
self.xf_attn_op = getattr(args, "xf_attn_op", None)

# TODO[Susan]: Clean up these kwargs when unifying method signatures between model & non-model parallel.
fc1_kwargs = {
"initialize_params_on_gpu": initialize_params_on_gpu,
Expand Down Expand Up @@ -219,6 +222,8 @@ def forward(
self.self_attn.head_dim,
recompute_fc1,
self.activation_fn_name,
self.attn_variant,
self.xf_attn_op,
)
return x

Expand Down