Allow IFRT-proxy to expand error-status payloads that are specific to Pathways.

PiperOrigin-RevId: 826656416
This commit is contained in:
A. Unique TensorFlower
2025-10-31 15:47:38 -07:00
committed by TensorFlower Gardener
parent c655468288
commit 15e235f79b
10 changed files with 457 additions and 5 deletions

View File

@@ -1153,7 +1153,11 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
],
)
@@ -1194,6 +1198,7 @@ cc_library(
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:string_view",
],
alwayslink = True,
)
cc_library(

View File

@@ -15,12 +15,18 @@ limitations under the License.
#include "xla/python/ifrt/user_context_registry.h"
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/python/ifrt/user_context.h"
@@ -97,5 +103,27 @@ void UserContextRegistry::Unregister(
}
}
CustomStatusExpanderRegistry& CustomStatusExpanderRegistry::Get() {
static absl::NoDestructor<CustomStatusExpanderRegistry> registry;
return *registry;
}
void CustomStatusExpanderRegistry::Register(absl::string_view payload_name,
PayloadExpanderFn expander,
std::optional<int> process_order) {
absl::WriterMutexLock lock(mu_);
std::pair<int, std::string> key = {
process_order.value_or(std::numeric_limits<int>::max()),
std::string(payload_name)};
CHECK(registry_.insert({std::move(key), std::move(expander)}).second);
}
void CustomStatusExpanderRegistry::Process(absl::Status& status) {
absl::ReaderMutexLock lock(mu_);
for (const auto& [_, expander] : registry_) {
expander(status);
}
}
} // namespace ifrt
} // namespace xla

View File

@@ -16,13 +16,19 @@ limitations under the License.
#ifndef XLA_PYTHON_IFRT_USER_CONTEXT_REGISTRY_H_
#define XLA_PYTHON_IFRT_USER_CONTEXT_REGISTRY_H_
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/python/ifrt/user_context.h"
@@ -122,6 +128,36 @@ class TrackedUserContext {
absl_nonnull const UserContextRef user_context_;
};
// CustomStatusExpanderRegistry allows registering 'payload expanders' that
// errors returned by the IFRT backend are processed through before the error
// message is returned to IFRT users.
class CustomStatusExpanderRegistry {
public:
static CustomStatusExpanderRegistry& Get();
using PayloadExpanderFn = std::function<void(absl::Status&)>;
// Registers a payload expander. `expander` is expected to take the entire
// `absl::Status` object, remove the payload from the object, and modify the
// contents of the `absl::Status` accordingly.
//
// The optional `process_order`, if supplied, determines the order in which
// the expander is processed in relation to other expanders. Expanders with
// lower process orders are processed first; please use a positive value
// unless you have discussed with IFRT maintainers about writing a
// a critical expander function that needs to be processed earlier. Order
// among expanders of the same `process_order` is unspecified.
void Register(absl::string_view payload_name, PayloadExpanderFn expander,
std::optional<int> process_order = std::nullopt);
// Invokes all registered expanders on the given status.
void Process(absl::Status& status);
private:
mutable absl::Mutex mu_;
absl::btree_map<std::pair<int, std::string>, PayloadExpanderFn> registry_
ABSL_GUARDED_BY(mu_);
};
} // namespace ifrt
} // namespace xla

View File

@@ -100,14 +100,15 @@ absl::Status ReattachUserContextRefs(absl::Status status) {
return status;
}
absl::Status ExpandUserContexts(absl::Status status) {
static void ExpandStandardUserContext(absl::Status& status) {
if (status.ok()) {
return status;
return;
}
std::optional<absl::Cord> payload =
status.GetPayload(kIfrtUserContextPayloadUrl);
if (!payload.has_value()) {
return status;
return;
}
status.ErasePayload(kIfrtUserContextPayloadUrl);
@@ -117,7 +118,7 @@ absl::Status ExpandUserContexts(absl::Status status) {
tsl::errors::AppendToMessage(
&status, "\n(failed to parse a user context ID: ", payload->Flatten(),
")");
return status;
return;
}
TrackedUserContextRef user_context =
UserContextRegistry::Get().Lookup(UserContextId(user_context_id));
@@ -125,12 +126,23 @@ absl::Status ExpandUserContexts(absl::Status status) {
tsl::errors::AppendToMessage(
&status, "\n(failed to find a user context for ID: ", user_context_id,
")");
return status;
return;
}
tsl::errors::AppendToMessage(&status, "\n",
user_context->user_context()->DebugString());
}
absl::Status ExpandUserContexts(absl::Status status) {
CustomStatusExpanderRegistry::Get().Process(status);
return status;
}
static const bool register_standard_user_context = []() {
xla::ifrt::CustomStatusExpanderRegistry::Get().Register(
kIfrtUserContextPayloadUrl, ExpandStandardUserContext,
/*process_order=*/-1);
return true;
}();
} // namespace ifrt
} // namespace xla

View File

@@ -174,6 +174,7 @@ cc_library(
"//xla/python/ifrt_proxy/common:ifrt_service_proto_cc",
"//xla/python/ifrt_proxy/common:types",
"//xla/python/ifrt_proxy/common:versions",
"//xla/python/ifrt_proxy/contrib/pathways:status_annotator_util", # build_cleaner: keep
"//xla/python/pjrt_ifrt:pjrt_attribute_map_util",
"//xla/tsl/concurrency:future",
"//xla/tsl/concurrency:ref_count",

View File

@@ -0,0 +1,65 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//xla/python/ifrt_proxy/common:ifrt_proxy.bzl", "cc_library", "default_ifrt_proxy_visibility", "ifrt_proxy_cc_test")
load("//xla/tsl/platform:build_config.bzl", "tf_proto_library")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = default_ifrt_proxy_visibility,
)
cc_library(
name = "status_annotator_util",
srcs = ["status_annotator_util.cc"],
hdrs = ["status_annotator_util.h"],
deps = [
":status_annotator_proto_cc",
"//xla/python/ifrt:user_context",
"//xla/python/ifrt:user_context_registry",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status_to_from_proto",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
],
alwayslink = 1,
)
ifrt_proxy_cc_test(
name = "status_annotator_util_test",
srcs = ["status_annotator_util_test.cc"],
deps = [
":status_annotator_proto_cc",
":status_annotator_util",
"//xla/python/ifrt:test_util",
"//xla/python/ifrt:user_context_registry",
"//xla/python/ifrt:user_context_status_util",
"//xla/tsl/platform:status_to_from_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
tf_proto_library(
name = "status_annotator_proto",
srcs = ["status_annotator.proto"],
visibility = default_ifrt_proxy_visibility,
deps = ["//xla/tsl/protobuf:status_proto"],
)

View File

@@ -0,0 +1,49 @@
// Copyright 2025 The OpenXLA Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
edition = "2023";
package ifrt_proxy_contrib_pathways;
import "xla/tsl/protobuf/status.proto";
option java_multiple_files = true;
// Summarizes a Pathways worker-side ObjectStore in a way relevant to users
// debugging an HBM OOM.
message ObjectStoreDumpProto {
// TODO(b/456800998): Rename 'ErrorContext' to 'UserContext'.
message PerErrorContext {
message PerCreator {
string creator = 1; // The kind of operation that created this object.
// Count and size of all objects whose `GetReadyFuture().ready()` is true.
uint64 ready_obj_count = 2;
uint64 ready_total_size = 3;
// Count and size of objects whose `GetReadyFuture().ready()` is false.
uint64 not_ready_obj_count = 4;
uint64 not_ready_total_size = 5;
}
uint64 error_context_id = 1;
repeated PerCreator per_creator = 2;
}
string device = 1;
fixed64 dump_timestamp_ns = 2;
// One of 'dump_failed' or 'per_error_context' is filled.
tensorflow.StatusProto dump_failed = 3;
repeated PerErrorContext per_error_context = 4;
}

View File

@@ -0,0 +1,118 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/python/ifrt_proxy/contrib/pathways/status_annotator_util.h"
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "xla/python/ifrt/user_context.h"
#include "xla/python/ifrt/user_context_registry.h"
#include "xla/python/ifrt_proxy/contrib/pathways/status_annotator.pb.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/status_to_from_proto.h"
namespace ifrt_proxy_contrib_pathways {
static constexpr absl::string_view kObjectStoreDumpPayloadUrl =
"type.googleapis.com/ifrt_proxy_contrib_pathways.ObjectStoreDumpProto";
static void ExpandObjectStoreDump(absl::Status& status) {
std::optional<absl::Cord> payload =
status.GetPayload(kObjectStoreDumpPayloadUrl);
if (!payload.has_value()) {
return;
}
ObjectStoreDumpProto object_store_dump;
if (!object_store_dump.ParseFromString(payload->Flatten())) {
LOG(WARNING) << "Unable to expand string to ObjectStoreDumpProto: "
<< payload->Flatten();
tsl::errors::AppendToMessage(
&status,
"\nWARNING: Unable to parse attached payload string to "
"ObjectStoreDumpProto. Please see logs for the actual payload string.");
return;
}
status.ErasePayload(kObjectStoreDumpPayloadUrl);
std::string header = absl::StrCat(
"Pathways object-store summary for device ", object_store_dump.device(),
" at ", absl::FromUnixNanos(object_store_dump.dump_timestamp_ns()));
absl::Status error = tsl::StatusFromProto(object_store_dump.dump_failed());
if (!error.ok()) {
tsl::errors::AppendToMessage(
&status, "\n", header, " got error while dumping: ", error.ToString());
return;
}
std::vector<std::string> cited_traces;
tsl::errors::AppendToMessage(&status, "\n", header, ":");
for (const auto& per_error_context : object_store_dump.per_error_context()) {
xla::ifrt::TrackedUserContextRef tracked_user_context =
xla::ifrt::UserContextRegistry::Get().Lookup(
xla::ifrt::UserContextId(per_error_context.error_context_id()));
if (tracked_user_context != nullptr) {
std::string trace = tracked_user_context->user_context()->DebugString();
absl::StrReplaceAll({{"\n", "\t"}}, &trace);
absl::StrReplaceAll({{"\t", "\n "}}, &trace);
cited_traces.push_back(trace);
tsl::errors::AppendToMessage(
&status, " - The following entries arise from user stack [",
cited_traces.size(), "]:");
} else {
tsl::errors::AppendToMessage(
&status,
" - The following entries arise from an unknown user stack:");
}
for (const auto& per_creator : per_error_context.per_creator()) {
tsl::errors::AppendToMessage(&status, " + ", per_creator.creator(),
" with ", per_creator.ready_obj_count(),
" 'ready' buffers of total size ",
per_creator.ready_total_size(), " and ",
per_creator.not_ready_obj_count(),
" 'not ready' buffers of total size ",
per_creator.not_ready_total_size());
}
}
for (int i = 0; i < cited_traces.size(); ++i) {
tsl::errors::AppendToMessage(
&status, absl::StrFormat("[%3d] %s", i + 1, cited_traces[i]));
}
}
void AnnotateIfrtUserStatusWithObjectStoreDump(
absl::Status& status, const ObjectStoreDumpProto& object_store_dump) {
status.SetPayload(kObjectStoreDumpPayloadUrl,
object_store_dump.SerializeAsCord());
}
static const bool register_expanders = []() {
xla::ifrt::CustomStatusExpanderRegistry::Get().Register(
kObjectStoreDumpPayloadUrl, ExpandObjectStoreDump);
return true;
}();
} // namespace ifrt_proxy_contrib_pathways

View File

@@ -0,0 +1,30 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_PYTHON_IFRT_PROXY_CONTRIB_PATHWAYS_STATUS_ANNOTATOR_UTIL_H_
#define XLA_PYTHON_IFRT_PROXY_CONTRIB_PATHWAYS_STATUS_ANNOTATOR_UTIL_H_
#include "absl/status/status.h"
#include "xla/python/ifrt_proxy/contrib/pathways/status_annotator.pb.h"
namespace ifrt_proxy_contrib_pathways {
// Attaches the given `object_store_dump` to the given `status` as a payload.
void AnnotateIfrtUserStatusWithObjectStoreDump(
absl::Status& status, const ObjectStoreDumpProto& object_store_dump);
} // namespace ifrt_proxy_contrib_pathways
#endif // XLA_PYTHON_IFRT_PROXY_CONTRIB_PATHWAYS_STATUS_ANNOTATOR_UTIL_H_

View File

@@ -0,0 +1,108 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/python/ifrt_proxy/contrib/pathways/status_annotator_util.h"
#include <string>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/python/ifrt/user_context_registry.h"
#include "xla/python/ifrt/user_context_status_util.h"
#include "xla/python/ifrt_proxy/contrib/pathways/status_annotator.pb.h"
#include "xla/tsl/platform/status_to_from_proto.h"
namespace ifrt_proxy_contrib_pathways {
namespace {
using ::testing::HasSubstr;
TEST(StatusAnnotatorUtilTest, SimpleAnnotateAndExpand) {
ObjectStoreDumpProto object_store_dump;
object_store_dump.set_device("tpu:0");
absl::Status status = absl::InternalError("test error");
AnnotateIfrtUserStatusWithObjectStoreDump(status, object_store_dump);
EXPECT_THAT(xla::ifrt::ExpandUserContexts(status).message(),
HasSubstr("tpu:0"));
}
TEST(StatusAnnotatorUtilTest, ExpandFailedDump) {
ObjectStoreDumpProto object_store_dump;
*object_store_dump.mutable_dump_failed() =
tsl::StatusToProto(absl::InternalError("dump failed"));
object_store_dump.set_device("tpu:0");
absl::Status status = absl::InternalError("test error");
AnnotateIfrtUserStatusWithObjectStoreDump(status, object_store_dump);
EXPECT_THAT(xla::ifrt::ExpandUserContexts(status).message(),
HasSubstr("tpu:0"));
EXPECT_THAT(xla::ifrt::ExpandUserContexts(status).message(),
HasSubstr("dump failed"));
}
TEST(StatusAnnotatorUtilTest, DumpWithManyContentsExpandsToAllDetails) {
ObjectStoreDumpProto object_store_dump;
object_store_dump.set_device("tpu:0");
std::vector<xla::ifrt::TrackedUserContextRef> user_context_refs;
for (int context : {1000, 2000, 3000}) {
user_context_refs.push_back(xla::ifrt::UserContextRegistry::Get().Register(
xla::ifrt::test_util::MakeUserContext(context)));
auto* per_error_context = object_store_dump.add_per_error_context();
per_error_context->set_error_context_id(context);
for (int creator : {100, 200}) {
auto* per_creator = per_error_context->add_per_creator();
*per_creator->mutable_creator() = absl::StrCat("creator", creator);
per_creator->set_ready_obj_count(context + creator + 1);
per_creator->set_ready_total_size(context + creator + 2);
per_creator->set_not_ready_obj_count(context + creator + 3);
per_creator->set_not_ready_total_size(context + creator + 4);
}
}
absl::Status status = absl::InternalError("test error");
AnnotateIfrtUserStatusWithObjectStoreDump(status, object_store_dump);
std::string expanded(xla::ifrt::ExpandUserContexts(status).message());
for (int context : {1000, 2000, 3000}) {
EXPECT_THAT(expanded,
HasSubstr(absl::StrCat("TestUserContext(", context, ")")));
for (int creator : {100, 200}) {
EXPECT_THAT(expanded, HasSubstr(absl::StrCat("creator", creator)));
EXPECT_THAT(expanded,
HasSubstr(absl::StrCat(context + creator + 1,
" 'ready' buffers of total size ",
context + creator + 2)));
EXPECT_THAT(expanded,
HasSubstr(absl::StrCat(context + creator + 3,
" 'not ready' buffers of total size ",
context + creator + 4)));
}
}
}
} // namespace
} // namespace ifrt_proxy_contrib_pathways