Skip to content

Commit

Permalink
Fix test script
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Oct 15, 2024
1 parent 5630d73 commit 8350b91
Showing 1 changed file with 14 additions and 86 deletions.
100 changes: 14 additions & 86 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,99 +100,27 @@ def forward(self, x, y):
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("label_smoothing", [0, 0.1])
@pytest.mark.parametrize("label_smoothing, ignore_index", [(0.0, -100), (0.1, 42)])
def test_correctness(
B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol
B,
T,
H,
V,
scalar,
dtype,
bias,
label_smoothing,
ignore_index,
reduction,
atol,
rtol,
):
device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=bias,
label_smoothing=label_smoothing,
reduction=reduction,
dtype=dtype,
).to(device)
liger_lm_head_ce = LigerLMHeadCE(
H=H,
V=V,
bias=bias,
label_smoothing=label_smoothing,
reduction=reduction,
dtype=dtype,
).to(device)

# init the linear in all CEs with the same weights
torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)

if bias:
torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(
V, device=device, dtype=dtype
)

_tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
_input1 = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

output1 = torch_lm_head_ce(_input1, target)
output2 = liger_lm_head_ce(_input2, target)

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

output1.backward()
output2.backward()

assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

assert_verbose_allclose(
torch_lm_head_ce.lin.weight.grad,
liger_lm_head_ce.lin.weight.grad,
atol=atol,
rtol=rtol,
)

if bias:
assert_verbose_allclose(
torch_lm_head_ce.lin.bias.grad,
liger_lm_head_ce.lin.bias.grad,
atol=atol,
rtol=rtol,
)


@pytest.mark.parametrize(
"B, T, H, V",
[
# (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
(8, 2048, 4096, 32000), # llama2, mistral
# Comment out to speed up testing
# (4, 2048, 4096, 128256), # llama3 8B
# (4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
],
)
@pytest.mark.parametrize(
"reduction, scalar, dtype, atol, rtol",
[
("mean", 1.0, torch.bfloat16, 5e-3, 5e-2),
("mean", 1.0, torch.float32, 1e-5, 5e-4),
("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
("sum", 1.0, torch.float32, 1e-3, 5e-2),
],
)
@pytest.mark.parametrize("ignore_index", [-100, 42])
def test_correctness_with_ignore_index(
B, T, H, V, scalar, dtype, bias, ignore_index, reduction, atol, rtol
):
device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=bias,
ignore_index=ignore_index,
reduction=reduction,
dtype=dtype,
Expand All @@ -201,6 +129,7 @@ def test_correctness_with_ignore_index(
H=H,
V=V,
bias=bias,
label_smoothing=label_smoothing,
ignore_index=ignore_index,
reduction=reduction,
dtype=dtype,
Expand All @@ -221,7 +150,6 @@ def test_correctness_with_ignore_index(
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
Expand Down

0 comments on commit 8350b91

Please sign in to comment.