Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] replace the extra DeviceMesh _flatten with mesh access #666

Merged
merged 2 commits into from
Oct 31, 2024

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Oct 30, 2024

Stack from ghstack (oldest at bottom):

Summary
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling _flatten instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.

XilunWu added a commit that referenced this pull request Oct 30, 2024
ghstack-source-id: 6afa471f6e5320e998a99422e26b2f7f09dd1c6f
Pull Request resolved: #666
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 30, 2024
if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a new DeviceMesh functionality that reacts specifically to <name1>_<name2>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not new. DeviceMesh supports world_mesh[<name1>_<name2>] when the _flatten behavior was implemented. However, it has a bug -- if the flattened mesh is constructed from 3+ mesh dimensions (e.g. dp_cp is flattened using dp_shard, dp_replicate, and cp. Accessing world_mesh[dp_cp] throws error which breaks 3D/4D/5D composability).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we catch the error and ask users to update to some version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, for dp, if hsdp is enabled, "dp" is the flatten mesh for "dp_replicate", "dp_shard", right? Otherwise, "dp" is just "dp_shard".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wz337 , that's right. To summarize:

  1. FSDP: the only dp dimension in mesh is "dp"
  2. DDP: the only dp dimension in mesh is "dp"
  3. HSDP: the basic dp dimensions in mesh are "dp_shard" and "dp_replicate", which are later on flattened into "dp"

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to have a try-except to indicate users are not using the latest PyTorch.

if parallel_dims.cp_enabled
else world_mesh[dp_mesh_dim_names]
)
dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we catch the error and ask users to update to some version?

@XilunWu
Copy link
Contributor Author

XilunWu commented Oct 30, 2024

It's better to have a try-except to indicate users are not using the latest PyTorch.

Oh yeah that's right...

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch.


[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Oct 31, 2024
ghstack-source-id: a0689ec03803419d67a4a79ec325dfed15113cdf
Pull Request resolved: #666
@XilunWu XilunWu merged commit 3653bf2 into gh/XilunWu/9/base Oct 31, 2024
5 checks passed
XilunWu added a commit that referenced this pull request Oct 31, 2024
XilunWu added a commit that referenced this pull request Oct 31, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #667

Note: This PR is a reland of #666 where the PR was mistakenly merged
into a wrong branch.

**Summary**
pytorch/pytorch#138945 fixes DeviceMesh access
on flattened mesh which are constructed from more than 2 meshes. Refer
to the fix PR for details if interested.

In #592 we avoided this issue by calling `_flatten` instead of direct
accessing the flattened mesh. We want to turn back to mesh access which
is more straightforward since the fix has been merged in PyTorch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants