[Caffe2] Fix shape inference for element-wise operators (#33431)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33431

Some elementwise operators don't have shape and type inference specified for the output tensor: `BitwiseOr`, `BitwiseAnd`, `BitwiseXor`, `Not`, `Sign`.

This change fixes this issue:
- For `Not` and `Sign` operators, the output has the same type and shape as the input, so `IdenticalTypeAndShapeOfInput` function is used to specify that.
- For bitwise operators created by `CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP` macro, the type and shape inference rules should be the same as for other binary element-wise operators, so `TensorInferenceFunction(ElementwiseOpShapeInference)` is used to specify that.

Also some tests were modified to ensure that the shape and type are inferred (`ensure_outputs_are_inferred` parameter)

Test Plan:
```
CAFFE2_ASSERT_SHAPEINFERENCE=1 buck test caffe2/caffe2/python/operator_test:elementwise_ops_test
CAFFE2_ASSERT_SHAPEINFERENCE=1 buck test caffe2/caffe2/python/operator_test:math_ops_test
```

Note that the tests have to be executed with `CAFFE2_ASSERT_SHAPEINFERENCE=1` in order to fail upon shape inference failure.

Reviewed By: idning

Differential Revision: D19880164

fbshipit-source-id: 5d7902e045d79e5669e5e98dfb13a39711294939
This commit is contained in:
Alex Cheparukhin
2020-02-25 09:01:00 -08:00
committed by Facebook Github Bot
parent 819ca2c285
commit ee23944f46
4 changed files with 55 additions and 10 deletions

View File

@@ -861,12 +861,13 @@ Both input operands should be of type `bool`.
};
}
#define CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(name, symbol) \
OPERATOR_SCHEMA(name) \
.NumInputs(2) \
.NumOutputs(1) \
.AllowInplace({{0, 0}}) \
.FillUsing(BitwiseDocGenerator(symbol)); \
#define CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(name, symbol) \
OPERATOR_SCHEMA(name) \
.NumInputs(2) \
.NumOutputs(1) \
.AllowInplace({{0, 0}}) \
.FillUsing(BitwiseDocGenerator(symbol)) \
.TensorInferenceFunction(ElementwiseOpShapeInference); \
SHOULD_NOT_DO_GRADIENT(name)
CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(BitwiseOr, "bitwise_or");
@@ -878,6 +879,7 @@ CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(BitwiseXor, "bitwise_xor");
OPERATOR_SCHEMA(Not)
.NumInputs(1)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0)
.SetDoc(R"DOC(
Performs element-wise negation on input tensor `X`.
@@ -934,6 +936,7 @@ SHOULD_NOT_DO_GRADIENT(Not);
OPERATOR_SCHEMA(Sign)
.NumInputs(1)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0)
.SetDoc(R"DOC(
Computes sign for each element of the input: -1, 0 or 1.

View File

@@ -30,6 +30,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=abs_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
@@ -50,6 +51,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=exp_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
@@ -74,6 +76,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=log_op,
ensure_outputs_are_inferred=True,
)
self.assertGradientChecks(
@@ -106,7 +109,8 @@ class TestElementwiseOps(hu.HypothesisTestCase):
inputs=[X, Y],
reference=powt_op,
output_to_grad="Z",
grad_reference=powt_grad)
grad_reference=powt_grad,
ensure_outputs_are_inferred=True)
@given(n=st.integers(0, 6), m=st.integers(4, 6),
seed=st.integers(0, 1000), **hu.gcs)
@@ -128,6 +132,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=sqr_op,
ensure_outputs_are_inferred=True,
)
self.assertGradientChecks(
@@ -156,6 +161,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=sqrt_op,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
# stepsize need to be smaller than the possible minimum X, so the
@@ -179,6 +185,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=softsign_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
if not inplace:
@@ -201,6 +208,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=rsqrt_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=5e-3)
@@ -228,6 +236,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
reference=cube_ref,
output_to_grad="Y",
grad_reference=cube_grad_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
@@ -247,6 +256,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=cbrt_ref,
ensure_outputs_are_inferred=True,
)
@given(X=hu.tensor(elements=st.floats(1.0, 10.0), dtype=np.float32),
@@ -281,6 +291,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=swish,
ensure_outputs_are_inferred=True,
)
self.assertGradientChecks(
@@ -331,6 +342,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=sigmoid_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(gc, op, [X], 0, [0])
@@ -370,6 +382,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=hard_sigmoid_ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(
@@ -390,6 +403,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X, Y],
reference=eq,
ensure_outputs_are_inferred=True,
)
workspace.FeedBlob('X', X)
@@ -419,6 +433,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X, Y],
reference=eq,
ensure_outputs_are_inferred=True,
)
workspace.FeedBlob('X', X)
@@ -446,6 +461,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=inputs,
reference=ref,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, inputs, [0])
if test_grad:
@@ -459,7 +475,8 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=inputs,
reference=ref,
)
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, inputs, [0])
if test_grad:
for i in range(len(inputs)):
@@ -667,11 +684,32 @@ class TestElementwiseOps(hu.HypothesisTestCase):
op=op,
inputs=[X],
reference=reciprocal_op,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
self.assertGradientChecks(
gc, op, [X], 0, [0], stepsize=1e-3, threshold=0.05)
@given(X=hu.tensor(dtype=np.bool), **hu.gcs)
def test_not(self, X, gc, dc):
def not_op(X):
return [np.logical_not(X)]
op = core.CreateOperator(
"Not",
["X"],
["Y"],
)
self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[X],
reference=not_op,
ensure_outputs_are_inferred=True,
)
self.assertDeviceChecks(dc, op, [X], [0])
if __name__ == "__main__":
import unittest

View File

@@ -32,7 +32,8 @@ class TestMathOps(serial.SerializedTestCase):
self.assertReferenceChecks(gc, op, [X], powf,
output_to_grad="Y",
grad_reference=powf_grad),
grad_reference=powf_grad,
ensure_outputs_are_inferred=True)
@serial.given(X=hu.tensor(),
exponent=st.floats(min_value=-3.0, max_value=3.0),
@@ -44,7 +45,8 @@ class TestMathOps(serial.SerializedTestCase):
op = core.CreateOperator(
"Sign", ["X"], ["Y"])
self.assertReferenceChecks(gc, op, [X], signf),
self.assertReferenceChecks(
gc, op, [X], signf, ensure_outputs_are_inferred=True)
self.assertDeviceChecks(dc, op, [X], [0])

View File

@@ -229,6 +229,7 @@ class SerializedTestCase(hu.HypothesisTestCase):
grad_reference=None,
atol=None,
outputs_to_check=None,
ensure_outputs_are_inferred=False,
):
outs = super(SerializedTestCase, self).assertReferenceChecks(
device_option,
@@ -241,6 +242,7 @@ class SerializedTestCase(hu.HypothesisTestCase):
grad_reference,
atol,
outputs_to_check,
ensure_outputs_are_inferred,
)
if not getattr(_output_context, 'disable_serialized_check', False):
grad_ops = _getGradientOrNone(op)