Skip to content

Introduce virtual device#4091

Merged
steventk-g merged 1 commit into
masterfrom
virtual-device
Nov 19, 2022
Merged

Introduce virtual device#4091
steventk-g merged 1 commit into
masterfrom
virtual-device

Conversation

@steventk-g

@steventk-g steventk-g commented Oct 12, 2022

Copy link
Copy Markdown
Collaborator

Changes in this PR

  • Expose virtual device via a flag XLA_USE_SPMD
  • Use virtual device to conditionally delay the data transfer of a tensor. This is accomplished by setting the device on the backend based on the flag, so the PJRT computation client can check for the virtual device before transferring data (NOTE: We need to use the device on the backend rather than checking for the flag directly, so that we still have a way to transfer sharded data later on).
  • When the flag is enabled, transfer sharded data without redownloading from an xla device. This is done in _xla_mark_sharding. The re-downloading path is preserved so that XLA_USE_SPMD=0 still works as well.
  • When the flag is enabled, ensure that the user gets xla:0 from xm.xla_device(). At this point, users should expect all tensors to be treated as if they are on the virtual device when SPMD is enabled.

@yeounoh

yeounoh commented Oct 12, 2022

Copy link
Copy Markdown
Contributor

Let's make sure that we cover the explict sharded cases, where we want to avoid the initial unpartitioned data transfer. We will have to double-check, but Modify XLATensor::Compile to begin data transfer on implicitly sharded tensors. this may not be needed.

@steventk-g

Copy link
Copy Markdown
Collaborator Author

Notes after chat with Yeounoh:

  • We need to locate the place where data transfer is initiated to backend device. This is probably in upstream code. This is where we can check the device type and potentially skip the data transfer.
  • We need to determine how to check the device type of an at::Tensor or XLATensor. The XlaDeviceType of tensors to shard will be "SPMD", and the device type will be XLA (like physical XLA devices: TPU, CPU, GPU).
  • Explicitly sharded tensors on an SPMD device will be transferred to the backend device by a call to CreateTensorData in _xla_mark_sharding.
  • We need to decide when to transfer data for implicitly sharded tensors, if not in XLATensor::Compile

@JackCaoG

Copy link
Copy Markdown
Collaborator

We need to locate the place where data transfer is initiated to backend device add a log to

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(

This is the only entry for the transfer data to device.

@steventk-g steventk-g force-pushed the virtual-device branch 14 times, most recently from 6a326ec to 32ff407 Compare October 18, 2022 06:18
@steventk-g

Copy link
Copy Markdown
Collaborator Author

Remaining implementation details before I can start testing:

  • Determine how to filter tensors and devices in CreateTensorsData methods. We want all devices passed into the sharded method to be SPMD, and we want to stop data transfer to backend devices when a tensor with an SPMD device is passed into the non-sharded method. Can we simply remove the non-SPMD tensors in the first case, and remove the SPMD tensors in the second case?
  • Figure out what to return from TensorToXlaData when we don't transfer data to a real backend.

@steventk-g steventk-g force-pushed the virtual-device branch 8 times, most recently from f64c1a5 to 01fa14d Compare October 26, 2022 19:08
@steventk-g steventk-g force-pushed the virtual-device branch 15 times, most recently from ccb5384 to 0adcbf7 Compare November 14, 2022 23:27
Comment thread test/test_virtual_device.py Outdated

@jonb377 jonb377 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work Steven!

Comment thread torch_xla/utils/utils.py Outdated
Comment thread torch_xla/core/xla_model.py
Comment thread third_party/xla_client/pjrt_computation_client.cc Outdated
Comment thread torch_xla/core/xla_model.py
Comment thread torch_xla/core/xla_model.py Outdated
Comment thread torch_xla/csrc/init_python_bindings.cpp Outdated
Comment thread torch_xla/csrc/init_python_bindings.cpp

@yeounoh yeounoh left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments, looking good -- great work @steventk-g :)

Comment thread torch_xla/utils/utils.py
Comment thread test/test_xla_sharding.py
Comment thread torch_xla/csrc/aten_xla_type.cpp
Comment thread torch_xla/csrc/init_python_bindings.cpp
Comment thread torch_xla/csrc/tensor_util.cpp Outdated

@yeounoh yeounoh left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @steventk-g 👍

Comment thread torch_xla/csrc/tensor_util.cpp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants