Skip to content

Commit

Permalink
fix: fixed sharding in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 5, 2024
1 parent 1ac5875 commit 599aee7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_dataset_grain(
shuffle=True,
seed=seed,
num_epochs=num_epochs,
shard_options=pygrain.NoSharding(),
shard_options=pygrain.ShardByJaxProcess(),
)

transformations = [
Expand Down Expand Up @@ -1093,9 +1093,9 @@ def main(args):
main(args)

"""
JAX_TRACEBACK_FILTERING=off python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/arrayrecord/laion-aesthetics-12m+mscoco-2017'\
--epochs=40 --batch_size=128 \
python3 training.py --dataset=laiona_coco --dataset_path='/home/mrwhite0racle/gcs_mount/arrayrecord/laion-aesthetics-12m+mscoco-2017'\
--epochs=40 --batch_size=64 \
--learning_rate=2.7e-4 --num_res_blocks=3 \
--use_self_and_cross=False --dtype=float32 --precision=high --attention_heads=16 \
--experiment_name='batch 128 multi-host laiona_coco'
--use_self_and_cross=False --dtype=bfloat16 --precision=high --attention_heads=16\
--experiment_name='batch 64 v4-16 host laiona_coco'"
"""

0 comments on commit 599aee7

Please sign in to comment.