Skip to content

Commit

Permalink
Fix prefix tuning finetune issue and update test (huggingface#975)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Jun 12, 2024
1 parent 1825d15 commit 33ee016
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
3 changes: 1 addition & 2 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,10 @@ def update(self, prev, cur, dim, idx, inp_seq_len):
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
Expand Down
16 changes: 8 additions & 8 deletions tests/baselines/llama_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,16 @@
"multi_card": {
"learning_rate": 5e-4,
"train_batch_size": 1,
"train_runtime": 16.5,
"train_samples_per_second": 63.161,
"perplexity": 1.224,
"train_runtime": 16.1,
"train_samples_per_second": 63.249,
"perplexity": 1.172,
"extra_arguments": [
"--num_virtual_tokens 8",
"--max_seq_length 64",
"--logging_steps 1",
"--report_to none",
"--max_steps 100",
"--peft_type prompt_tuning",
"--peft_type prefix_tuning",
"--max_seq_length 64",
"--lr_scheduler_type cosine",
"--warmup_steps 0",
Expand All @@ -256,16 +256,16 @@
"multi_card": {
"learning_rate": 5e-4,
"train_batch_size": 1,
"train_runtime": 16.5,
"train_runtime": 18.7,
"train_samples_per_second": 63.161,
"perplexity": 1.224,
"perplexity": 1.047,
"extra_arguments": [
"--num_virtual_tokens 8",
"--max_seq_length 64",
"--logging_steps 1",
"--report_to none",
"--max_steps 100",
"--peft_type prompt_tuning",
"--peft_type p_tuning",
"--max_seq_length 64",
"--lr_scheduler_type cosine",
"--warmup_steps 0",
Expand All @@ -276,4 +276,4 @@
}
}
}
}
}

0 comments on commit 33ee016

Please sign in to comment.