mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Merge pull request #49881 from geetachavan1/cherrypicks_2LIXF
Fix heap OOB / undefined behavior in `RaggedTensorToTensor`
This commit is contained in:
@@ -207,8 +207,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
DCHECK_EQ(result->size(), first_dimension);
|
||||
}
|
||||
|
||||
void CalculateOutputIndexRowSplit(
|
||||
OpKernelContext* context, const RowPartitionTensor& row_split,
|
||||
Status CalculateOutputIndexRowSplit(
|
||||
const RowPartitionTensor& row_split,
|
||||
const vector<INDEX_TYPE>& parent_output_index,
|
||||
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
|
||||
vector<INDEX_TYPE>* result) {
|
||||
@@ -232,10 +232,11 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
result->push_back(-1);
|
||||
}
|
||||
}
|
||||
if (row_split_size > 0) {
|
||||
OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
|
||||
errors::InvalidArgument("Invalid row split size."));
|
||||
if (row_split_size > 0 && result->size() != row_split(row_split_size - 1)) {
|
||||
return errors::InvalidArgument("Invalid row split size.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Calculate the output index of the first element of a list.
|
||||
@@ -259,20 +260,26 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
// result[6] = -1 because parent_output_index[value_rowids[6]] == -1
|
||||
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
|
||||
// result[8] = parent_output_index[value_rowids[7]]
|
||||
void CalculateOutputIndexValueRowID(
|
||||
OpKernelContext* context, const RowPartitionTensor& value_rowids,
|
||||
Status CalculateOutputIndexValueRowID(
|
||||
const RowPartitionTensor& value_rowids,
|
||||
const vector<INDEX_TYPE>& parent_output_index,
|
||||
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
|
||||
vector<INDEX_TYPE>* result) {
|
||||
const INDEX_TYPE index_size = value_rowids.size();
|
||||
result->reserve(index_size);
|
||||
if (index_size == 0) {
|
||||
return;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
INDEX_TYPE current_output_column = 0;
|
||||
INDEX_TYPE current_value_rowid = value_rowids(0);
|
||||
DCHECK_LT(current_value_rowid, parent_output_index.size());
|
||||
|
||||
if (current_value_rowid >= parent_output_index.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Got current_value_rowid=", current_value_rowid,
|
||||
" which is not less than ", parent_output_index.size());
|
||||
}
|
||||
|
||||
INDEX_TYPE current_output_index = parent_output_index[current_value_rowid];
|
||||
result->push_back(current_output_index);
|
||||
for (INDEX_TYPE i = 1; i < index_size; ++i) {
|
||||
@@ -289,13 +296,23 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
} else {
|
||||
current_output_column = 0;
|
||||
current_value_rowid = next_value_rowid;
|
||||
DCHECK_LT(next_value_rowid, parent_output_index.size());
|
||||
|
||||
if (next_value_rowid >= parent_output_index.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Got next_value_rowid=", next_value_rowid,
|
||||
" which is not less than ", parent_output_index.size());
|
||||
}
|
||||
|
||||
current_output_index = parent_output_index[next_value_rowid];
|
||||
}
|
||||
result->push_back(current_output_index);
|
||||
}
|
||||
OP_REQUIRES(context, result->size() == value_rowids.size(),
|
||||
errors::InvalidArgument("Invalid row ids."));
|
||||
|
||||
if (result->size() != value_rowids.size()) {
|
||||
return errors::InvalidArgument("Invalid row ids.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CalculateOutputIndex(OpKernelContext* context, int dimension,
|
||||
@@ -308,10 +325,9 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
auto partition_type = GetRowPartitionTypeByDimension(dimension);
|
||||
switch (partition_type) {
|
||||
case RowPartitionType::VALUE_ROWIDS:
|
||||
CalculateOutputIndexValueRowID(
|
||||
context, row_partition_tensor, parent_output_index,
|
||||
output_index_multiplier, output_size, result);
|
||||
return tensorflow::Status::OK();
|
||||
return CalculateOutputIndexValueRowID(
|
||||
row_partition_tensor, parent_output_index, output_index_multiplier,
|
||||
output_size, result);
|
||||
case RowPartitionType::ROW_SPLITS:
|
||||
if (row_partition_tensor.size() - 1 > parent_output_index.size()) {
|
||||
return errors::InvalidArgument(
|
||||
@@ -319,10 +335,9 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
|
||||
row_partition_tensor.size() - 1, " > ",
|
||||
parent_output_index.size());
|
||||
}
|
||||
CalculateOutputIndexRowSplit(
|
||||
context, row_partition_tensor, parent_output_index,
|
||||
output_index_multiplier, output_size, result);
|
||||
return tensorflow::Status::OK();
|
||||
return CalculateOutputIndexRowSplit(
|
||||
row_partition_tensor, parent_output_index, output_index_multiplier,
|
||||
output_size, result);
|
||||
default:
|
||||
return errors::InvalidArgument(
|
||||
"Unsupported partition type:",
|
||||
|
||||
Reference in New Issue
Block a user