Skip to content

Commit

Permalink
align gptq check to transformers for supporting cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Oct 16, 2024
1 parent 1e5014e commit 3b6ddfc
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.
import json
import os
import importlib
from enum import Enum
from logging import getLogger
from packaging import version
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -320,7 +322,9 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):

if not is_auto_gptq_available():
raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`")
if not torch.cuda.is_available():

gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed to quantize model.")

model.eval()
Expand Down Expand Up @@ -405,12 +409,13 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):

if not has_device_map:
# put modules from module_name_preceding_first_block on cuda
to_device = "cuda:0" if torch.cuda.is_available() else "cpu"
for module_name in self.module_name_preceding_first_block:
module = recurse_getattr(model, module_name)
if module is None:
raise ValueError(f"Module {module_name} was not found in model")
module = module.to(0)
blocks[0] = blocks[0].to(0)
module = module.to(to_device)
blocks[0] = blocks[0].to(to_device)

def store_input_hook(_, input, *args):
kwargs = args[0]
Expand All @@ -432,7 +437,7 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if not has_device_map or device.type == "cpu":
if (not has_device_map or device.type == "cpu") and torch.cuda.is_available():
data[k] = v.to(0)
else:
data[k] = v.to(device)
Expand Down Expand Up @@ -461,7 +466,7 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if not has_device_map or device.type == "cpu":
if (not has_device_map or device.type == "cpu") and torch.cuda.is_available():
data[k] = v.to(0)
else:
data[k] = v.to(device)
Expand All @@ -473,7 +478,7 @@ def store_input_hook(_, input, *args):

# move block to cuda if needed
# in case we have offload modules, we need to put them on cuda because of GPTQ object
if not has_device_map or get_device(block) == torch.device("cpu"):
if (not has_device_map or get_device(block) == torch.device("cpu")) and torch.cuda.is_available():
block = block.to(0)
layers = get_layers(block)
if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
Expand Down

0 comments on commit 3b6ddfc

Please sign in to comment.