Skip to content

Commit

Permalink
Activation offloading for fullfinetuning + fix tied embedding (#1847)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Oct 30, 2024
1 parent a1bcb97 commit e99b890
Show file tree
Hide file tree
Showing 89 changed files with 384 additions and 103 deletions.
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: True # True reduces memory
dtype: bf16

# Logging
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory
dtype: bf16

# Logging
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
ac_mode: 'selective' # ['selective', 'full']
ac_option: 2 # [int] = ac every positive int layer
memory_efficient_fsdp_wrap: False
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ device: cuda
dtype: bf16

enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: True # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
fsdp_cpu_offload: True
compile: False # pytorch compile, set to true for perf/memory improvement
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_dora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qdora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/405B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
fsdp_cpu_offload: True
compile: False # pytorch compile, set to true for perf/memory improvement
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']
compile: False # pytorch compile, set to true for perf/memory improvement

Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
compile: False # pytorch compile, set to true for perf/memory improvement

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/1B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/1B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/1B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory
compile: False # pytorch compile, set to true for perf/memory improvement

# Reduced precision
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_2/3B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False
enable_activation_offloading: False # True reduces memory
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/3B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_2/3B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: False
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
Loading

0 comments on commit e99b890

Please sign in to comment.