mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Cherry pick 2.2 Add missing valuidation to FusedBatchNorm
This commit is contained in:
committed by
Geeta Chavan
parent
e522a1924b
commit
882c7ff305
@@ -1267,6 +1267,33 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
context, estimated_variance.dims() == 1,
|
||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
estimated_variance.shape().DebugString()));
|
||||
|
||||
const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, scale.NumElements() == num_channels,
|
||||
errors::InvalidArgument("scale must have the same number of elements "
|
||||
"as the channels of x, got ",
|
||||
scale.NumElements(), " and ", num_channels));
|
||||
OP_REQUIRES(
|
||||
context, offset.NumElements() == num_channels,
|
||||
errors::InvalidArgument("offset must have the same number of elements "
|
||||
"as the channels of x, got ",
|
||||
offset.NumElements(), " and ", num_channels));
|
||||
if (estimated_mean.NumElements() != 0) {
|
||||
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
|
||||
errors::InvalidArgument(
|
||||
"mean must be empty or have the same number of "
|
||||
"elements as the channels of x, got ",
|
||||
estimated_mean.NumElements(), " and ", num_channels));
|
||||
}
|
||||
if (estimated_variance.NumElements() != 0) {
|
||||
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
|
||||
errors::InvalidArgument(
|
||||
"variance must be empty or have the same number of "
|
||||
"elements as the channels of x, got ",
|
||||
estimated_variance.NumElements(), " and ", num_channels));
|
||||
}
|
||||
|
||||
if (has_side_input_) {
|
||||
OP_REQUIRES(context, side_input->shape() == x.shape(),
|
||||
errors::InvalidArgument(
|
||||
@@ -1279,7 +1306,7 @@ class FusedBatchNormOpBase : public OpKernel {
|
||||
// NOTE(ezhulenev): This requirement is coming from implementation
|
||||
// details of cudnnBatchNormalizationForwardTrainingEx.
|
||||
OP_REQUIRES(
|
||||
context, !is_training_ || x.dim_size(3) % 4 == 0,
|
||||
context, !is_training_ || num_channels % 4 == 0,
|
||||
errors::InvalidArgument("FusedBatchNorm with activation requires "
|
||||
"channel dimension to be a multiple of 4."));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user