-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce cuda_p2p based fused_all_gather_matmul and fused_matmul_reduce_scatter #126634
Conversation
…uce_scatter [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126634
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (3 Unrelated Failures)As of commit 29e6b1f with merge base ff65b18 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do feel like we should be able to provide a higher level API for this 🤔 It would be nice if it could be the same API for both allgather and reduce_scatter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks awesome!
|
||
|
||
@contextmanager | ||
def test_with_non_cuda_p2p_group(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these test utils should move to the torch.testing package instead?
ag_shape = list(A_shard.shape) | ||
ag_shape[gather_dim] *= group_size | ||
ag_out = A_shard.new_empty(ag_shape) | ||
return ag_out, [ag_out @ B for B in Bs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for meta formulas, wondering if this matmul would actually incur computation or just call the matmul meta kernel(i guess it's the later one?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we are calling the meta kernels to deduce device, shape, and strides.
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
…uce_scatter ghstack-source-id: cfada01c278b4ed552914d073147c77aa29e6a04 Pull Request resolved: #126634
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
…_matmul_reduce_scatter" ## Context See context [here](#122163). ## This PR Introduces `cuda_p2p` based `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops which performs micro-pipelining TP for `all-gather -> matmul` and `matmul -> reduce-scatter` respectively. Fusion vs. decomposition - in principle, the micro-pipelining is achieved via decomposition. However, in practice, today Inductor can't deal with the decomposed patterns well. So instead performing decomposition in Inductor, we fuse the patterns to be decomposed and dispatch them to corresponding operators that handle decomposition + micropipelining. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
…_matmul_reduce_scatter" [ghstack-poisoned]
…_matmul_reduce_scatter" [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…uce_scatter (#126634) Pull Request resolved: #126634 Approved by: https://github.com/Chillee, https://github.com/wanchaol (cherry picked from commit 1071437)
Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k