mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Move mutability tracking from Tracked buffers to RawBuffers as this is only used for importing foreign memory and clutters the API.
PiperOrigin-RevId: 846916941
This commit is contained in:
committed by
TensorFlower Gardener
parent
f4a923fa82
commit
9d833374f9
41
third_party/xla/xla/pjrt/common_pjrt_client.cc
vendored
41
third_party/xla/xla/pjrt/common_pjrt_client.cc
vendored
@@ -158,8 +158,7 @@ CommonPjRtClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
raw_buffer));
|
||||
return DefineBuffer(device_shape, memory_space, std::move(raw_buffer),
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true);
|
||||
{std::move(definition_event)});
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>>
|
||||
@@ -193,8 +192,7 @@ CommonPjRtClient::CreateUninitializedBuffer(const Shape& shape,
|
||||
raw_buffer->MakeAllocationReadyEvent());
|
||||
TF_ASSIGN_OR_RETURN(auto output_buffer,
|
||||
DefineBuffer(device_shape, memory_space, raw_buffer,
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
return output_buffer;
|
||||
}
|
||||
|
||||
@@ -270,8 +268,7 @@ CommonPjRtClient::CreateAliasBuffer(const Shape& shape,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto result_buffer,
|
||||
DefineBuffer(shape, memory_space, std::move(raw_buffer),
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
|
||||
return std::make_pair(std::move(result_buffer), std::move(fulfill_cb));
|
||||
}
|
||||
@@ -302,14 +299,14 @@ CommonPjRtClient::BufferFromHostBuffer(
|
||||
ImportForeignMemory(
|
||||
const_cast<void*>(data), // CONST_CAST_OK=flag controlled.
|
||||
std::move(on_done_with_host_buffer), on_device_bytes_count,
|
||||
memory_space));
|
||||
memory_space,
|
||||
host_buffer_semantics ==
|
||||
PjRtClient::HostBufferSemantics::kMutableZeroCopy));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto output_buffer,
|
||||
DefineBuffer(
|
||||
device_shape, memory_space, raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{},
|
||||
/*raw_buffer_is_mutable=*/host_buffer_semantics ==
|
||||
PjRtClient::HostBufferSemantics::kMutableZeroCopy));
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{}));
|
||||
return output_buffer;
|
||||
}
|
||||
}
|
||||
@@ -327,8 +324,7 @@ CommonPjRtClient::BufferFromHostBuffer(
|
||||
std::move(on_done_with_host_buffer), device_shape, raw_buffer));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> output_buffer,
|
||||
DefineBuffer(device_shape, memory_space, raw_buffer,
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
return output_buffer;
|
||||
}
|
||||
|
||||
@@ -351,12 +347,13 @@ CommonPjRtClient::CreateViewOfDeviceBuffer(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto raw_buffer,
|
||||
ImportForeignMemory(device_ptr, std::move(on_delete_callback),
|
||||
on_device_bytes_count, memory_space));
|
||||
on_device_bytes_count, memory_space,
|
||||
/*is_mutable=*/false));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto output_buffer,
|
||||
DefineBuffer(device_shape, memory_space, raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{},
|
||||
/*raw_buffer_is_mutable=*/false));
|
||||
DefineBuffer(
|
||||
device_shape, memory_space, raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>{}));
|
||||
return output_buffer;
|
||||
}
|
||||
|
||||
@@ -705,9 +702,9 @@ static std::unique_ptr<PjRtBuffer> CreateOutputLeafBuffer(
|
||||
CHECK(memory_space) << "No memory space found for device: "
|
||||
<< device->DebugString() << " kind: " << kind_id;
|
||||
}
|
||||
auto buffer_or = client->DefineBuffer(
|
||||
output_leaf_shape, memory_space, std::move(leaf_buffer),
|
||||
{definition_event}, /*raw_buffer_is_mutable=*/true);
|
||||
auto buffer_or =
|
||||
client->DefineBuffer(output_leaf_shape, memory_space,
|
||||
std::move(leaf_buffer), {definition_event});
|
||||
CHECK_OK(buffer_or);
|
||||
return *std::move(buffer_or);
|
||||
}
|
||||
@@ -1154,8 +1151,7 @@ CommonPjRtBufferImpl::CopyToCpuMemorySpace(const xla::Shape& dst_shape,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto buffer,
|
||||
dst_client->DefineBuffer(dst_shape, dst_memory_space, dst_raw_buffer,
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
auto* base_ptr = dst_raw_buffer->GetHostPointer();
|
||||
std::unique_ptr<MutableLiteralBase> literal;
|
||||
bool needs_second_copy = false;
|
||||
@@ -1265,8 +1261,7 @@ static absl::Status CommonCopyToMemorySpace(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
dst_buffer,
|
||||
dst_client->DefineBuffer(dst_shape, dst_memory_space, dst_raw_buffer,
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
TF_RETURN_IF_ERROR(src_buffer->AcquireScopedRawBuffer(
|
||||
[&](tsl::RCReference<CommonPjRtRawBuffer> buf_raw_buffer,
|
||||
std::vector<tsl::RCReference<tsl::AsyncValue>>
|
||||
|
||||
@@ -87,7 +87,7 @@ class CommonPjRtClient : public PjRtClient {
|
||||
ImportForeignMemory(void* device_ptr,
|
||||
absl::AnyInvocable<void() &&> on_delete_callback,
|
||||
size_t on_device_bytes_count,
|
||||
PjRtMemorySpace* memory_space) {
|
||||
PjRtMemorySpace* memory_space, bool is_mutable) {
|
||||
return absl::UnimplementedError("ImportForeignMemory is not supported");
|
||||
}
|
||||
|
||||
@@ -105,8 +105,7 @@ class CommonPjRtClient : public PjRtClient {
|
||||
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
|
||||
definition_device_events,
|
||||
bool raw_buffer_is_mutable) {
|
||||
definition_device_events) {
|
||||
return absl::UnimplementedError("DefineBuffer is not supported");
|
||||
}
|
||||
|
||||
|
||||
34
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
34
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
@@ -888,10 +888,11 @@ static bool IsAlignedData(void* ptr) {
|
||||
absl::StatusOr<tsl::RCReference<CommonPjRtRawBuffer>>
|
||||
PjRtCpuClient::ImportForeignMemory(
|
||||
void* device_ptr, absl::AnyInvocable<void() &&> on_delete_callback,
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) {
|
||||
return CpuRawBuffer::ImportForeignMemory(device_ptr,
|
||||
std::move(on_delete_callback),
|
||||
on_device_bytes_count, memory_space);
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
|
||||
bool is_mutable) {
|
||||
return CpuRawBuffer::ImportForeignMemory(
|
||||
device_ptr, std::move(on_delete_callback), on_device_bytes_count,
|
||||
memory_space, is_mutable);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::CreateErrorBuffer(
|
||||
@@ -910,7 +911,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::CreateErrorBuffer(
|
||||
return std::make_unique<CommonPjRtBufferImpl>(
|
||||
shape,
|
||||
std::make_unique<TrackedCpuDeviceBuffer>(
|
||||
/*owns_buffers=*/true, std::move(raw_buffer),
|
||||
std::move(raw_buffer),
|
||||
tsl::AsyncValueRef<CpuEvent>(
|
||||
tsl::MakeErrorAsyncValueRef(std::move(error)))),
|
||||
memory_space);
|
||||
@@ -995,8 +996,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::DefineBuffer(
|
||||
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
|
||||
definition_device_events,
|
||||
bool raw_buffer_is_mutable) {
|
||||
definition_device_events) {
|
||||
if (raw_buffer && raw_buffer->memory_space() != memory_space) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("DefineBuffer: Mismatch in memory spaces: %s vs %s",
|
||||
@@ -1006,7 +1006,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> PjRtCpuClient::DefineBuffer(
|
||||
return std::unique_ptr<PjRtBuffer>(std::make_unique<CommonPjRtBufferImpl>(
|
||||
on_device_shape,
|
||||
std::make_unique<TrackedCpuDeviceBuffer>(
|
||||
/*owns_buffers=*/raw_buffer_is_mutable, std::move(raw_buffer),
|
||||
std::move(raw_buffer),
|
||||
CpuTrackedDeviceEvent::AfterAll(definition_device_events)),
|
||||
memory_space));
|
||||
}
|
||||
@@ -1029,7 +1029,7 @@ PjRtCpuClient::CreateRawBufferChannel(PjRtMemorySpace* memory_space,
|
||||
auto buffer_promise = tsl::MakeIndirectAsyncValue();
|
||||
auto raw_buffer = tsl::MakeRef<CpuRawBuffer>(
|
||||
memory_space, tsl::AsyncValueRef<CpuDeviceMemory>(buffer_promise),
|
||||
on_device_bytes_count);
|
||||
on_device_bytes_count, /*is_mutable=*/true);
|
||||
|
||||
auto buffer_promise_cb =
|
||||
[buffer_promise = std::move(buffer_promise), memory_space](
|
||||
@@ -1250,7 +1250,9 @@ static absl::StatusOr<BufferInfo> MemoryForAllocation(
|
||||
// If we don't own the buffer, we can't overwrite it or donate it. For
|
||||
// example we might be pointing to a buffer owned by the client whose
|
||||
// lifetime will not extend past the lifetime of the donated input buffer.
|
||||
if ((!can_donate || (arg && !arg->owns_buffers())) &&
|
||||
if ((!can_donate ||
|
||||
(arg && !tensorflow::down_cast<CpuRawBuffer*>(arg->raw_buffer().get())
|
||||
->is_mutable())) &&
|
||||
!allocation.is_readonly()) {
|
||||
auto copy = CpuDeviceMemory::CreateDelayedMemory();
|
||||
|
||||
@@ -1265,7 +1267,9 @@ static absl::StatusOr<BufferInfo> MemoryForAllocation(
|
||||
}
|
||||
|
||||
buffer_info.buffer = out.CopyRef();
|
||||
buffer_info.owns_buffer = !arg || arg->owns_buffers();
|
||||
buffer_info.owns_buffer =
|
||||
!arg || tensorflow::down_cast<CpuRawBuffer*>(arg->raw_buffer().get())
|
||||
->is_mutable();
|
||||
buffer_info.buffer_size = buffer_size;
|
||||
return buffer_info;
|
||||
|
||||
@@ -1856,10 +1860,10 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
||||
// Program execution writes to output buffers so it's a definition event.
|
||||
auto leaf_tracked_device_buffer =
|
||||
std::make_unique<TrackedCpuDeviceBuffer>(
|
||||
result_buffers_info[i].owns_buffer,
|
||||
tsl::MakeRef<CpuRawBuffer>(
|
||||
memory_space, std::move(result_buffers_info[i].buffer),
|
||||
result_buffers_info[i].buffer_size),
|
||||
result_buffers_info[i].buffer_size,
|
||||
result_buffers_info[i].owns_buffer),
|
||||
execute_event.CopyRef());
|
||||
auto leaf_buffer = std::make_unique<CommonPjRtBufferImpl>(
|
||||
result_shape.tuple_shapes(i), std::move(leaf_tracked_device_buffer),
|
||||
@@ -1870,10 +1874,10 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
||||
CHECK_EQ(result_buffers_info.size(), 1);
|
||||
// Program execution writes to output buffers so it's a definition event.
|
||||
auto tracked_device_buffer = std::make_unique<TrackedCpuDeviceBuffer>(
|
||||
result_buffers_info[0].owns_buffer,
|
||||
tsl::MakeRef<CpuRawBuffer>(memory_space,
|
||||
std::move(result_buffers_info[0].buffer),
|
||||
result_buffers_info[0].buffer_size),
|
||||
result_buffers_info[0].buffer_size,
|
||||
result_buffers_info[0].owns_buffer),
|
||||
/*definition_event=*/execute_event);
|
||||
auto output_buffer = std::make_unique<CommonPjRtBufferImpl>(
|
||||
result_shape, std::move(tracked_device_buffer), memory_space);
|
||||
|
||||
6
third_party/xla/xla/pjrt/cpu/cpu_client.h
vendored
6
third_party/xla/xla/pjrt/cpu/cpu_client.h
vendored
@@ -174,7 +174,8 @@ class PjRtCpuClient final : public CommonPjRtClient {
|
||||
|
||||
absl::StatusOr<tsl::RCReference<CommonPjRtRawBuffer>> ImportForeignMemory(
|
||||
void* device_ptr, absl::AnyInvocable<void() &&> on_delete_callback,
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) override;
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
|
||||
bool is_mutable) override;
|
||||
|
||||
tsl::thread::ThreadPool* pjrt_client_thread_pool() const {
|
||||
return pjrt_client_thread_pool_.get();
|
||||
@@ -234,8 +235,7 @@ class PjRtCpuClient final : public CommonPjRtClient {
|
||||
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
|
||||
definition_device_events,
|
||||
bool raw_buffer_is_mutable) override;
|
||||
definition_device_events) override;
|
||||
|
||||
absl::StatusOr<int64_t> GetOnDeviceBytesCount(
|
||||
PjRtMemorySpace* memory_space, const xla::Shape& shape) const override;
|
||||
|
||||
9
third_party/xla/xla/pjrt/cpu/raw_buffer.cc
vendored
9
third_party/xla/xla/pjrt/cpu/raw_buffer.cc
vendored
@@ -130,14 +130,15 @@ CpuRawBuffer::Allocate(PjRtMemorySpace* memory_space, size_t size_bytes,
|
||||
const CpuDeviceMemory::Allocator& allocator) {
|
||||
TF_ASSIGN_OR_RETURN(auto memory,
|
||||
CpuDeviceMemory::Allocate(size_bytes, allocator));
|
||||
return tsl::MakeRef<CpuRawBuffer>(memory_space, std::move(memory),
|
||||
size_bytes);
|
||||
return tsl::MakeRef<CpuRawBuffer>(memory_space, std::move(memory), size_bytes,
|
||||
/*is_mutable=*/true);
|
||||
}
|
||||
|
||||
/*static*/ absl::StatusOr<tsl::RCReference<CpuRawBuffer>>
|
||||
CpuRawBuffer::ImportForeignMemory(
|
||||
void* data, absl::AnyInvocable<void() &&> on_delete_callback,
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space) {
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
|
||||
bool is_mutable) {
|
||||
if ((absl::bit_cast<std::uintptr_t>(data) & (cpu::MinAlign() - 1)) != 0) {
|
||||
return InvalidArgument(
|
||||
"Can't create a view of buffer with unaligned data, ptr: %#x is not "
|
||||
@@ -148,7 +149,7 @@ CpuRawBuffer::ImportForeignMemory(
|
||||
memory_space,
|
||||
CpuDeviceMemory::CreateForeignMemory(data, on_device_bytes_count,
|
||||
std::move(on_delete_callback)),
|
||||
on_device_bytes_count);
|
||||
on_device_bytes_count, is_mutable);
|
||||
}
|
||||
|
||||
size_t CpuRawBuffer::GetOnDeviceSizeInBytes() const { return buffer_size_; }
|
||||
|
||||
12
third_party/xla/xla/pjrt/cpu/raw_buffer.h
vendored
12
third_party/xla/xla/pjrt/cpu/raw_buffer.h
vendored
@@ -95,10 +95,12 @@ class CpuTrackedDeviceEvent : public PjRtDeviceEvent {
|
||||
class CpuRawBuffer : public CommonPjRtRawBuffer {
|
||||
public:
|
||||
CpuRawBuffer(PjRtMemorySpace* memory_space,
|
||||
tsl::AsyncValueRef<CpuDeviceMemory> buffer, size_t buffer_size)
|
||||
tsl::AsyncValueRef<CpuDeviceMemory> buffer, size_t buffer_size,
|
||||
bool is_mutable)
|
||||
: memory_space_(memory_space),
|
||||
buffer_(std::move(buffer)),
|
||||
buffer_size_(buffer_size) {}
|
||||
buffer_size_(buffer_size),
|
||||
is_mutable_(is_mutable) {}
|
||||
|
||||
absl::Status ValidateSlice(int64_t offset, int64_t slice_size);
|
||||
|
||||
@@ -111,7 +113,8 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
|
||||
// Imports foreign memory.
|
||||
static absl::StatusOr<tsl::RCReference<CpuRawBuffer>> ImportForeignMemory(
|
||||
void* data, absl::AnyInvocable<void() &&> on_delete_callback,
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space);
|
||||
size_t on_device_bytes_count, PjRtMemorySpace* memory_space,
|
||||
bool is_mutable);
|
||||
|
||||
size_t GetOnDeviceSizeInBytes() const override;
|
||||
|
||||
@@ -129,6 +132,8 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
|
||||
|
||||
PjRtMemorySpace* memory_space() const override { return memory_space_; }
|
||||
|
||||
bool is_mutable() const { return is_mutable_; }
|
||||
|
||||
absl::StatusOr<tsl::RCReference<PjRtDeviceEvent>>
|
||||
CopyRawHostToDeviceAndReturnEvent(const void* src, int64_t offset,
|
||||
int64_t transfer_size) override;
|
||||
@@ -175,6 +180,7 @@ class CpuRawBuffer : public CommonPjRtRawBuffer {
|
||||
PjRtMemorySpace* const memory_space_;
|
||||
tsl::AsyncValueRef<CpuDeviceMemory> buffer_;
|
||||
size_t buffer_size_;
|
||||
bool is_mutable_;
|
||||
};
|
||||
|
||||
absl::StatusOr<xla::Shape> MakeDefaultCpuBufferShape(xla::Shape shape,
|
||||
|
||||
@@ -189,10 +189,9 @@ absl::Status CpuDeviceMemory::AllocateInto(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TrackedCpuDeviceBuffer::TrackedCpuDeviceBuffer(
|
||||
bool owns_buffers, tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
tsl::AsyncValueRef<CpuEvent> definition_event)
|
||||
: AbstractTrackedDeviceBuffer(std::move(raw_buffer)),
|
||||
owns_buffers_(owns_buffers),
|
||||
definition_event_(std::move(definition_event)) {
|
||||
DCHECK(definition_event_);
|
||||
}
|
||||
|
||||
@@ -141,8 +141,7 @@ class CpuDeviceMemory {
|
||||
class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||
public:
|
||||
// Variant with single definition event.
|
||||
TrackedCpuDeviceBuffer(bool owns_buffers,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
TrackedCpuDeviceBuffer(tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
tsl::AsyncValueRef<CpuEvent> definition_event);
|
||||
|
||||
TrackedCpuDeviceBuffer(TrackedCpuDeviceBuffer&&) noexcept = default;
|
||||
@@ -170,8 +169,6 @@ class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||
absl::InlinedVector<tsl::AsyncValueRef<CpuEvent>, 4>
|
||||
LockUseAndTransferUsageEvents();
|
||||
|
||||
bool owns_buffers() const { return owns_buffers_; }
|
||||
|
||||
std::vector<tsl::RCReference<tsl::AsyncValue>> GetAsyncValueDefinitionEvents()
|
||||
override;
|
||||
|
||||
@@ -190,8 +187,6 @@ class TrackedCpuDeviceBuffer : public AbstractTrackedDeviceBuffer {
|
||||
private:
|
||||
void ConfirmDonation() override;
|
||||
|
||||
bool owns_buffers_;
|
||||
|
||||
// The definition event are associated with CPU operations that write to the
|
||||
// buffers.
|
||||
tsl::AsyncValueRef<CpuEvent> definition_event_;
|
||||
|
||||
@@ -57,8 +57,7 @@ TEST(TrackedCpuDeviceBufferTest, Basic) {
|
||||
definition_event.SetStateConcrete();
|
||||
});
|
||||
|
||||
TrackedCpuDeviceBuffer tracked_buffer(
|
||||
/*owns_buffers=*/true, buffer, definition_event);
|
||||
TrackedCpuDeviceBuffer tracked_buffer(buffer, definition_event);
|
||||
|
||||
BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue());
|
||||
|
||||
@@ -85,8 +84,7 @@ TEST(TrackedCpuDeviceBufferTest, BasicError) {
|
||||
Internal("tracked_cpu_device_buffer_test error."));
|
||||
});
|
||||
|
||||
TrackedCpuDeviceBuffer tracked_buffer(
|
||||
/*owns_buffers=*/true, buffer, definition_event);
|
||||
TrackedCpuDeviceBuffer tracked_buffer(buffer, definition_event);
|
||||
|
||||
BlockUntilReady(tracked_buffer.definition_event().GetAsyncValue());
|
||||
|
||||
@@ -108,8 +106,8 @@ TEST(TrackedCpuDeviceBufferTest, DelayedAllocation) {
|
||||
|
||||
auto definition_event = MakeConstructedAsyncValueRef<CpuEvent>();
|
||||
TrackedCpuDeviceBuffer tracked_buffer(
|
||||
/*owns_buffers=*/true,
|
||||
tsl::MakeRef<CpuRawBuffer>(memory_space, buffer, expected.size()),
|
||||
tsl::MakeRef<CpuRawBuffer>(memory_space, buffer, expected.size(),
|
||||
/*is_mutable=*/true),
|
||||
definition_event);
|
||||
auto result = tracked_buffer.buffer();
|
||||
ASSERT_FALSE(result.IsAvailable());
|
||||
|
||||
@@ -611,8 +611,7 @@ absl::StatusOr<PreparedReceive> PrepareReceive(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer> buffer,
|
||||
client->DefineBuffer(on_device_shape, memory_space,
|
||||
raw_buffer, {definition_event},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
raw_buffer, {definition_event}));
|
||||
definition_event->AndThen([raw_buffer]() {});
|
||||
|
||||
return PreparedReceive(client, std::move(clique_key), std::move(buffer),
|
||||
@@ -917,8 +916,7 @@ StreamExecutorGpuClient::PrepareReceiveBuffer(PjRtDevice* device, Shape shape) {
|
||||
auto buffer,
|
||||
DefineBuffer(
|
||||
on_device_shape, memory_space, raw_buffer,
|
||||
{tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(definition_event)}));
|
||||
|
||||
return PrepareReceiveBufferResult{std::move(buffer), std::move(raw_buffer),
|
||||
local_device, stream,
|
||||
|
||||
@@ -2875,8 +2875,7 @@ TEST(StreamExecutorGpuClientTest, LinkedEventPromise) {
|
||||
client->CreateLinkedEventPromise(memory_space, ""));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto buffer, client->DefineBuffer(device_shape, memory_space, raw_buffer,
|
||||
{std::move(event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(event)}));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto definition_event,
|
||||
|
||||
@@ -154,8 +154,7 @@ class CommonAsyncHostToDeviceTransferManager
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto buffer,
|
||||
client->DefineBuffer(device_shape, memory_space, raw_buffer,
|
||||
{std::move(definition_event)},
|
||||
/*raw_buffer_is_mutable=*/true));
|
||||
{std::move(definition_event)}));
|
||||
device_shapes.push_back(std::move(device_shape));
|
||||
buffers.push_back(std::move(buffer));
|
||||
undispatched_buffer_refs.push_back(raw_buffer);
|
||||
|
||||
@@ -512,8 +512,7 @@ PjRtStreamExecutorClient::DefineBuffer(
|
||||
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
|
||||
definition_device_events,
|
||||
bool raw_buffer_is_mutable) {
|
||||
definition_device_events) {
|
||||
if (raw_buffer && raw_buffer->memory_space() != memory_space) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("DefineBuffer: Mismatch in memory spaces: %s vs %s",
|
||||
|
||||
@@ -396,8 +396,7 @@ class PjRtStreamExecutorClient : public CommonPjRtClient {
|
||||
const Shape& on_device_shape, PjRtMemorySpace* memory_space,
|
||||
tsl::RCReference<CommonPjRtRawBuffer> raw_buffer,
|
||||
absl::InlinedVector<tsl::RCReference<PjRtDeviceEvent>, 4>
|
||||
definition_device_events,
|
||||
bool raw_buffer_is_mutable) override;
|
||||
definition_device_events) override;
|
||||
|
||||
absl::StatusOr<std::pair<tsl::RCReference<CommonPjRtRawBuffer>,
|
||||
PjRtFulfillAliasRawBufferCallback>>
|
||||
|
||||
Reference in New Issue
Block a user