Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Observer Restructure]: Add Observers; Add calibration and frozen steps to QuantizationModifier #837

Merged
merged 30 commits into from
Oct 31, 2024

Conversation

dsikka
Copy link
Collaborator

@dsikka dsikka commented Oct 10, 2024

SUMMARY:

  • PR to add observers to llm-compressor
  • Adds the required hooks needed to run calibration as part of the QuantizationModifier. All required calibration lifecycle steps can now be found in calibration.py
  • Also adds the KV Cache object such that calibration can be done to update k_scale and v_scale for kv_cache quantization
  • Requires the following PR to land in compressed-tensors: Observer Restructure: Remove Observers, calibration, and applying frozen steps from lifecycle neuralmagic/compressed-tensors#189
  • Updated Calibration lifecycle (also shown in the docstrings). This will run as part of the calibration step within the QuantizationModifier

Run calibration if running input/output activation quantization or kv_cache quantization.

Calibration Lifecycle for a single torch.nn.Module:

      1. initialize_observer():
          if input/output activation:
              - observer = Observer.load_from_registry(...)
              - module.register_module(f"{base_name}_observer", observer)
              
      2. register_calibration_hooks():
          if input activation and not dynamic quant (used to call observers before intput QDQ):
              - pre_hook_handle = module.register_forward_pre_hook(calibrate_input_hook())
          if output activation and not dynamic quant (used to call observers before output QDQ):
              - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook())
          if kv_cache quantization (used to set kv_cache to QuantizedKVParameterCache and update k_scale/v_scale)
              - pre_hook_handle = module.register_forward_pre_hook(calibrate_kv_cache_input_hook(), with_kwargs=True)
              - post_hook_handle = module.register_forward_hook(calibrate_kv_cache_output_hook())
          self.calibration_hooks.append(pre_hook_handle)
          self.calibration_hooks.append(post_hook_handle)

      3. self._calibrate(module) # run forward pass through model using calibration data
      4. set_unset_kv_cache() # remove kv_cache objects attached to attention layers  initially set in _apply_modifier_to_model
      5. remove calibration hooks in self.calibration_hooks_
      6. remove observers

Testing:

  • Tested w4a16, quantized kv_cache, and w8a8 int8 workflows

@dsikka dsikka marked this pull request as draft October 10, 2024 17:03
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

@dsikka dsikka changed the title [Observer Restructure]: Update function call [Observer Restructure]: Add Observers Oct 14, 2024
@dsikka dsikka changed the title [Observer Restructure]: Add Observers [Observer Restructure]: Add Observers, calibration, and frozen steps to lifecycle Oct 22, 2024
@dsikka dsikka changed the title [Observer Restructure]: Add Observers, calibration, and frozen steps to lifecycle [Observer Restructure]: Add Observers; Add calibration, and frozen steps to QuantizationModifier Oct 22, 2024
@dsikka dsikka changed the title [Observer Restructure]: Add Observers; Add calibration, and frozen steps to QuantizationModifier [Observer Restructure]: Add Observers; Add calibration and frozen steps to QuantizationModifier Oct 22, 2024
@dsikka
Copy link
Collaborator Author

dsikka commented Oct 24, 2024

With the corresponding remove-observers branch checked out

python3 examples/quantization_w4a16/llama3_example.py                                                                                                                         
Traceback (most recent call last):                                                                                         
  File "/home/ksayers/llm-compressor/examples/quantization_w4a16/llama3_example.py", line 4, in <module>                   
    from llmcompressor.modifiers.quantization import GPTQModifier                                                          
  File "/home/ksayers/llm-compressor/src/llmcompressor/modifiers/quantization/__init__.py", line 3, in <module>            
    from .cache import *                                                                                                   
  File "/home/ksayers/llm-compressor/src/llmcompressor/modifiers/quantization/cache.py", line 18, in <module>              
    from compressed_tensors.quantization.lifecycle import KVCacheScaleType 

Can you confirm you're using the most recent commit for both? I do not get this error. And in general you should not as the scales are defined under src/compressed_tensors/quantization/lifecycle/initialize.py

kylesayrs
kylesayrs previously approved these changes Oct 30, 2024
rahul-tuli
rahul-tuli previously approved these changes Oct 30, 2024
Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the new structure, Great work!

Left a few nits, would recommend revisiting the docstrings and updating them for consistency:
-> Start docstrings with a Capital Letter
-> Include param info in :params over just writing a description in the main docstring

Otherwise no big red flags! Good tests as well.

src/llmcompressor/observers/base.py Outdated Show resolved Hide resolved
src/llmcompressor/observers/base.py Show resolved Hide resolved
@dsikka dsikka dismissed stale reviews from rahul-tuli and kylesayrs via 9fc10a9 October 30, 2024 21:49
@dsikka dsikka merged commit 18e9a9f into main Oct 31, 2024
6 of 7 checks passed
@dsikka dsikka deleted the update-foward branch October 31, 2024 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants