Move mutability tracking from Tracked buffers to RawBuffers as this is only used for importing foreign memory and clutters the API.

PiperOrigin-RevId: 846916941
This commit is contained in:
Parker Schuh
2025-12-19 17:06:23 -08:00
committed by TensorFlower Gardener
parent f4a923fa82
commit 9d833374f9
14 changed files with 68 additions and 77 deletions

View File

@@ -158,8 +158,7 @@ CommonPjRtClient::BufferFromHostLiteral(const LiteralSlice& literal,
HostBufferSemantics::kImmutableUntilTransferCompletes,
raw_buffer));
return DefineBuffer(device_shape, memory_space, std::move(raw_buffer),
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true);
{std::move(definition_event)});
}
absl::StatusOr<std::unique_ptr<PjRtBuffer>>
@@ -193,8 +192,7 @@ CommonPjRtClient::CreateUninitializedBuffer(const Shape& shape,
raw_buffer->MakeAllocationReadyEvent());
TF_ASSIGN_OR_RETURN(auto output_buffer,
DefineBuffer(device_shape, memory_space, raw_buffer,
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
return output_buffer;
}
@@ -270,8 +268,7 @@ CommonPjRtClient::CreateAliasBuffer(const Shape& shape,
TF_ASSIGN_OR_RETURN(auto result_buffer,
DefineBuffer(shape, memory_space, std::move(raw_buffer),
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
return std::make_pair(std::move(result_buffer), std::move(fulfill_cb));
}
@@ -302,14 +299,14 @@ CommonPjRtClient::BufferFromHostBuffer(
ImportForeignMemory(
const_cast<void*>(data), // CONST_CAST_OK=flag controlled.
std::move(on_done_with_host_buffer), on_device_bytes_count,
memory_space));
memory_space,
host_buffer_semantics ==
PjRtClient::HostBufferSemantics::kMutableZeroCopy));
TF_ASSIGN_OR_RETURN(
auto output_buffer,
DefineBuffer(
device_shape, memory_space, raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{},
/*raw_buffer_is_mutable=*/host_buffer_semantics ==
PjRtClient::HostBufferSemantics::kMutableZeroCopy));
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{}));
return output_buffer;
}
}
@@ -327,8 +324,7 @@ CommonPjRtClient::BufferFromHostBuffer(
std::move(on_done_with_host_buffer), device_shape, raw_buffer));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> output_buffer,
DefineBuffer(device_shape, memory_space, raw_buffer,
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
return output_buffer;
}
@@ -351,12 +347,13 @@ CommonPjRtClient::CreateViewOfDeviceBuffer(
TF_ASSIGN_OR_RETURN(
auto raw_buffer,
ImportForeignMemory(device_ptr, std::move(on_delete_callback),
on_device_bytes_count, memory_space));
on_device_bytes_count, memory_space,
/*is_mutable=*/false));
TF_ASSIGN_OR_RETURN(
auto output_buffer,
DefineBuffer(device_shape, memory_space, raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{},
/*raw_buffer_is_mutable=*/false));
DefineBuffer(
device_shape, memory_space, raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{}));
return output_buffer;
}
@@ -705,9 +702,9 @@ static std::unique_ptr<PjRtBuffer> CreateOutputLeafBuffer(
CHECK(memory_space) << "No memory space found for device: "
<< device->DebugString() << " kind: " << kind_id;
}
auto buffer_or = client->DefineBuffer(
output_leaf_shape, memory_space, std::move(leaf_buffer),
{definition_event}, /*raw_buffer_is_mutable=*/true);
auto buffer_or =
client->DefineBuffer(output_leaf_shape, memory_space,
std::move(leaf_buffer), {definition_event});
CHECK_OK(buffer_or);
return *std::move(buffer_or);
}
@@ -1154,8 +1151,7 @@ CommonPjRtBufferImpl::CopyToCpuMemorySpace(const xla::Shape& dst_shape,
TF_ASSIGN_OR_RETURN(
auto buffer,
dst_client->DefineBuffer(dst_shape, dst_memory_space, dst_raw_buffer,
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
auto* base_ptr = dst_raw_buffer->GetHostPointer();
std::unique_ptr<MutableLiteralBase> literal;
bool needs_second_copy = false;
@@ -1265,8 +1261,7 @@ static absl::Status CommonCopyToMemorySpace(
TF_ASSIGN_OR_RETURN(
dst_buffer,
dst_client->DefineBuffer(dst_shape, dst_memory_space, dst_raw_buffer,
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
TF_RETURN_IF_ERROR(src_buffer->AcquireScopedRawBuffer(
[&](tsl::RCReference<CommonPjRtRawBuffer> buf_raw_buffer,
std::vector<tsl::RCReference<tsl::AsyncValue>>

View File

@@ -87,7 +87,7 @@ class CommonPjRtClient : public PjRtClient {
ImportForeignMemory(void* device_ptr,
absl::AnyInvocable<void() &&> on_delete_callback,
size_t on_device_bytes_count,
PjRtMemorySpace* memory_space) {
PjRtMemorySpace* memory_space, bool is_mutable) {
return absl::UnimplementedError("ImportForeignMemory is not supported");
}
@@ -105,8 +105,7 @@ class CommonPjRtClient : public PjRtClient {
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
definition_device_events,
bool raw_buffer_is_mutable) {
definition_device_events) {
return absl::UnimplementedError("DefineBuffer is not supported");
}

View File

@@ -888,10 +888,11 @@ static bool IsAlignedData(void* ptr) {
absl::StatusOr<tsl::RCReference<CommonPjRtRawBuffer>>
PjRtCpuClient::ImportForeignMemory(
void* device_ptr, absl::AnyInvocable<void() &&> on_delete_callback,
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) {
return CpuRawBuffer::ImportForeignMemory(device_ptr,
std::move(on_delete_callback),
on_device_bytes_count, memory_space);
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
bool is_mutable) {
return CpuRawBuffer::ImportForeignMemory(
device_ptr, std::move(on_delete_callback), on_device_bytes_count,
memory_space, is_mutable);
}
absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::CreateErrorBuffer(
@@ -910,7 +911,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::CreateErrorBuffer(
return std::make_unique<CommonPjRtBufferImpl>(
shape,
std::make_unique<TrackedCpuDeviceBuffer>(
/*owns_buffers=*/true, std::move(raw_buffer),
std::move(raw_buffer),
tsl::AsyncValueRef<CpuEvent>(
tsl::MakeErrorAsyncValueRef(std::move(error)))),
memory_space);
@@ -995,8 +996,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::DefineBuffer(
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
definition_device_events,
bool raw_buffer_is_mutable) {
definition_device_events) {
if (raw_buffer && raw_buffer->memory_space() != memory_space) {
return absl::InvalidArgumentError(
absl::StrFormat("DefineBuffer: Mismatch in memory spaces: %s vs %s",
@@ -1006,7 +1006,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::DefineBuffer(
return std::unique_ptr<PjRtBuffer>(std::make_unique<CommonPjRtBufferImpl>(
on_device_shape,
std::make_unique<TrackedCpuDeviceBuffer>(
/*owns_buffers=*/raw_buffer_is_mutable, std::move(raw_buffer),
std::move(raw_buffer),
CpuTrackedDeviceEvent::AfterAll(definition_device_events)),
memory_space));
}
@@ -1029,7 +1029,7 @@ PjRtCpuClient::CreateRawBufferChannel(PjRtMemorySpace* memory_space,
auto buffer_promise = tsl::MakeIndirectAsyncValue();
auto raw_buffer = tsl::MakeRef<CpuRawBuffer>(
memory_space, tsl::AsyncValueRef<CpuDeviceMemory>(buffer_promise),
on_device_bytes_count);
on_device_bytes_count, /*is_mutable=*/true);
auto buffer_promise_cb =
[buffer_promise = std::move(buffer_promise), memory_space](
@@ -1250,7 +1250,9 @@ static absl::StatusOr<BufferInfo> MemoryForAllocation(
// If we don't own the buffer, we can't overwrite it or donate it. For
// example we might be pointing to a buffer owned by the client whose
// lifetime will not extend past the lifetime of the donated input buffer.
if ((!can_donate || (arg && !arg->owns_buffers())) &&
if ((!can_donate ||
(arg && !tensorflow::down_cast<CpuRawBuffer*>(arg->raw_buffer().get())
->is_mutable())) &&
!allocation.is_readonly()) {
auto copy = CpuDeviceMemory::CreateDelayedMemory();
@@ -1265,7 +1267,9 @@ static absl::StatusOr<BufferInfo> MemoryForAllocation(
}
buffer_info.buffer = out.CopyRef();
buffer_info.owns_buffer = !arg || arg->owns_buffers();
buffer_info.owns_buffer =
!arg || tensorflow::down_cast<CpuRawBuffer*>(arg->raw_buffer().get())
->is_mutable();
buffer_info.buffer_size = buffer_size;
return buffer_info;
@@ -1856,10 +1860,10 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
// Program execution writes to output buffers so it's a definition event.
auto leaf_tracked_device_buffer =
std::make_unique<TrackedCpuDeviceBuffer>(
result_buffers_info[i].owns_buffer,
tsl::MakeRef<CpuRawBuffer>(
memory_space, std::move(result_buffers_info[i].buffer),
result_buffers_info[i].buffer_size),
result_buffers_info[i].buffer_size,
result_buffers_info[i].owns_buffer),
execute_event.CopyRef());
auto leaf_buffer = std::make_unique<CommonPjRtBufferImpl>(
result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer),
@@ -1870,10 +1874,10 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
CHECK_EQ(result_buffers_info.size(), 1);
// Program execution writes to output buffers so it's a definition event.
auto tracked_device_buffer = std::make_unique<TrackedCpuDeviceBuffer>(
result_buffers_info[0].owns_buffer,
tsl::MakeRef<CpuRawBuffer>(memory_space,
std::move(result_buffers_info[0].buffer),
result_buffers_info[0].buffer_size),
result_buffers_info[0].buffer_size,
result_buffers_info[0].owns_buffer),
/*definition_event=*/execute_event);
auto output_buffer = std::make_unique<CommonPjRtBufferImpl>(
result_shape, std::move(tracked_device_buffer), memory_space);

View File

@@ -174,7 +174,8 @@ class PjRtCpuClient final : public CommonPjRtClient {
absl::StatusOr<tsl::RCReference<CommonPjRtRawBuffer>> ImportForeignMemory(
void* device_ptr, absl::AnyInvocable<void() &&> on_delete_callback,
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) override;
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
bool is_mutable) override;
tsl::thread::ThreadPool* pjrt_client_thread_pool() const {
return pjrt_client_thread_pool_.get();
@@ -234,8 +235,7 @@ class PjRtCpuClient final : public CommonPjRtClient {
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
definition_device_events,
bool raw_buffer_is_mutable) override;
definition_device_events) override;
absl::StatusOr<int64_t> GetOnDeviceBytesCount(
PjRtMemorySpace* memory_space, const xla::Shape& shape) const override;

View File

@@ -130,14 +130,15 @@ CpuRawBuffer::Allocate(PjRtMemorySpace* memory_space, size_t size_bytes,
const CpuDeviceMemory::Allocator& allocator) {
TF_ASSIGN_OR_RETURN(auto memory,
CpuDeviceMemory::Allocate(size_bytes, allocator));
return tsl::MakeRef<CpuRawBuffer>(memory_space, std::move(memory),
size_bytes);
return tsl::MakeRef<CpuRawBuffer>(memory_space, std::move(memory), size_bytes,
/*is_mutable=*/true);
}
/*static*/ absl::StatusOr<tsl::RCReference<CpuRawBuffer>>
CpuRawBuffer::ImportForeignMemory(
void* data, absl::AnyInvocable<void() &&> on_delete_callback,
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) {
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
bool is_mutable) {
if ((absl::bit_cast<std::uintptr_t>(data) & (cpu::MinAlign() - 1)) != 0) {
return InvalidArgument(
"Can't create a view of buffer with unaligned data, ptr: %#x is not "
@@ -148,7 +149,7 @@ CpuRawBuffer::ImportForeignMemory(
memory_space,
CpuDeviceMemory::CreateForeignMemory(data, on_device_bytes_count,
std::move(on_delete_callback)),
on_device_bytes_count);
on_device_bytes_count, is_mutable);
}
size_t CpuRawBuffer::GetOnDeviceSizeInBytes() const { return buffer_size_; }

View File

@@ -95,10 +95,12 @@ class CpuTrackedDeviceEvent : public PjRtDeviceEvent {
class CpuRawBuffer : public CommonPjRtRawBuffer {
public:
CpuRawBuffer(PjRtMemorySpace* memory_space,
tsl::AsyncValueRef<CpuDeviceMemory> buffer, size_t buffer_size)
tsl::AsyncValueRef<CpuDeviceMemory> buffer, size_t buffer_size,
bool is_mutable)
: memory_space_(memory_space),
buffer_(std::move(buffer)),
buffer_size_(buffer_size) {}
buffer_size_(buffer_size),
is_mutable_(is_mutable) {}
absl::Status ValidateSlice(int64_t offset, int64_t slice_size);
@@ -111,7 +113,8 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
// Imports foreign memory.
static absl::StatusOr<tsl::RCReference<CpuRawBuffer>> ImportForeignMemory(
void* data, absl::AnyInvocable<void() &&> on_delete_callback,
size_t on_device_bytes_count, PjRtMemorySpace* memory_space);
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
bool is_mutable);
size_t GetOnDeviceSizeInBytes() const override;
@@ -129,6 +132,8 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
PjRtMemorySpace* memory_space() const override { return memory_space_; }
bool is_mutable() const { return is_mutable_; }
absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>
CopyRawHostToDeviceAndReturnEvent(const void* src, int64_t offset,
int64_t transfer_size) override;
@@ -175,6 +180,7 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
PjRtMemorySpace* const memory_space_;
tsl::AsyncValueRef<CpuDeviceMemory> buffer_;
size_t buffer_size_;
bool is_mutable_;
};
absl::StatusOr<xla::Shape> MakeDefaultCpuBufferShape(xla::Shape shape,

View File

@@ -189,10 +189,9 @@ absl::Status CpuDeviceMemory::AllocateInto(
//===----------------------------------------------------------------------===//
TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer(
bool owns_buffers, tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
tsl::AsyncValueRef<CpuEvent> definition_event)
: AbstractTrackedDeviceBuffer(std::move(raw_buffer)),
owns_buffers_(owns_buffers),
definition_event_(std::move(definition_event)) {
DCHECK(definition_event_);
}

View File

@@ -141,8 +141,7 @@ class CpuDeviceMemory {
class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
public:
// Variant with single definition event.
TrackedCpuDeviceBuffer(bool owns_buffers,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
TrackedCpuDeviceBuffer(tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
tsl::AsyncValueRef<CpuEvent> definition_event);
TrackedCpuDeviceBuffer(TrackedCpuDeviceBuffer&&) noexcept = default;
@@ -170,8 +169,6 @@ class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
absl::InlinedVector<tsl::AsyncValueRef<CpuEvent>, 4>
LockUseAndTransferUsageEvents();
bool owns_buffers() const { return owns_buffers_; }
std::vector<tsl::RCReference<tsl::AsyncValue>> GetAsyncValueDefinitionEvents()
override;
@@ -190,8 +187,6 @@ class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
private:
void ConfirmDonation() override;
bool owns_buffers_;
// The definition event are associated with CPU operations that write to the
// buffers.
tsl::AsyncValueRef<CpuEvent> definition_event_;

View File

@@ -57,8 +57,7 @@ TEST(TrackedCpuDeviceBufferTest, Basic) {
definition_event.SetStateConcrete();
});
TrackedCpuDeviceBuffer tracked_buffer(
/*owns_buffers=*/true, buffer, definition_event);
TrackedCpuDeviceBuffer tracked_buffer(buffer, definition_event);
BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue());
@@ -85,8 +84,7 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) {
Internal("tracked_cpu_device_buffer_test error."));
});
TrackedCpuDeviceBuffer tracked_buffer(
/*owns_buffers=*/true, buffer, definition_event);
TrackedCpuDeviceBuffer tracked_buffer(buffer, definition_event);
BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue());
@@ -108,8 +106,8 @@ TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) {
auto definition_event = MakeConstructedAsyncValueRef<CpuEvent>();
TrackedCpuDeviceBuffer tracked_buffer(
/*owns_buffers=*/true,
tsl::MakeRef<CpuRawBuffer>(memory_space, buffer, expected.size()),
tsl::MakeRef<CpuRawBuffer>(memory_space, buffer, expected.size(),
/*is_mutable=*/true),
definition_event);
auto result = tracked_buffer.buffer();
ASSERT_FALSE(result.IsAvailable());

View File

@@ -611,8 +611,7 @@ absl::StatusOr<PreparedReceive> PrepareReceive(
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer,
client->DefineBuffer(on_device_shape, memory_space,
raw_buffer, {definition_event},
/*raw_buffer_is_mutable=*/true));
raw_buffer, {definition_event}));
definition_event->AndThen([raw_buffer]() {});
return PreparedReceive(client, std::move(clique_key), std::move(buffer),
@@ -917,8 +916,7 @@ StreamExecutorGpuClient::PrepareReceiveBuffer(PjRtDevice* device, Shape shape) {
auto buffer,
DefineBuffer(
on_device_shape, memory_space, raw_buffer,
{tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(definition_event)},
/*raw_buffer_is_mutable=*/true));
{tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(definition_event)}));
return PrepareReceiveBufferResult{std::move(buffer), std::move(raw_buffer),
local_device, stream,

View File

@@ -2875,8 +2875,7 @@ TEST(StreamExecutorGpuClientTest, LinkedEventPromise) {
client->CreateLinkedEventPromise(memory_space, ""));
TF_ASSERT_OK_AND_ASSIGN(
auto buffer, client->DefineBuffer(device_shape, memory_space, raw_buffer,
{std::move(event)},
/*raw_buffer_is_mutable=*/true));
{std::move(event)}));
TF_ASSERT_OK_AND_ASSIGN(
auto definition_event,

View File

@@ -154,8 +154,7 @@ class CommonAsyncHostToDeviceTransferManager
TF_ASSIGN_OR_RETURN(
auto buffer,
client->DefineBuffer(device_shape, memory_space, raw_buffer,
{std::move(definition_event)},
/*raw_buffer_is_mutable=*/true));
{std::move(definition_event)}));
device_shapes.push_back(std::move(device_shape));
buffers.push_back(std::move(buffer));
undispatched_buffer_refs.push_back(raw_buffer);

View File

@@ -512,8 +512,7 @@ PjRtStreamExecutorClient::DefineBuffer(
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
definition_device_events,
bool raw_buffer_is_mutable) {
definition_device_events) {
if (raw_buffer && raw_buffer->memory_space() != memory_space) {
return absl::InvalidArgumentError(
absl::StrFormat("DefineBuffer: Mismatch in memory spaces: %s vs %s",

View File

@@ -396,8 +396,7 @@ class PjRtStreamExecutorClient : public CommonPjRtClient {
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
definition_device_events,
bool raw_buffer_is_mutable) override;
definition_device_events) override;
absl::StatusOr<std::pair<tsl::RCReference<CommonPjRtRawBuffer>,
PjRtFulfillAliasRawBufferCallback>>