Cherry pick 2.2 Add missing valuidation to FusedBatchNorm

This commit is contained in:
Mihai Maruseac
2021-05-06 17:45:51 -07:00
committed by Geeta Chavan
parent e522a1924b
commit 882c7ff305

View File

@@ -1267,6 +1267,33 @@ class FusedBatchNormOpBase : public OpKernel {
context, estimated_variance.dims() == 1,
errors::InvalidArgument("estimated_variance must be 1-dimensional",
estimated_variance.shape().DebugString()));
const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
OP_REQUIRES(
context, scale.NumElements() == num_channels,
errors::InvalidArgument("scale must have the same number of elements "
"as the channels of x, got ",
scale.NumElements(), " and ", num_channels));
OP_REQUIRES(
context, offset.NumElements() == num_channels,
errors::InvalidArgument("offset must have the same number of elements "
"as the channels of x, got ",
offset.NumElements(), " and ", num_channels));
if (estimated_mean.NumElements() != 0) {
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
errors::InvalidArgument(
"mean must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_mean.NumElements(), " and ", num_channels));
}
if (estimated_variance.NumElements() != 0) {
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
errors::InvalidArgument(
"variance must be empty or have the same number of "
"elements as the channels of x, got ",
estimated_variance.NumElements(), " and ", num_channels));
}
if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
@@ -1279,7 +1306,7 @@ class FusedBatchNormOpBase : public OpKernel {
// NOTE(ezhulenev): This requirement is coming from implementation
// details of cudnnBatchNormalizationForwardTrainingEx.
OP_REQUIRES(
context, !is_training_ || x.dim_size(3) % 4 == 0,
context, !is_training_ || num_channels % 4 == 0,
errors::InvalidArgument("FusedBatchNorm with activation requires "
"channel dimension to be a multiple of 4."));
}