switch to pinning layout by default in FSDP to avoid data corruption in PJRT and be consistent with xm.all_reduce#4359
Conversation
…in PJRT and be consistent with xm.all_reduce
|
@aws-rhsoln @amithrm FYI, this might affect you guys, since now |
|
Thanks @JackCaoG for letting us know. |
|
Hi @JackCaoG, is the pinning behavior because of the way compiler handles CC ops in TPUs? Or is this a requirement of FSDP? Checking because do we need to this as a compulsory fix and hence test out all our use cases |
|
when layout is unpin, and each process generate slightly different graph(which we do in pytorch/xla), there is a chance data corruption will happen(because of some XLA optimization assume all graph will be the same). I don't know if this will be the case for you guys through(whether this is a TPU specified behavior), maybe check with your compiler team? |
This PR switches the XLA FSDP implementation to pinning layout by default to avoid data corruption in PJRT and be consistent with
xm.all_reduce.cc: @JackCaoG @hjm-aws @AlexWertheim