mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
[xla:cpu] Replace xla::cpu::CollectivesInterface with xla::cpu::CpuCollectives
PiperOrigin-RevId: 713661518
This commit is contained in:
committed by
TensorFlower Gardener
parent
cf43bb53b5
commit
7b49ba401a
@@ -112,7 +112,6 @@ cc_library(
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
@@ -184,7 +183,6 @@ cc_library(
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
|
||||
15
third_party/xla/xla/backends/cpu/runtime/BUILD
vendored
15
third_party/xla/xla/backends/cpu/runtime/BUILD
vendored
@@ -150,7 +150,6 @@ cc_library(
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/service/cpu:cpu_executable_run_options",
|
||||
"//xla/service/cpu:cpu_runtime",
|
||||
"//xla/service/cpu:in_process_collectives",
|
||||
@@ -193,13 +192,13 @@ xla_cc_test(
|
||||
deps = [
|
||||
":thunk",
|
||||
"//xla:executable_run_options",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/service/cpu:cpu_executable_run_options",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
"//xla/tsl/platform:test_main",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/platform:test",
|
||||
"@local_tsl//tsl/platform:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -337,7 +336,6 @@ cc_library(
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@@ -466,7 +464,6 @@ cc_library(
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"//xla/tsl/platform:errors",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@@ -501,7 +498,6 @@ cc_library(
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
@@ -533,7 +529,6 @@ cc_library(
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@@ -567,7 +562,6 @@ cc_library(
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:computation_placer",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
@@ -606,7 +600,6 @@ cc_library(
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:computation_placer",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"//xla/tsl/platform:errors",
|
||||
|
||||
@@ -29,7 +29,6 @@ limitations under the License.
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
||||
@@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
||||
@@ -28,7 +28,6 @@ limitations under the License.
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
||||
@@ -37,7 +37,6 @@ limitations under the License.
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/computation_placer.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/status_macros.h"
|
||||
|
||||
@@ -31,7 +31,6 @@ limitations under the License.
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
|
||||
@@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
||||
@@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/cpu/cpu_executable_run_options.h"
|
||||
#include "xla/service/cpu/in_process_collectives.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
@@ -103,7 +102,7 @@ Thunk::CollectiveExecuteParams::Create(
|
||||
|
||||
// Default implementation of a collectives interface that can execute
|
||||
// collective operations within the same process.
|
||||
static CollectivesInterface* in_process_collectives =
|
||||
static CpuCollectives* in_process_collectives =
|
||||
new runtime::InProcessCollectives();
|
||||
|
||||
// If CPU executable run options are set, use the collectives interface
|
||||
@@ -111,7 +110,7 @@ Thunk::CollectiveExecuteParams::Create(
|
||||
// in-process collectives interface.
|
||||
const CpuExecutableRunOptions* cpu_run_options =
|
||||
run_options->cpu_executable_run_options();
|
||||
CollectivesInterface* collectives =
|
||||
CpuCollectives* collectives =
|
||||
cpu_run_options && cpu_run_options->collectives()
|
||||
? cpu_run_options->collectives()
|
||||
: in_process_collectives;
|
||||
|
||||
@@ -17,11 +17,11 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/cpu/cpu_executable_run_options.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "tsl/platform/test.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
namespace {
|
||||
@@ -93,13 +93,12 @@ TEST(ThunkTest, CollectiveExecuteParams) {
|
||||
// Test forwarding collectives interface from CpuExecutableRunOptions.
|
||||
CpuExecutableRunOptions cpu_run_options;
|
||||
cpu_run_options.set_collectives(
|
||||
reinterpret_cast<CollectivesInterface*>(0x12345678));
|
||||
reinterpret_cast<CpuCollectives*>(0x12345678));
|
||||
run_options.set_cpu_executable_run_options(&cpu_run_options);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(params,
|
||||
Thunk::CollectiveExecuteParams::Create(&run_options));
|
||||
EXPECT_EQ(params.collectives,
|
||||
reinterpret_cast<CollectivesInterface*>(0x12345678));
|
||||
EXPECT_EQ(params.collectives, reinterpret_cast<CpuCollectives*>(0x12345678));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
@@ -27,8 +26,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
CliqueKey::CliqueKey(std::vector<GlobalDeviceId> devices)
|
||||
: devices_(std::move(devices)) {}
|
||||
CliqueKey::CliqueKey(absl::Span<const GlobalDeviceId> devices)
|
||||
: devices_(devices.begin(), devices.end()) {}
|
||||
|
||||
absl::Span<const GlobalDeviceId> CliqueKey::devices() const { return devices_; }
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ namespace xla {
|
||||
// these cliques launch operations (device kernels) on different device streams.
|
||||
class CliqueKey {
|
||||
public:
|
||||
explicit CliqueKey(std::vector<GlobalDeviceId> devices);
|
||||
explicit CliqueKey(absl::Span<const GlobalDeviceId> devices);
|
||||
virtual ~CliqueKey() = default;
|
||||
|
||||
CliqueKey(const CliqueKey& other) = default;
|
||||
|
||||
15
third_party/xla/xla/pjrt/cpu/BUILD
vendored
15
third_party/xla/xla/pjrt/cpu/BUILD
vendored
@@ -151,6 +151,7 @@ cc_library(
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla:xla_proto_cc",
|
||||
"//xla/backends/cpu/codegen:cpu_features",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/backends/cpu/runtime:buffer_allocations",
|
||||
"//xla/backends/cpu/runtime:thread_pool_task_runner",
|
||||
"//xla/backends/cpu/runtime:thunk",
|
||||
@@ -186,7 +187,6 @@ cc_library(
|
||||
"//xla/service:hlo_proto_cc",
|
||||
"//xla/service:hlo_value",
|
||||
"//xla/service:maybe_owning_device_memory",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/service/cpu:cpu_compiler",
|
||||
"//xla/service/cpu:cpu_event",
|
||||
"//xla/service/cpu:cpu_executable",
|
||||
@@ -302,11 +302,12 @@ cc_library(
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/backends/cpu/collectives:gloo_communicator",
|
||||
"//xla/core/collectives:clique_id",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
@@ -334,12 +335,16 @@ xla_cc_test(
|
||||
":gloo_kv_store",
|
||||
"//xla:executable_run_options",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_clique_key",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/core/collectives:clique_id",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/pjrt/distributed:in_memory_key_value_store",
|
||||
"//xla/pjrt/distributed:key_value_store_interface",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/lib/core:status_test_util",
|
||||
"//xla/tsl/platform:env",
|
||||
@@ -384,11 +389,13 @@ cc_library(
|
||||
"//xla:types",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/backends/cpu/collectives:mpi_communicator",
|
||||
"//xla/core/collectives:clique_id",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/log",
|
||||
|
||||
4
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
4
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
@@ -49,6 +49,7 @@ limitations under the License.
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/backends/cpu/codegen/cpu_features.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
||||
#include "xla/backends/cpu/runtime/thread_pool_task_runner.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
@@ -85,7 +86,6 @@ limitations under the License.
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/compiler.h"
|
||||
#include "xla/service/computation_placer.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/cpu/cpu_compiler.h"
|
||||
#include "xla/service/cpu/cpu_event.h"
|
||||
#include "xla/service/cpu/cpu_executable.h"
|
||||
@@ -311,7 +311,7 @@ static tsl::ThreadOptions GetThreadOptions() {
|
||||
|
||||
TfrtCpuClient::TfrtCpuClient(
|
||||
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
|
||||
std::shared_ptr<cpu::CollectivesInterface> collectives, size_t num_threads,
|
||||
std::shared_ptr<cpu::CpuCollectives> collectives, size_t num_threads,
|
||||
bool asynchronous,
|
||||
std::function<void(HloModuleConfig&)> customize_hlo_module_config)
|
||||
: process_index_(process_index),
|
||||
|
||||
8
third_party/xla/xla/pjrt/cpu/cpu_client.h
vendored
8
third_party/xla/xla/pjrt/cpu/cpu_client.h
vendored
@@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/hlo/builder/xla_computation.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
@@ -57,7 +58,6 @@ limitations under the License.
|
||||
#include "xla/pjrt/transpose.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/computation_placer.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/cpu/cpu_event.h"
|
||||
#include "xla/service/executable.h"
|
||||
#include "xla/service/hlo.pb.h"
|
||||
@@ -77,8 +77,8 @@ class TfrtCpuClient final : public PjRtClient {
|
||||
public:
|
||||
TfrtCpuClient(
|
||||
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
|
||||
std::shared_ptr<cpu::CollectivesInterface> collectives,
|
||||
size_t num_threads, bool asynchronous,
|
||||
std::shared_ptr<cpu::CpuCollectives> collectives, size_t num_threads,
|
||||
bool asynchronous,
|
||||
std::function<void(HloModuleConfig&)> customize_hlo_module_config);
|
||||
~TfrtCpuClient() override;
|
||||
|
||||
@@ -288,7 +288,7 @@ class TfrtCpuClient final : public PjRtClient {
|
||||
absl::Mutex transpose_mu_;
|
||||
TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_);
|
||||
|
||||
std::shared_ptr<cpu::CollectivesInterface> collectives_;
|
||||
std::shared_ptr<cpu::CpuCollectives> collectives_;
|
||||
|
||||
xla::CpuTopologyDescription topology_;
|
||||
|
||||
|
||||
71
third_party/xla/xla/pjrt/cpu/gloo_collectives.cc
vendored
71
third_party/xla/xla/pjrt/cpu/gloo_collectives.cc
vendored
@@ -15,10 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "xla/pjrt/cpu/gloo_collectives.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <exception>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@@ -27,7 +29,6 @@ limitations under the License.
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "gloo/context.h"
|
||||
#include "gloo/rendezvous/context.h"
|
||||
@@ -35,6 +36,8 @@ limitations under the License.
|
||||
#include "gloo/rendezvous/store.h"
|
||||
#include "gloo/transport/device.h"
|
||||
#include "xla/backends/cpu/collectives/gloo_communicator.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
@@ -48,42 +51,38 @@ GlooCollectives::GlooCollectives(
|
||||
|
||||
GlooCollectives::~GlooCollectives() = default;
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> GlooCollectives::GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> global_devices, int rank) {
|
||||
Context* context;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto& context_ref = contexts_[std::make_tuple(
|
||||
std::vector<GlobalDeviceId>(global_devices.begin(),
|
||||
global_devices.end()),
|
||||
rank)];
|
||||
if (!context_ref) {
|
||||
context_ref = std::make_unique<Context>();
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
GlooCollectives::CreateCommunicators(int32_t nranks,
|
||||
const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) {
|
||||
std::vector<std::unique_ptr<Communicator>> communicators;
|
||||
for (auto& device_rank : ranks) {
|
||||
size_t rank = device_rank.rank.value();
|
||||
|
||||
auto gloo_context = std::make_shared<gloo::rendezvous::Context>(
|
||||
rank, clique_key.num_devices());
|
||||
auto prefix_store = gloo::rendezvous::PrefixStore(
|
||||
absl::StrCat("gloo/",
|
||||
absl::StrJoin(clique_key.devices(), ",",
|
||||
[](std::string* out, GlobalDeviceId id) {
|
||||
absl::StrAppend(out, id.value());
|
||||
})),
|
||||
*store_);
|
||||
|
||||
try {
|
||||
gloo_context->connectFullMesh(prefix_store, device_);
|
||||
} catch (std::exception& e) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Gloo context initialization failed: ", e.what()));
|
||||
}
|
||||
context = context_ref.get();
|
||||
|
||||
communicators.push_back(std::make_unique<GlooCommunicator>(
|
||||
std::move(gloo_context), rank, clique_key.num_devices()));
|
||||
}
|
||||
absl::MutexLock context_lock(&context->mu);
|
||||
if (context->communicator) {
|
||||
return context->communicator;
|
||||
}
|
||||
auto gloo_context =
|
||||
std::make_shared<gloo::rendezvous::Context>(rank, global_devices.size());
|
||||
auto prefix_store = gloo::rendezvous::PrefixStore(
|
||||
absl::StrCat("gloo/",
|
||||
absl::StrJoin(global_devices, ",",
|
||||
[](std::string* out, GlobalDeviceId id) {
|
||||
absl::StrAppend(out, id.value());
|
||||
})),
|
||||
*store_);
|
||||
try {
|
||||
gloo_context->connectFullMesh(prefix_store, device_);
|
||||
} catch (std::exception& e) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Gloo context initialization failed: ", e.what()));
|
||||
}
|
||||
context->communicator = std::make_shared<GlooCommunicator>(
|
||||
std::move(gloo_context), rank, global_devices.size());
|
||||
return context->communicator;
|
||||
|
||||
return communicators;
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
32
third_party/xla/xla/pjrt/cpu/gloo_collectives.h
vendored
32
third_party/xla/xla/pjrt/cpu/gloo_collectives.h
vendored
@@ -16,49 +16,39 @@ limitations under the License.
|
||||
#ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_
|
||||
#define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "gloo/context.h"
|
||||
#include "gloo/rendezvous/store.h"
|
||||
#include "gloo/transport/device.h"
|
||||
#include "xla/backends/cpu/collectives/gloo_communicator.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
class GlooCollectives : public CollectivesInterface {
|
||||
class GlooCollectives : public CpuCollectives {
|
||||
public:
|
||||
GlooCollectives(std::unique_ptr<gloo::rendezvous::Store> store,
|
||||
std::shared_ptr<gloo::transport::Device> device);
|
||||
~GlooCollectives() override;
|
||||
|
||||
// Thread-safe.
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> devices, int rank) override;
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) final;
|
||||
|
||||
private:
|
||||
struct Context {
|
||||
absl::Mutex mu;
|
||||
std::shared_ptr<GlooCommunicator> communicator;
|
||||
};
|
||||
|
||||
std::unique_ptr<gloo::rendezvous::Store> store_;
|
||||
std::shared_ptr<gloo::transport::Device> device_;
|
||||
|
||||
absl::Mutex mu_;
|
||||
absl::flat_hash_map<std::tuple<std::vector<GlobalDeviceId>, int>,
|
||||
std::unique_ptr<Context>>
|
||||
contexts_ ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
@@ -20,18 +20,22 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_clique_key.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/core/collectives/rank_id.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/pjrt/cpu/gloo_kv_store.h"
|
||||
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
|
||||
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
@@ -59,7 +63,7 @@ constexpr int kNumParticipants = 2;
|
||||
constexpr size_t kBufferSize = 256;
|
||||
constexpr absl::Duration kTimeout = absl::Seconds(5);
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
absl::StatusOr<std::unique_ptr<Communicator>> GetCommunicator(
|
||||
size_t kNumParticipants, absl::Span<GlobalDeviceId const> global_devices,
|
||||
const std::shared_ptr<xla::KeyValueStoreInterface>& kv_store, int rank) {
|
||||
auto collectives = std::make_shared<cpu::GlooCollectives>(
|
||||
@@ -69,7 +73,16 @@ absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
#elif defined(__APPLE__)
|
||||
gloo::transport::uv::CreateDevice(gloo::transport::uv::attr()));
|
||||
#endif // defined(__linux__)
|
||||
return collectives->GetCommunicator(global_devices, rank);
|
||||
|
||||
CpuCliqueKey clique_key(global_devices);
|
||||
CpuCollectives::DeviceRank device_rank(nullptr, RankId(rank));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto communicators,
|
||||
collectives->CreateCommunicators(
|
||||
global_devices.size(), clique_key, std::nullopt,
|
||||
{device_rank}, CpuCollectives::Config()));
|
||||
|
||||
return std::move(communicators[0]);
|
||||
}
|
||||
|
||||
RendezvousKey MakeRendezvousKey(std::vector<GlobalDeviceId> global_devices) {
|
||||
|
||||
47
third_party/xla/xla/pjrt/cpu/mpi_collectives.cc
vendored
47
third_party/xla/xla/pjrt/cpu/mpi_collectives.cc
vendored
@@ -15,8 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "xla/pjrt/cpu/mpi_collectives.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/log.h"
|
||||
@@ -25,8 +27,9 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "mpi.h"
|
||||
#include "xla/backends/cpu/collectives/mpi_communicator.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
@@ -39,13 +42,13 @@ void MpiCollectives::Init() {
|
||||
VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_;
|
||||
}
|
||||
|
||||
void MpiCollectives::Finalize() {
|
||||
contexts_.clear();
|
||||
MPI_Finalize();
|
||||
}
|
||||
void MpiCollectives::Finalize() { MPI_Finalize(); }
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> MpiCollectives::GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> global_devices, int rank) {
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
MpiCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) {
|
||||
int flag;
|
||||
MPI_Is_thread_main(&flag);
|
||||
if (!flag) {
|
||||
@@ -55,23 +58,21 @@ absl::StatusOr<std::shared_ptr<Communicator>> MpiCollectives::GetCommunicator(
|
||||
"threads/devices per process are not yet supported.");
|
||||
}
|
||||
|
||||
auto& context = contexts_[std::make_tuple(
|
||||
std::vector<GlobalDeviceId>(global_devices.begin(), global_devices.end()),
|
||||
rank)];
|
||||
if (context) {
|
||||
return context;
|
||||
std::vector<std::unique_ptr<Communicator>> communicators;
|
||||
for (auto& device_rank : ranks) {
|
||||
size_t rank = device_rank.rank.value();
|
||||
int color;
|
||||
int key = 0;
|
||||
if (clique_key.num_devices() > 0) {
|
||||
color = static_cast<int>(clique_key.devices().at(0).value());
|
||||
key = rank;
|
||||
} else {
|
||||
color = MPI_UNDEFINED;
|
||||
}
|
||||
communicators.push_back(std::make_unique<MpiCommunicator>(color, key));
|
||||
}
|
||||
|
||||
int color;
|
||||
int key = 0;
|
||||
if (global_devices.size() > 0) {
|
||||
color = static_cast<int>(global_devices.at(0).value());
|
||||
key = rank;
|
||||
} else {
|
||||
color = MPI_UNDEFINED;
|
||||
}
|
||||
context = std::make_shared<MpiCommunicator>(color, key);
|
||||
return context;
|
||||
return communicators;
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
21
third_party/xla/xla/pjrt/cpu/mpi_collectives.h
vendored
21
third_party/xla/xla/pjrt/cpu/mpi_collectives.h
vendored
@@ -16,23 +16,24 @@ limitations under the License.
|
||||
#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_
|
||||
#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/mpi_communicator.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
class MpiCollectives : public CollectivesInterface {
|
||||
class MpiCollectives : public CpuCollectives {
|
||||
public:
|
||||
/*
|
||||
The user has to explicitly call Init() and Finalize() before and
|
||||
@@ -46,8 +47,11 @@ class MpiCollectives : public CollectivesInterface {
|
||||
void Init();
|
||||
void Finalize();
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> global_devices, int rank) override;
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) final;
|
||||
|
||||
private:
|
||||
absl::Status ExchangeGlobalDeviceIds(
|
||||
@@ -55,9 +59,6 @@ class MpiCollectives : public CollectivesInterface {
|
||||
|
||||
int mpi_world_rank_;
|
||||
int mpi_world_size_;
|
||||
absl::flat_hash_map<std::tuple<std::vector<GlobalDeviceId>, int>,
|
||||
std::shared_ptr<MpiCommunicator>>
|
||||
contexts_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
@@ -40,8 +40,8 @@ cc_library(
|
||||
srcs = [],
|
||||
hdrs = ["cpu_client_options.h"],
|
||||
deps = [
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/service:hlo_module_config",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/service/hlo_module_config.h"
|
||||
|
||||
namespace xla {
|
||||
@@ -45,7 +45,7 @@ struct CpuClientOptions {
|
||||
|
||||
// Distributed collectives implementation. Optional. If not provided, an
|
||||
// in-process collectives implementation will be used.
|
||||
std::shared_ptr<cpu::CollectivesInterface> collectives;
|
||||
std::shared_ptr<cpu::CpuCollectives> collectives;
|
||||
|
||||
// If defined this function will be called on the HloModuleConfig before
|
||||
// compilation, and allows users to set custom flags.
|
||||
|
||||
2
third_party/xla/xla/python/BUILD
vendored
2
third_party/xla/xla/python/BUILD
vendored
@@ -1307,6 +1307,7 @@ tsl_pybind_extension(
|
||||
"//xla:shape_util",
|
||||
"//xla:types",
|
||||
"//xla:util",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/ffi:ffi_api",
|
||||
"//xla/pjrt:exceptions",
|
||||
"//xla/pjrt:mlir_to_hlo",
|
||||
@@ -1333,7 +1334,6 @@ tsl_pybind_extension(
|
||||
"//xla/python/pjrt_ifrt",
|
||||
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
|
||||
"//xla/python/pjrt_ifrt:xla_ifrt",
|
||||
"//xla/service/cpu:collectives_interface",
|
||||
"//xla/tsl/concurrency:ref_count",
|
||||
"//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
|
||||
"//xla/tsl/platform/cloud:gcs_file_system",
|
||||
|
||||
13
third_party/xla/xla/python/xla.cc
vendored
13
third_party/xla/xla/python/xla.cc
vendored
@@ -46,6 +46,7 @@ limitations under the License.
|
||||
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/variant.h" // IWYU pragma: keep
|
||||
#include "nanobind/stl/vector.h" // IWYU pragma: keep
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||
#include "xla/pjrt/distributed/client.h"
|
||||
#include "xla/pjrt/distributed/distributed.h"
|
||||
@@ -63,7 +64,6 @@ limitations under the License.
|
||||
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
|
||||
#include "xla/python/py_client.h"
|
||||
#include "xla/python/py_program.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/tsl/concurrency/ref_count.h"
|
||||
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
|
||||
|
||||
@@ -259,8 +259,7 @@ NB_MODULE(xla_extension, m) {
|
||||
|
||||
jax::BuildWeakrefLRUCacheAPI(m);
|
||||
|
||||
nb::class_<xla::cpu::CollectivesInterface> cpu_collectives(m,
|
||||
"CpuCollectives");
|
||||
nb::class_<xla::cpu::CpuCollectives> cpu_collectives(m, "CpuCollectives");
|
||||
|
||||
m.def(
|
||||
"make_gloo_tcp_collectives",
|
||||
@@ -268,7 +267,7 @@ NB_MODULE(xla_extension, m) {
|
||||
|
||||
std::optional<std::string> hostname,
|
||||
std::optional<std::string> interface)
|
||||
-> std::shared_ptr<xla::cpu::CollectivesInterface> {
|
||||
-> std::shared_ptr<xla::cpu::CpuCollectives> {
|
||||
#if defined(__linux__)
|
||||
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
|
||||
if (distributed_client != nullptr) {
|
||||
@@ -321,7 +320,7 @@ NB_MODULE(xla_extension, m) {
|
||||
});
|
||||
#else // !_WIN32 && !PLATFORM_GOOGLE
|
||||
m.def("make_mpi_collectives",
|
||||
[]() -> std::shared_ptr<xla::cpu::CollectivesInterface> {
|
||||
[]() -> std::shared_ptr<xla::cpu::CpuCollectives> {
|
||||
throw xla::XlaRuntimeError(
|
||||
"make_mpi_collectives is not implemented for Windows");
|
||||
});
|
||||
@@ -332,7 +331,7 @@ NB_MODULE(xla_extension, m) {
|
||||
[](bool asynchronous,
|
||||
std::shared_ptr<DistributedRuntimeClient> distributed_client,
|
||||
int node_id, int num_nodes,
|
||||
std::shared_ptr<xla::cpu::CollectivesInterface> collectives,
|
||||
std::shared_ptr<xla::cpu::CpuCollectives> collectives,
|
||||
std::optional<int> num_devices) -> nb_class_ptr<PyClient> {
|
||||
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
|
||||
{
|
||||
@@ -363,7 +362,7 @@ NB_MODULE(xla_extension, m) {
|
||||
nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr,
|
||||
nb::arg("node_id") = 0, nb::arg("num_nodes") = 1,
|
||||
nb::arg("collectives").none() =
|
||||
std::shared_ptr<xla::cpu::CollectivesInterface>(),
|
||||
std::shared_ptr<xla::cpu::CpuCollectives>(),
|
||||
nb::arg("num_devices").none() = std::nullopt);
|
||||
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
|
||||
absl::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);
|
||||
|
||||
37
third_party/xla/xla/service/cpu/BUILD
vendored
37
third_party/xla/xla/service/cpu/BUILD
vendored
@@ -1001,7 +1001,6 @@ cc_library(
|
||||
],
|
||||
copts = runtime_copts(),
|
||||
deps = [
|
||||
":collectives_interface",
|
||||
":cpu_executable_run_options",
|
||||
":in_process_collectives",
|
||||
"//xla:executable_run_options",
|
||||
@@ -1009,7 +1008,11 @@ cc_library(
|
||||
"//xla:types",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_clique",
|
||||
"//xla/backends/cpu/collectives:cpu_clique_key",
|
||||
"//xla/backends/cpu/collectives:cpu_cliques",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/hlo/parser:hlo_parser",
|
||||
"//xla/service:collective_ops_utils",
|
||||
@@ -1017,6 +1020,8 @@ cc_library(
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/stream_executor:stream_executor_h",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
"//xla/tsl/platform:status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
@@ -1029,9 +1034,6 @@ cc_library(
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@local_tsl//tsl/platform:errors",
|
||||
"@local_tsl//tsl/platform:logging",
|
||||
"@local_tsl//tsl/platform:status",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
],
|
||||
)
|
||||
@@ -1957,34 +1959,11 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "collectives_interface",
|
||||
hdrs = ["collectives_interface.h"],
|
||||
deps = [
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/core/collectives:clique_id",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/service:collective_ops_utils",
|
||||
"//xla/service:global_device_id",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "in_process_collectives",
|
||||
srcs = ["in_process_collectives.cc"],
|
||||
hdrs = ["in_process_collectives.h"],
|
||||
deps = [
|
||||
":collectives_interface",
|
||||
"//xla:refcounting_hash_map",
|
||||
"//xla:shape_util",
|
||||
"//xla:status_macros",
|
||||
@@ -1992,6 +1971,8 @@ cc_library(
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/collectives:cpu_collectives",
|
||||
"//xla/backends/cpu/collectives:in_process_communicator",
|
||||
"//xla/core/collectives:clique_id",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/core/collectives:communicator",
|
||||
"//xla/core/collectives:rank_id",
|
||||
"//xla/service:collective_ops_utils",
|
||||
@@ -2014,7 +1995,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "cpu_executable_run_options",
|
||||
hdrs = ["cpu_executable_run_options.h"],
|
||||
deps = [":collectives_interface"],
|
||||
deps = ["//xla/backends/cpu/collectives:cpu_collectives"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
/* Copyright 2023 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_
|
||||
#define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/core/collectives/rank_id.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// An adapter from a shared_ptr<Communicator> to a Communicator.
|
||||
class CommunicatorWrapper final : public Communicator {
|
||||
public:
|
||||
explicit CommunicatorWrapper(std::shared_ptr<Communicator> comm)
|
||||
: comm_(std::move(comm)) {}
|
||||
|
||||
absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer,
|
||||
stream_executor::DeviceMemoryBase recv_buffer,
|
||||
PrimitiveType dtype, size_t count,
|
||||
ReductionKind reduction_kind,
|
||||
const Executor& executor) final {
|
||||
return comm_->AllReduce(send_buffer, recv_buffer, dtype, count,
|
||||
reduction_kind, executor);
|
||||
}
|
||||
|
||||
absl::Status Broadcast(se::DeviceMemoryBase send_buffer,
|
||||
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
|
||||
size_t count, RankId root,
|
||||
const Executor& executor) final {
|
||||
return comm_->Broadcast(send_buffer, recv_buffer, dtype, count, root,
|
||||
executor);
|
||||
}
|
||||
|
||||
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
|
||||
se::DeviceMemoryBase recv_buffer,
|
||||
PrimitiveType dtype, size_t count,
|
||||
ReductionKind reduction_kind,
|
||||
const Executor& executor) final {
|
||||
return comm_->ReduceScatter(send_buffer, recv_buffer, dtype, count,
|
||||
reduction_kind, executor);
|
||||
}
|
||||
|
||||
absl::Status AllGather(se::DeviceMemoryBase send_buffer,
|
||||
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
|
||||
size_t count, const Executor& executor) final {
|
||||
return comm_->AllGather(send_buffer, recv_buffer, dtype, count, executor);
|
||||
}
|
||||
|
||||
absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer,
|
||||
se::DeviceMemoryBase recv_buffer,
|
||||
PrimitiveType dtype, size_t count,
|
||||
std::optional<RankId> source_rank,
|
||||
absl::Span<const RankId> target_ranks,
|
||||
const Executor& executor) final {
|
||||
return comm_->CollectivePermute(send_buffer, recv_buffer, dtype, count,
|
||||
source_rank, target_ranks, executor);
|
||||
}
|
||||
|
||||
absl::Status AllToAll(absl::Span<const se::DeviceMemoryBase> send_buffers,
|
||||
absl::Span<const se::DeviceMemoryBase> recv_buffers,
|
||||
PrimitiveType dtype, size_t count,
|
||||
const Executor& executor) final {
|
||||
return comm_->AllToAll(send_buffers, recv_buffers, dtype, count, executor);
|
||||
}
|
||||
|
||||
absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype,
|
||||
size_t count, RankId peer, const Executor& executor) final {
|
||||
return comm_->Send(send_buffer, dtype, count, peer, executor);
|
||||
}
|
||||
|
||||
absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
|
||||
size_t count, RankId peer, const Executor& executor) final {
|
||||
return comm_->Recv(recv_buffer, dtype, count, peer, executor);
|
||||
}
|
||||
|
||||
absl::StatusOr<size_t> NumRanks() const final { return comm_->NumRanks(); }
|
||||
|
||||
std::string ToString() const final { return comm_->ToString(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Communicator> comm_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
class CollectivesInterface : public CpuCollectives {
|
||||
public:
|
||||
virtual ~CollectivesInterface() = default;
|
||||
|
||||
// Builds a context for a collective group.
|
||||
// Args:
|
||||
// devices: the devices participating in this collective.
|
||||
// rank: the rank of this process.
|
||||
virtual absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> devices, int rank) = 0;
|
||||
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) final {
|
||||
// We expect to create CPU communicators lazily one at a time.
|
||||
if (ranks.size() != 1) {
|
||||
return InvalidArgument("Expected 1 rank, got %d", ranks.size());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto comm, GetCommunicator(clique_key.devices(),
|
||||
ranks[0].rank.value()));
|
||||
|
||||
std::vector<std::unique_ptr<Communicator>> comms;
|
||||
comms.reserve(1);
|
||||
comms.push_back(std::make_unique<internal::CommunicatorWrapper>(comm));
|
||||
return comms;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_
|
||||
@@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_
|
||||
#define XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_
|
||||
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
@@ -25,16 +25,16 @@ namespace xla::cpu {
|
||||
// dependencies to ExecutableRunOptions.
|
||||
class CpuExecutableRunOptions {
|
||||
public:
|
||||
CpuExecutableRunOptions& set_collectives(CollectivesInterface* collectives) {
|
||||
CpuExecutableRunOptions& set_collectives(CpuCollectives* collectives) {
|
||||
collectives_ = collectives;
|
||||
return *this;
|
||||
}
|
||||
CollectivesInterface* collectives() const { return collectives_; }
|
||||
CpuCollectives* collectives() const { return collectives_; }
|
||||
|
||||
private:
|
||||
// For cross-process collectives, use this collective implementation to
|
||||
// communicate.
|
||||
CollectivesInterface* collectives_;
|
||||
CpuCollectives* collectives_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
51
third_party/xla/xla/service/cpu/cpu_runtime.cc
vendored
51
third_party/xla/xla/service/cpu/cpu_runtime.cc
vendored
@@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -40,7 +39,10 @@ limitations under the License.
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_clique_key.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_cliques.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/core/collectives/rank_id.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/hlo/parser/hlo_parser.h"
|
||||
@@ -48,7 +50,6 @@ limitations under the License.
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/collective_ops_utils.h"
|
||||
#include "xla/service/computation_placer.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/cpu/cpu_executable_run_options.h"
|
||||
#include "xla/service/cpu/in_process_collectives.h"
|
||||
#include "xla/service/cpu/xfeed_manager.h"
|
||||
@@ -56,12 +57,11 @@ limitations under the License.
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/stream_executor/stream_executor.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/status.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "tsl/platform/errors.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
#include "tsl/platform/status.h"
|
||||
#include "tsl/profiler/lib/traceme.h"
|
||||
|
||||
namespace xla {
|
||||
@@ -339,13 +339,12 @@ RendezvousKey GetRendezvousKey(const ExecutableRunOptions* run_options,
|
||||
num_local_participants, op_kind, op_id};
|
||||
}
|
||||
|
||||
CollectivesInterface* GetInProcessCollectivesImpl() {
|
||||
CpuCollectives* GetInProcessCollectivesImpl() {
|
||||
static InProcessCollectives* c = new InProcessCollectives();
|
||||
return c;
|
||||
}
|
||||
|
||||
CollectivesInterface* GetCollectivesImpl(
|
||||
const ExecutableRunOptions* run_options) {
|
||||
CpuCollectives* GetCollectivesImpl(const ExecutableRunOptions* run_options) {
|
||||
if (run_options->cpu_executable_run_options() &&
|
||||
run_options->cpu_executable_run_options()->collectives()) {
|
||||
return run_options->cpu_executable_run_options()->collectives();
|
||||
@@ -386,14 +385,16 @@ void AllToAllImpl(const ExecutableRunOptions* run_options,
|
||||
|
||||
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
|
||||
|
||||
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
|
||||
CpuCollectives* collectives = GetCollectivesImpl(run_options);
|
||||
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(source_buffers,
|
||||
sizeof(void*) * num_buffers);
|
||||
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(destination_buffers,
|
||||
sizeof(void*) * num_buffers);
|
||||
auto communicator =
|
||||
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
|
||||
|
||||
CpuCliqueKey clique_key(rendezvous_key.global_devices);
|
||||
Communicator* communicator =
|
||||
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
|
||||
|
||||
CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());
|
||||
|
||||
@@ -428,10 +429,11 @@ void AllGatherImpl(const ExecutableRunOptions* run_options,
|
||||
|
||||
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
|
||||
|
||||
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
|
||||
CpuCollectives* collectives = GetCollectivesImpl(run_options);
|
||||
|
||||
auto communicator =
|
||||
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
|
||||
CpuCliqueKey clique_key(rendezvous_key.global_devices);
|
||||
Communicator* communicator =
|
||||
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
|
||||
|
||||
se::DeviceMemoryBase input_buffer_data(source_buffer, buffer_size);
|
||||
se::DeviceMemoryBase output_buffer_data(destination_buffer, buffer_size);
|
||||
@@ -461,10 +463,11 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options,
|
||||
|
||||
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
|
||||
|
||||
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
|
||||
CpuCollectives* collectives = GetCollectivesImpl(run_options);
|
||||
|
||||
auto communicator =
|
||||
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
|
||||
CpuCliqueKey clique_key(rendezvous_key.global_devices);
|
||||
Communicator* communicator =
|
||||
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
|
||||
|
||||
auto dtype = static_cast<PrimitiveType>(element_type);
|
||||
|
||||
@@ -506,10 +509,11 @@ void AllReduceImpl(const ExecutableRunOptions* run_options,
|
||||
|
||||
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
|
||||
|
||||
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
|
||||
CpuCollectives* collectives = GetCollectivesImpl(run_options);
|
||||
|
||||
auto communicator =
|
||||
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
|
||||
CpuCliqueKey clique_key(rendezvous_key.global_devices);
|
||||
Communicator* communicator =
|
||||
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
|
||||
|
||||
// Convert input/output buffers to DeviceMemoryBase.
|
||||
std::vector<se::DeviceMemoryBase> input_buffers_data;
|
||||
@@ -569,10 +573,11 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options,
|
||||
|
||||
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
|
||||
|
||||
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
|
||||
CpuCollectives* collectives = GetCollectivesImpl(run_options);
|
||||
|
||||
auto communicator =
|
||||
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
|
||||
CpuCliqueKey clique_key(rendezvous_key.global_devices);
|
||||
Communicator* communicator =
|
||||
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
|
||||
|
||||
CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());
|
||||
|
||||
|
||||
@@ -15,23 +15,29 @@ limitations under the License.
|
||||
|
||||
#include "xla/service/cpu/in_process_collectives.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/in_process_communicator.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu::runtime {
|
||||
|
||||
absl::StatusOr<std::shared_ptr<Communicator>>
|
||||
InProcessCollectives::GetCommunicator(absl::Span<GlobalDeviceId const> devices,
|
||||
int rank) {
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
InProcessCollectives::CreateCommunicators(
|
||||
int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks, const Config& config) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
std::shared_ptr<InProcessCommunicator::State> state = state_.lock();
|
||||
@@ -40,9 +46,14 @@ InProcessCollectives::GetCommunicator(absl::Span<GlobalDeviceId const> devices,
|
||||
state_ = state;
|
||||
}
|
||||
|
||||
// We don't care about devices here: we share rendezvous state globally.
|
||||
return std::make_shared<InProcessCommunicator>(std::move(state), rank,
|
||||
devices.size());
|
||||
std::vector<std::unique_ptr<Communicator>> communicators;
|
||||
for (auto& device_rank : ranks) {
|
||||
size_t rank = device_rank.rank.value();
|
||||
communicators.push_back(std::make_unique<InProcessCommunicator>(
|
||||
state, rank, clique_key.num_devices()));
|
||||
}
|
||||
|
||||
return communicators;
|
||||
}
|
||||
|
||||
} // namespace xla::cpu::runtime
|
||||
|
||||
@@ -16,25 +16,31 @@ limitations under the License.
|
||||
#ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_
|
||||
#define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
||||
#include "xla/backends/cpu/collectives/in_process_communicator.h"
|
||||
#include "xla/core/collectives/clique_id.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/core/collectives/communicator.h"
|
||||
#include "xla/service/cpu/collectives_interface.h"
|
||||
#include "xla/service/global_device_id.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu::runtime {
|
||||
|
||||
class InProcessCollectives : public CollectivesInterface {
|
||||
class InProcessCollectives : public CpuCollectives {
|
||||
public:
|
||||
// Thread-safe.
|
||||
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
|
||||
absl::Span<GlobalDeviceId const> devices, int rank) override;
|
||||
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
|
||||
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
|
||||
const std::optional<CliqueId>& clique_id,
|
||||
absl::Span<const DeviceRank> ranks,
|
||||
const Config& config) final;
|
||||
|
||||
private:
|
||||
absl::Mutex mu_;
|
||||
|
||||
@@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
Reference in New Issue
Block a user