Skip to content

switch to pinning layout by default in FSDP to avoid data corruption in PJRT and be consistent with xm.all_reduce#4359

Merged
JackCaoG merged 1 commit into
pytorch:masterfrom
ronghanghu:xla_fsdp_pin_layout_by_default
Dec 17, 2022
Merged

switch to pinning layout by default in FSDP to avoid data corruption in PJRT and be consistent with xm.all_reduce#4359
JackCaoG merged 1 commit into
pytorch:masterfrom
ronghanghu:xla_fsdp_pin_layout_by_default

Conversation

@ronghanghu

@ronghanghu ronghanghu commented Dec 16, 2022

Copy link
Copy Markdown
Collaborator

This PR switches the XLA FSDP implementation to pinning layout by default to avoid data corruption in PJRT and be consistent with xm.all_reduce.

cc: @JackCaoG @hjm-aws @AlexWertheim

…in PJRT and be consistent with xm.all_reduce

@JackCaoG JackCaoG left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG merged commit b357d87 into pytorch:master Dec 17, 2022
@ronghanghu ronghanghu deleted the xla_fsdp_pin_layout_by_default branch December 17, 2022 04:33
@JackCaoG

Copy link
Copy Markdown
Collaborator

@aws-rhsoln @amithrm FYI, this might affect you guys, since now all-gather will become all_reduce based when layout is pin. We need to do this for TPU.

@aws-rhsoln

Copy link
Copy Markdown
Contributor

Thanks @JackCaoG for letting us know.

@amithrm

amithrm commented Dec 20, 2022

Copy link
Copy Markdown
Contributor

Hi @JackCaoG, is the pinning behavior because of the way compiler handles CC ops in TPUs? Or is this a requirement of FSDP? Checking because do we need to this as a compulsory fix and hence test out all our use cases

@JackCaoG

Copy link
Copy Markdown
Collaborator

when layout is unpin, and each process generate slightly different graph(which we do in pytorch/xla), there is a chance data corruption will happen(because of some XLA optimization assume all graph will be the same). I don't know if this will be the case for you guys through(whether this is a TPU specified behavior), maybe check with your compiler team?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants