-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] replace the extra DeviceMesh _flatten with mesh access #666
Conversation
[ghstack-poisoned]
ghstack-source-id: 6afa471f6e5320e998a99422e26b2f7f09dd1c6f Pull Request resolved: #666
if parallel_dims.cp_enabled | ||
else world_mesh[dp_mesh_dim_names] | ||
) | ||
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a new DeviceMesh
functionality that reacts specifically to <name1>_<name2>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not new. DeviceMesh supports world_mesh[<name1>_<name2>]
when the _flatten
behavior was implemented. However, it has a bug -- if the flattened mesh is constructed from 3+ mesh dimensions (e.g. dp_cp
is flattened using dp_shard
, dp_replicate
, and cp
. Accessing world_mesh[dp_cp]
throws error which breaks 3D/4D/5D composability).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we catch the error and ask users to update to some version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, for dp, if hsdp is enabled, "dp" is the flatten mesh for "dp_replicate", "dp_shard", right? Otherwise, "dp" is just "dp_shard".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , that's right. To summarize:
- FSDP: the only dp dimension in mesh is "dp"
- DDP: the only dp dimension in mesh is "dp"
- HSDP: the basic dp dimensions in mesh are "dp_shard" and "dp_replicate", which are later on flattened into "dp"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
if parallel_dims.cp_enabled | ||
else world_mesh[dp_mesh_dim_names] | ||
) | ||
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to have a try-except to indicate users are not using the latest PyTorch.
if parallel_dims.cp_enabled | ||
else world_mesh[dp_mesh_dim_names] | ||
) | ||
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we catch the error and ask users to update to some version?
Oh yeah that's right... |
**Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. [ghstack-poisoned]
ghstack-source-id: a0689ec03803419d67a4a79ec325dfed15113cdf Pull Request resolved: #666
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #667 Note: This PR is a reland of #666 where the PR was mistakenly merged into a wrong branch. **Summary** pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.
Stack from ghstack (oldest at bottom):
Summary
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.
In #592 we avoided this issue by calling
_flatten
instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.