[xla:cpu] Replace xla::cpu::CollectivesInterface with xla::cpu::CpuCollectives

PiperOrigin-RevId: 713661518
This commit is contained in:
Eugene Zhulenev
2025-01-09 07:14:11 -08:00
committed by TensorFlower Gardener
parent cf43bb53b5
commit 7b49ba401a
31 changed files with 211 additions and 370 deletions

View File

@@ -112,7 +112,6 @@ cc_library(
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
@@ -184,7 +183,6 @@ cc_library(
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",

View File

@@ -150,7 +150,6 @@ cc_library(
"//xla/ffi:execution_context",
"//xla/runtime:buffer_use",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/service/cpu:cpu_executable_run_options",
"//xla/service/cpu:cpu_runtime",
"//xla/service/cpu:in_process_collectives",
@@ -193,13 +192,13 @@ xla_cc_test(
deps = [
":thunk",
"//xla:executable_run_options",
"//xla/service/cpu:collectives_interface",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/service/cpu:cpu_executable_run_options",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/platform:test",
"@local_tsl//tsl/platform:test_main",
],
)
@@ -337,7 +336,6 @@ cc_library(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
@@ -466,7 +464,6 @@ cc_library(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
@@ -501,7 +498,6 @@ cc_library(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
@@ -533,7 +529,6 @@ cc_library(
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
@@ -567,7 +562,6 @@ cc_library(
"//xla/service:buffer_assignment",
"//xla/service:collective_ops_utils",
"//xla/service:computation_placer",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
@@ -606,7 +600,6 @@ cc_library(
"//xla/service:collective_ops_utils",
"//xla/service:computation_placer",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",

View File

@@ -29,7 +29,6 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"

View File

@@ -33,7 +33,6 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"

View File

@@ -28,7 +28,6 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"

View File

@@ -37,7 +37,6 @@ limitations under the License.
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"

View File

@@ -31,7 +31,6 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"

View File

@@ -30,7 +30,6 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"

View File

@@ -25,7 +25,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/executable_run_options.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/cpu/cpu_executable_run_options.h"
#include "xla/service/cpu/in_process_collectives.h"
#include "xla/service/global_device_id.h"
@@ -103,7 +102,7 @@ Thunk::CollectiveExecuteParams::Create(
// Default implementation of a collectives interface that can execute
// collective operations within the same process.
static CollectivesInterface* in_process_collectives =
static CpuCollectives* in_process_collectives =
new runtime::InProcessCollectives();
// If CPU executable run options are set, use the collectives interface
@@ -111,7 +110,7 @@ Thunk::CollectiveExecuteParams::Create(
// in-process collectives interface.
const CpuExecutableRunOptions* cpu_run_options =
run_options->cpu_executable_run_options();
CollectivesInterface* collectives =
CpuCollectives* collectives =
cpu_run_options && cpu_run_options->collectives()
? cpu_run_options->collectives()
: in_process_collectives;

View File

@@ -17,11 +17,11 @@ limitations under the License.
#include <utility>
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/executable_run_options.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/cpu/cpu_executable_run_options.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
namespace xla::cpu {
namespace {
@@ -93,13 +93,12 @@ TEST(ThunkTest, CollectiveExecuteParams) {
// Test forwarding collectives interface from CpuExecutableRunOptions.
CpuExecutableRunOptions cpu_run_options;
cpu_run_options.set_collectives(
reinterpret_cast<CollectivesInterface*>(0x12345678));
reinterpret_cast<CpuCollectives*>(0x12345678));
run_options.set_cpu_executable_run_options(&cpu_run_options);
TF_ASSERT_OK_AND_ASSIGN(params,
Thunk::CollectiveExecuteParams::Create(&run_options));
EXPECT_EQ(params.collectives,
reinterpret_cast<CollectivesInterface*>(0x12345678));
EXPECT_EQ(params.collectives, reinterpret_cast<CpuCollectives*>(0x12345678));
}
} // namespace

View File

@@ -17,7 +17,6 @@ limitations under the License.
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
@@ -27,8 +26,8 @@ limitations under the License.
namespace xla {
CliqueKey::CliqueKey(std::vector<GlobalDeviceId> devices)
: devices_(std::move(devices)) {}
CliqueKey::CliqueKey(absl::Span<const GlobalDeviceId> devices)
: devices_(devices.begin(), devices.end()) {}
absl::Span<const GlobalDeviceId> CliqueKey::devices() const { return devices_; }

View File

@@ -40,7 +40,7 @@ namespace xla {
// these cliques launch operations (device kernels) on different device streams.
class CliqueKey {
public:
explicit CliqueKey(std::vector<GlobalDeviceId> devices);
explicit CliqueKey(absl::Span<const GlobalDeviceId> devices);
virtual ~CliqueKey() = default;
CliqueKey(const CliqueKey& other) = default;

View File

@@ -151,6 +151,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"//xla/backends/cpu/codegen:cpu_features",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/backends/cpu/runtime:buffer_allocations",
"//xla/backends/cpu/runtime:thread_pool_task_runner",
"//xla/backends/cpu/runtime:thunk",
@@ -186,7 +187,6 @@ cc_library(
"//xla/service:hlo_proto_cc",
"//xla/service:hlo_value",
"//xla/service:maybe_owning_device_memory",
"//xla/service/cpu:collectives_interface",
"//xla/service/cpu:cpu_compiler",
"//xla/service/cpu:cpu_event",
"//xla/service/cpu:cpu_executable",
@@ -302,11 +302,12 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/backends/cpu/collectives:gloo_communicator",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
@@ -334,12 +335,16 @@ xla_cc_test(
":gloo_kv_store",
"//xla:executable_run_options",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_clique_key",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/pjrt/distributed:in_memory_key_value_store",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"//xla/stream_executor:device_memory",
"//xla/tsl/lib/core:status_test_util",
"//xla/tsl/platform:env",
@@ -384,11 +389,13 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/backends/cpu/collectives:mpi_communicator",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",

View File

@@ -49,6 +49,7 @@ limitations under the License.
#include "mlir/IR/BuiltinOps.h"
#include "xla/array.h"
#include "xla/backends/cpu/codegen/cpu_features.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/buffer_allocations.h"
#include "xla/backends/cpu/runtime/thread_pool_task_runner.h"
#include "xla/backends/cpu/runtime/thunk.h"
@@ -85,7 +86,6 @@ limitations under the License.
#include "xla/service/buffer_assignment.h"
#include "xla/service/compiler.h"
#include "xla/service/computation_placer.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/cpu/cpu_compiler.h"
#include "xla/service/cpu/cpu_event.h"
#include "xla/service/cpu/cpu_executable.h"
@@ -311,7 +311,7 @@ static tsl::ThreadOptions GetThreadOptions() {
TfrtCpuClient::TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives, size_t num_threads,
std::shared_ptr<cpu::CpuCollectives> collectives, size_t num_threads,
bool asynchronous,
std::function<void(HloModuleConfig&)> customize_hlo_module_config)
: process_index_(process_index),

View File

@@ -38,6 +38,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "mlir/IR/BuiltinOps.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
@@ -57,7 +58,6 @@ limitations under the License.
#include "xla/pjrt/transpose.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/computation_placer.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/cpu/cpu_event.h"
#include "xla/service/executable.h"
#include "xla/service/hlo.pb.h"
@@ -77,8 +77,8 @@ class TfrtCpuClient final : public PjRtClient {
public:
TfrtCpuClient(
int process_index, std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
std::shared_ptr<cpu::CollectivesInterface> collectives,
size_t num_threads, bool asynchronous,
std::shared_ptr<cpu::CpuCollectives> collectives, size_t num_threads,
bool asynchronous,
std::function<void(HloModuleConfig&)> customize_hlo_module_config);
~TfrtCpuClient() override;
@@ -288,7 +288,7 @@ class TfrtCpuClient final : public PjRtClient {
absl::Mutex transpose_mu_;
TransposePlanCache transpose_cache_ ABSL_GUARDED_BY(transpose_mu_);
std::shared_ptr<cpu::CollectivesInterface> collectives_;
std::shared_ptr<cpu::CpuCollectives> collectives_;
xla::CpuTopologyDescription topology_;

View File

@@ -15,10 +15,12 @@ limitations under the License.
#include "xla/pjrt/cpu/gloo_collectives.h"
#include <cstddef>
#include <cstdint>
#include <exception>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
@@ -27,7 +29,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "gloo/context.h"
#include "gloo/rendezvous/context.h"
@@ -35,6 +36,8 @@ limitations under the License.
#include "gloo/rendezvous/store.h"
#include "gloo/transport/device.h"
#include "xla/backends/cpu/collectives/gloo_communicator.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
@@ -48,42 +51,38 @@ GlooCollectives::GlooCollectives(
GlooCollectives::~GlooCollectives() = default;
absl::StatusOr<std::shared_ptr<Communicator>> GlooCollectives::GetCommunicator(
absl::Span<GlobalDeviceId const> global_devices, int rank) {
Context* context;
{
absl::MutexLock lock(&mu_);
auto& context_ref = contexts_[std::make_tuple(
std::vector<GlobalDeviceId>(global_devices.begin(),
global_devices.end()),
rank)];
if (!context_ref) {
context_ref = std::make_unique<Context>();
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
GlooCollectives::CreateCommunicators(int32_t nranks,
const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) {
std::vector<std::unique_ptr<Communicator>> communicators;
for (auto& device_rank : ranks) {
size_t rank = device_rank.rank.value();
auto gloo_context = std::make_shared<gloo::rendezvous::Context>(
rank, clique_key.num_devices());
auto prefix_store = gloo::rendezvous::PrefixStore(
absl::StrCat("gloo/",
absl::StrJoin(clique_key.devices(), ",",
[](std::string* out, GlobalDeviceId id) {
absl::StrAppend(out, id.value());
})),
*store_);
try {
gloo_context->connectFullMesh(prefix_store, device_);
} catch (std::exception& e) {
return absl::UnknownError(
absl::StrCat("Gloo context initialization failed: ", e.what()));
}
context = context_ref.get();
communicators.push_back(std::make_unique<GlooCommunicator>(
std::move(gloo_context), rank, clique_key.num_devices()));
}
absl::MutexLock context_lock(&context->mu);
if (context->communicator) {
return context->communicator;
}
auto gloo_context =
std::make_shared<gloo::rendezvous::Context>(rank, global_devices.size());
auto prefix_store = gloo::rendezvous::PrefixStore(
absl::StrCat("gloo/",
absl::StrJoin(global_devices, ",",
[](std::string* out, GlobalDeviceId id) {
absl::StrAppend(out, id.value());
})),
*store_);
try {
gloo_context->connectFullMesh(prefix_store, device_);
} catch (std::exception& e) {
return absl::UnknownError(
absl::StrCat("Gloo context initialization failed: ", e.what()));
}
context->communicator = std::make_shared<GlooCommunicator>(
std::move(gloo_context), rank, global_devices.size());
return context->communicator;
return communicators;
}
} // namespace xla::cpu

View File

@@ -16,49 +16,39 @@ limitations under the License.
#ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_
#define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_
#include <cstdint>
#include <memory>
#include <tuple>
#include <optional>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "gloo/context.h"
#include "gloo/rendezvous/store.h"
#include "gloo/transport/device.h"
#include "xla/backends/cpu/collectives/gloo_communicator.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
class GlooCollectives : public CollectivesInterface {
class GlooCollectives : public CpuCollectives {
public:
GlooCollectives(std::unique_ptr<gloo::rendezvous::Store> store,
std::shared_ptr<gloo::transport::Device> device);
~GlooCollectives() override;
// Thread-safe.
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
absl::Span<GlobalDeviceId const> devices, int rank) override;
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
private:
struct Context {
absl::Mutex mu;
std::shared_ptr<GlooCommunicator> communicator;
};
std::unique_ptr<gloo::rendezvous::Store> store_;
std::shared_ptr<gloo::transport::Device> device_;
absl::Mutex mu_;
absl::flat_hash_map<std::tuple<std::vector<GlobalDeviceId>, int>,
std::unique_ptr<Context>>
contexts_ ABSL_GUARDED_BY(mu_);
};
} // namespace xla::cpu

View File

@@ -20,18 +20,22 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_clique_key.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/executable_run_options.h"
#include "xla/pjrt/cpu/gloo_kv_store.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/lib/core/status_test_util.h"
@@ -59,7 +63,7 @@ constexpr int kNumParticipants = 2;
constexpr size_t kBufferSize = 256;
constexpr absl::Duration kTimeout = absl::Seconds(5);
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
absl::StatusOr<std::unique_ptr<Communicator>> GetCommunicator(
size_t kNumParticipants, absl::Span<GlobalDeviceId const> global_devices,
const std::shared_ptr<xla::KeyValueStoreInterface>& kv_store, int rank) {
auto collectives = std::make_shared<cpu::GlooCollectives>(
@@ -69,7 +73,16 @@ absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
#elif defined(__APPLE__)
gloo::transport::uv::CreateDevice(gloo::transport::uv::attr()));
#endif // defined(__linux__)
return collectives->GetCommunicator(global_devices, rank);
CpuCliqueKey clique_key(global_devices);
CpuCollectives::DeviceRank device_rank(nullptr, RankId(rank));
TF_ASSIGN_OR_RETURN(auto communicators,
collectives->CreateCommunicators(
global_devices.size(), clique_key, std::nullopt,
{device_rank}, CpuCollectives::Config()));
return std::move(communicators[0]);
}
RendezvousKey MakeRendezvousKey(std::vector<GlobalDeviceId> global_devices) {

View File

@@ -15,8 +15,10 @@ limitations under the License.
#include "xla/pjrt/cpu/mpi_collectives.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <tuple>
#include <optional>
#include <vector>
#include "absl/log/log.h"
@@ -25,8 +27,9 @@ limitations under the License.
#include "absl/types/span.h"
#include "mpi.h"
#include "xla/backends/cpu/collectives/mpi_communicator.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
@@ -39,13 +42,13 @@ void MpiCollectives::Init() {
VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_;
}
void MpiCollectives::Finalize() {
contexts_.clear();
MPI_Finalize();
}
void MpiCollectives::Finalize() { MPI_Finalize(); }
absl::StatusOr<std::shared_ptr<Communicator>> MpiCollectives::GetCommunicator(
absl::Span<GlobalDeviceId const> global_devices, int rank) {
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
MpiCollectives::CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) {
int flag;
MPI_Is_thread_main(&flag);
if (!flag) {
@@ -55,23 +58,21 @@ absl::StatusOr<std::shared_ptr<Communicator>> MpiCollectives::GetCommunicator(
"threads/devices per process are not yet supported.");
}
auto& context = contexts_[std::make_tuple(
std::vector<GlobalDeviceId>(global_devices.begin(), global_devices.end()),
rank)];
if (context) {
return context;
std::vector<std::unique_ptr<Communicator>> communicators;
for (auto& device_rank : ranks) {
size_t rank = device_rank.rank.value();
int color;
int key = 0;
if (clique_key.num_devices() > 0) {
color = static_cast<int>(clique_key.devices().at(0).value());
key = rank;
} else {
color = MPI_UNDEFINED;
}
communicators.push_back(std::make_unique<MpiCommunicator>(color, key));
}
int color;
int key = 0;
if (global_devices.size() > 0) {
color = static_cast<int>(global_devices.at(0).value());
key = rank;
} else {
color = MPI_UNDEFINED;
}
context = std::make_shared<MpiCommunicator>(color, key);
return context;
return communicators;
}
} // namespace xla::cpu

View File

@@ -16,23 +16,24 @@ limitations under the License.
#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_
#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_
#include <cstdint>
#include <memory>
#include <tuple>
#include <optional>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/mpi_communicator.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
class MpiCollectives : public CollectivesInterface {
class MpiCollectives : public CpuCollectives {
public:
/*
The user has to explicitly call Init() and Finalize() before and
@@ -46,8 +47,11 @@ class MpiCollectives : public CollectivesInterface {
void Init();
void Finalize();
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
absl::Span<GlobalDeviceId const> global_devices, int rank) override;
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
private:
absl::Status ExchangeGlobalDeviceIds(
@@ -55,9 +59,6 @@ class MpiCollectives : public CollectivesInterface {
int mpi_world_rank_;
int mpi_world_size_;
absl::flat_hash_map<std::tuple<std::vector<GlobalDeviceId>, int>,
std::shared_ptr<MpiCommunicator>>
contexts_;
};
} // namespace xla::cpu

View File

@@ -40,8 +40,8 @@ cc_library(
srcs = [],
hdrs = ["cpu_client_options.h"],
deps = [
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/service:hlo_module_config",
"//xla/service/cpu:collectives_interface",
],
)

View File

@@ -20,7 +20,7 @@ limitations under the License.
#include <memory>
#include <optional>
#include "xla/service/cpu/collectives_interface.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/service/hlo_module_config.h"
namespace xla {
@@ -45,7 +45,7 @@ struct CpuClientOptions {
// Distributed collectives implementation. Optional. If not provided, an
// in-process collectives implementation will be used.
std::shared_ptr<cpu::CollectivesInterface> collectives;
std::shared_ptr<cpu::CpuCollectives> collectives;
// If defined this function will be called on the HloModuleConfig before
// compilation, and allows users to set custom flags.

View File

@@ -1307,6 +1307,7 @@ tsl_pybind_extension(
"//xla:shape_util",
"//xla:types",
"//xla:util",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/ffi:ffi_api",
"//xla/pjrt:exceptions",
"//xla/pjrt:mlir_to_hlo",
@@ -1333,7 +1334,6 @@ tsl_pybind_extension(
"//xla/python/pjrt_ifrt",
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"//xla/python/pjrt_ifrt:xla_ifrt",
"//xla/service/cpu:collectives_interface",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
"//xla/tsl/platform/cloud:gcs_file_system",

View File

@@ -46,6 +46,7 @@ limitations under the License.
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/variant.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
@@ -63,7 +64,6 @@ limitations under the License.
#include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h"
#include "xla/python/py_client.h"
#include "xla/python/py_program.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/python/lib/core/numpy.h" // NOLINT
@@ -259,8 +259,7 @@ NB_MODULE(xla_extension, m) {
jax::BuildWeakrefLRUCacheAPI(m);
nb::class_<xla::cpu::CollectivesInterface> cpu_collectives(m,
"CpuCollectives");
nb::class_<xla::cpu::CpuCollectives> cpu_collectives(m, "CpuCollectives");
m.def(
"make_gloo_tcp_collectives",
@@ -268,7 +267,7 @@ NB_MODULE(xla_extension, m) {
std::optional<std::string> hostname,
std::optional<std::string> interface)
-> std::shared_ptr<xla::cpu::CollectivesInterface> {
-> std::shared_ptr<xla::cpu::CpuCollectives> {
#if defined(__linux__)
std::shared_ptr<KeyValueStoreInterface> kv_store = nullptr;
if (distributed_client != nullptr) {
@@ -321,7 +320,7 @@ NB_MODULE(xla_extension, m) {
});
#else // !_WIN32 && !PLATFORM_GOOGLE
m.def("make_mpi_collectives",
[]() -> std::shared_ptr<xla::cpu::CollectivesInterface> {
[]() -> std::shared_ptr<xla::cpu::CpuCollectives> {
throw xla::XlaRuntimeError(
"make_mpi_collectives is not implemented for Windows");
});
@@ -332,7 +331,7 @@ NB_MODULE(xla_extension, m) {
[](bool asynchronous,
std::shared_ptr<DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes,
std::shared_ptr<xla::cpu::CollectivesInterface> collectives,
std::shared_ptr<xla::cpu::CpuCollectives> collectives,
std::optional<int> num_devices) -> nb_class_ptr<PyClient> {
std::unique_ptr<ifrt::PjRtClient> ifrt_client;
{
@@ -363,7 +362,7 @@ NB_MODULE(xla_extension, m) {
nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr,
nb::arg("node_id") = 0, nb::arg("num_nodes") = 1,
nb::arg("collectives").none() =
std::shared_ptr<xla::cpu::CollectivesInterface>(),
std::shared_ptr<xla::cpu::CpuCollectives>(),
nb::arg("num_devices").none() = std::nullopt);
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
absl::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);

View File

@@ -1001,7 +1001,6 @@ cc_library(
],
copts = runtime_copts(),
deps = [
":collectives_interface",
":cpu_executable_run_options",
":in_process_collectives",
"//xla:executable_run_options",
@@ -1009,7 +1008,11 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_clique",
"//xla/backends/cpu/collectives:cpu_clique_key",
"//xla/backends/cpu/collectives:cpu_cliques",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/hlo/parser:hlo_parser",
"//xla/service:collective_ops_utils",
@@ -1017,6 +1020,8 @@ cc_library(
"//xla/service:global_device_id",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:stream_executor_h",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
@@ -1029,9 +1034,6 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/profiler/lib:traceme",
],
)
@@ -1957,34 +1959,11 @@ cc_library(
],
)
cc_library(
name = "collectives_interface",
hdrs = ["collectives_interface.h"],
deps = [
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "in_process_collectives",
srcs = ["in_process_collectives.cc"],
hdrs = ["in_process_collectives.h"],
deps = [
":collectives_interface",
"//xla:refcounting_hash_map",
"//xla:shape_util",
"//xla:status_macros",
@@ -1992,6 +1971,8 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/collectives:cpu_collectives",
"//xla/backends/cpu/collectives:in_process_communicator",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
"//xla/service:collective_ops_utils",
@@ -2014,7 +1995,7 @@ cc_library(
cc_library(
name = "cpu_executable_run_options",
hdrs = ["cpu_executable_run_options.h"],
deps = [":collectives_interface"],
deps = ["//xla/backends/cpu/collectives:cpu_collectives"],
)
cc_library(

View File

@@ -1,154 +0,0 @@
/* Copyright 2023 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_
#define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/global_device_id.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
namespace internal {
// An adapter from a shared_ptr<Communicator> to a Communicator.
class CommunicatorWrapper final : public Communicator {
public:
explicit CommunicatorWrapper(std::shared_ptr<Communicator> comm)
: comm_(std::move(comm)) {}
absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer,
stream_executor::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
const Executor& executor) final {
return comm_->AllReduce(send_buffer, recv_buffer, dtype, count,
reduction_kind, executor);
}
absl::Status Broadcast(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, RankId root,
const Executor& executor) final {
return comm_->Broadcast(send_buffer, recv_buffer, dtype, count, root,
executor);
}
absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
ReductionKind reduction_kind,
const Executor& executor) final {
return comm_->ReduceScatter(send_buffer, recv_buffer, dtype, count,
reduction_kind, executor);
}
absl::Status AllGather(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, const Executor& executor) final {
return comm_->AllGather(send_buffer, recv_buffer, dtype, count, executor);
}
absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer,
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
std::optional<RankId> source_rank,
absl::Span<const RankId> target_ranks,
const Executor& executor) final {
return comm_->CollectivePermute(send_buffer, recv_buffer, dtype, count,
source_rank, target_ranks, executor);
}
absl::Status AllToAll(absl::Span<const se::DeviceMemoryBase> send_buffers,
absl::Span<const se::DeviceMemoryBase> recv_buffers,
PrimitiveType dtype, size_t count,
const Executor& executor) final {
return comm_->AllToAll(send_buffers, recv_buffers, dtype, count, executor);
}
absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype,
size_t count, RankId peer, const Executor& executor) final {
return comm_->Send(send_buffer, dtype, count, peer, executor);
}
absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype,
size_t count, RankId peer, const Executor& executor) final {
return comm_->Recv(recv_buffer, dtype, count, peer, executor);
}
absl::StatusOr<size_t> NumRanks() const final { return comm_->NumRanks(); }
std::string ToString() const final { return comm_->ToString(); }
private:
std::shared_ptr<Communicator> comm_;
};
} // namespace internal
class CollectivesInterface : public CpuCollectives {
public:
virtual ~CollectivesInterface() = default;
// Builds a context for a collective group.
// Args:
// devices: the devices participating in this collective.
// rank: the rank of this process.
virtual absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
absl::Span<GlobalDeviceId const> devices, int rank) = 0;
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final {
// We expect to create CPU communicators lazily one at a time.
if (ranks.size() != 1) {
return InvalidArgument("Expected 1 rank, got %d", ranks.size());
}
TF_ASSIGN_OR_RETURN(auto comm, GetCommunicator(clique_key.devices(),
ranks[0].rank.value()));
std::vector<std::unique_ptr<Communicator>> comms;
comms.reserve(1);
comms.push_back(std::make_unique<internal::CommunicatorWrapper>(comm));
return comms;
}
};
} // namespace xla::cpu
#endif // XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_

View File

@@ -16,7 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_
#define XLA_SERVICE_CPU_CPU_EXECUTABLE_RUN_OPTIONS_H_
#include "xla/service/cpu/collectives_interface.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
namespace xla::cpu {
@@ -25,16 +25,16 @@ namespace xla::cpu {
// dependencies to ExecutableRunOptions.
class CpuExecutableRunOptions {
public:
CpuExecutableRunOptions& set_collectives(CollectivesInterface* collectives) {
CpuExecutableRunOptions& set_collectives(CpuCollectives* collectives) {
collectives_ = collectives;
return *this;
}
CollectivesInterface* collectives() const { return collectives_; }
CpuCollectives* collectives() const { return collectives_; }
private:
// For cross-process collectives, use this collective implementation to
// communicate.
CollectivesInterface* collectives_;
CpuCollectives* collectives_;
};
} // namespace xla::cpu

View File

@@ -20,7 +20,6 @@ limitations under the License.
#include <cstdio>
#include <cstring>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <utility>
@@ -40,7 +39,10 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_clique_key.h"
#include "xla/backends/cpu/collectives/cpu_cliques.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/parser/hlo_parser.h"
@@ -48,7 +50,6 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/cpu/cpu_executable_run_options.h"
#include "xla/service/cpu/in_process_collectives.h"
#include "xla/service/cpu/xfeed_manager.h"
@@ -56,12 +57,11 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/status.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
#include "tsl/profiler/lib/traceme.h"
namespace xla {
@@ -339,13 +339,12 @@ RendezvousKey GetRendezvousKey(const ExecutableRunOptions* run_options,
num_local_participants, op_kind, op_id};
}
CollectivesInterface* GetInProcessCollectivesImpl() {
CpuCollectives* GetInProcessCollectivesImpl() {
static InProcessCollectives* c = new InProcessCollectives();
return c;
}
CollectivesInterface* GetCollectivesImpl(
const ExecutableRunOptions* run_options) {
CpuCollectives* GetCollectivesImpl(const ExecutableRunOptions* run_options) {
if (run_options->cpu_executable_run_options() &&
run_options->cpu_executable_run_options()->collectives()) {
return run_options->cpu_executable_run_options()->collectives();
@@ -386,14 +385,16 @@ void AllToAllImpl(const ExecutableRunOptions* run_options,
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
CpuCollectives* collectives = GetCollectivesImpl(run_options);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(source_buffers,
sizeof(void*) * num_buffers);
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(destination_buffers,
sizeof(void*) * num_buffers);
auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
CpuCliqueKey clique_key(rendezvous_key.global_devices);
Communicator* communicator =
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());
@@ -428,10 +429,11 @@ void AllGatherImpl(const ExecutableRunOptions* run_options,
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
CpuCollectives* collectives = GetCollectivesImpl(run_options);
auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
CpuCliqueKey clique_key(rendezvous_key.global_devices);
Communicator* communicator =
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
se::DeviceMemoryBase input_buffer_data(source_buffer, buffer_size);
se::DeviceMemoryBase output_buffer_data(destination_buffer, buffer_size);
@@ -461,10 +463,11 @@ void ReduceScatterImpl(const ExecutableRunOptions* run_options,
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
CpuCollectives* collectives = GetCollectivesImpl(run_options);
auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
CpuCliqueKey clique_key(rendezvous_key.global_devices);
Communicator* communicator =
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
auto dtype = static_cast<PrimitiveType>(element_type);
@@ -506,10 +509,11 @@ void AllReduceImpl(const ExecutableRunOptions* run_options,
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
CpuCollectives* collectives = GetCollectivesImpl(run_options);
auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
CpuCliqueKey clique_key(rendezvous_key.global_devices);
Communicator* communicator =
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
// Convert input/output buffers to DeviceMemoryBase.
std::vector<se::DeviceMemoryBase> input_buffers_data;
@@ -569,10 +573,11 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options,
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();
CollectivesInterface* collectives = GetCollectivesImpl(run_options);
CpuCollectives* collectives = GetCollectivesImpl(run_options);
auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
CpuCliqueKey clique_key(rendezvous_key.global_devices);
Communicator* communicator =
AcquireCommunicator(collectives, clique_key, RankId(rank)).value();
CpuCollectives::Executor executor(rendezvous_key, DefaultCollectiveTimeout());

View File

@@ -15,23 +15,29 @@ limitations under the License.
#include "xla/service/cpu/in_process_collectives.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>
#include <optional>
#include <vector>
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/in_process_communicator.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu::runtime {
absl::StatusOr<std::shared_ptr<Communicator>>
InProcessCollectives::GetCommunicator(absl::Span<GlobalDeviceId const> devices,
int rank) {
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
InProcessCollectives::CreateCommunicators(
int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks, const Config& config) {
absl::MutexLock lock(&mu_);
std::shared_ptr<InProcessCommunicator::State> state = state_.lock();
@@ -40,9 +46,14 @@ InProcessCollectives::GetCommunicator(absl::Span<GlobalDeviceId const> devices,
state_ = state;
}
// We don't care about devices here: we share rendezvous state globally.
return std::make_shared<InProcessCommunicator>(std::move(state), rank,
devices.size());
std::vector<std::unique_ptr<Communicator>> communicators;
for (auto& device_rank : ranks) {
size_t rank = device_rank.rank.value();
communicators.push_back(std::make_unique<InProcessCommunicator>(
state, rank, clique_key.num_devices()));
}
return communicators;
}
} // namespace xla::cpu::runtime

View File

@@ -16,25 +16,31 @@ limitations under the License.
#ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_
#define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_
#include <cstdint>
#include <memory>
#include <optional>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/collectives/in_process_communicator.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/cpu/collectives_interface.h"
#include "xla/service/global_device_id.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu::runtime {
class InProcessCollectives : public CollectivesInterface {
class InProcessCollectives : public CpuCollectives {
public:
// Thread-safe.
absl::StatusOr<std::shared_ptr<Communicator>> GetCommunicator(
absl::Span<GlobalDeviceId const> devices, int rank) override;
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(int32_t nranks, const CliqueKey& clique_key,
const std::optional<CliqueId>& clique_id,
absl::Span<const DeviceRank> ranks,
const Config& config) final;
private:
absl::Mutex mu_;

View File

@@ -23,7 +23,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "tsl/platform/logging.h"
#include "xla/tsl/platform/logging.h"
namespace xla {
namespace cpu {