mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Validate that a and b are proper sparse tensors
PiperOrigin-RevId: 373248068 Change-Id: I0a2041a0747901b3f00387a6a3bce9bca6b0b3b1
This commit is contained in:
committed by
Geeta Chavan
parent
f1b51e0af2
commit
9c0d5a842a
@@ -44,6 +44,11 @@ class SparseAddOp : public OpKernel {
|
||||
b_indices->shape().DebugString()));
|
||||
const int64 a_nnz = a_indices->dim_size(0);
|
||||
const int64 b_nnz = b_indices->dim_size(0);
|
||||
const int num_dims = a_indices->dim_size(1);
|
||||
OP_REQUIRES(ctx, b_indices->dim_size(1) == num_dims,
|
||||
errors::InvalidArgument(
|
||||
"Input indices must have the same dimension, got ",
|
||||
num_dims, " and ", b_indices->dim_size(1)));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("a_values", &a_values_t));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("b_values", &b_values_t));
|
||||
@@ -72,6 +77,13 @@ class SparseAddOp : public OpKernel {
|
||||
"Input shapes should be a vector but received shapes ",
|
||||
a_shape->shape().DebugString(), " and ",
|
||||
b_shape->shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, a_shape->NumElements() == num_dims,
|
||||
errors::InvalidArgument("Second dimension of a_indices and length of "
|
||||
"a_shape must match, got ",
|
||||
num_dims, " and ", a_shape->NumElements()));
|
||||
OP_REQUIRES(ctx, num_dims > 0,
|
||||
errors::InvalidArgument("Tesors must not be empty"));
|
||||
OP_REQUIRES(
|
||||
ctx, a_shape->IsSameSize(*b_shape),
|
||||
errors::InvalidArgument(
|
||||
@@ -100,11 +112,6 @@ class SparseAddOp : public OpKernel {
|
||||
std::vector<std::pair<bool, int64>> entries_to_copy; // from_a?, idx
|
||||
entries_to_copy.reserve(a_nnz + b_nnz);
|
||||
std::vector<T> out_values;
|
||||
const int num_dims = a_shape->dim_size(0);
|
||||
|
||||
OP_REQUIRES(ctx, num_dims > 0,
|
||||
errors::InvalidArgument("Invalid input_a shape. Received: ",
|
||||
a_shape->DebugString()));
|
||||
|
||||
// The input and output sparse tensors are assumed to be ordered along
|
||||
// increasing dimension number.
|
||||
|
||||
Reference in New Issue
Block a user