Skip to content

Commit

Permalink
fixed pmap call
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 4, 2024
1 parent 6b87ce1 commit 25c7c16
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(axis_name="data")(train_step)
train_step = jax.pmap(train_step, axis_name="data")
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = jax.jit(train_step)
Expand Down Expand Up @@ -811,7 +811,7 @@ def model_loss(params):
return train_state, loss, rng_state

if distributed_training:
train_step = jax.pmap(axis_name="data")(train_step)
train_step = jax.pmap(train_step, axis_name="data")
# train_step = shard_map(train_step, mesh=self.mesh, in_specs=P('data'), out_specs=P())
else:
train_step = jax.jit(train_step)
Expand Down

0 comments on commit 25c7c16

Please sign in to comment.