Skip to content

Commit

Permalink
add emtpy_cache() after each padding (opendatahub-io#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu authored Aug 6, 2024
1 parent c034d5d commit 98f31cd
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,11 @@ def process_weights_after_loading(self):
self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
torch.cuda.empty_cache()
return

# If checkpoint is fp16, quantize here.
Expand Down

0 comments on commit 98f31cd

Please sign in to comment.