mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[Caffe2] Fix shape inference for element-wise operators (#33431)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33431 Some elementwise operators don't have shape and type inference specified for the output tensor: `BitwiseOr`, `BitwiseAnd`, `BitwiseXor`, `Not`, `Sign`. This change fixes this issue: - For `Not` and `Sign` operators, the output has the same type and shape as the input, so `IdenticalTypeAndShapeOfInput` function is used to specify that. - For bitwise operators created by `CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP` macro, the type and shape inference rules should be the same as for other binary element-wise operators, so `TensorInferenceFunction(ElementwiseOpShapeInference)` is used to specify that. Also some tests were modified to ensure that the shape and type are inferred (`ensure_outputs_are_inferred` parameter) Test Plan: ``` CAFFE2_ASSERT_SHAPEINFERENCE=1 buck test caffe2/caffe2/python/operator_test:elementwise_ops_test CAFFE2_ASSERT_SHAPEINFERENCE=1 buck test caffe2/caffe2/python/operator_test:math_ops_test ``` Note that the tests have to be executed with `CAFFE2_ASSERT_SHAPEINFERENCE=1` in order to fail upon shape inference failure. Reviewed By: idning Differential Revision: D19880164 fbshipit-source-id: 5d7902e045d79e5669e5e98dfb13a39711294939
This commit is contained in:
committed by
Facebook Github Bot
parent
819ca2c285
commit
ee23944f46
@@ -861,12 +861,13 @@ Both input operands should be of type `bool`.
|
||||
};
|
||||
}
|
||||
|
||||
#define CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(name, symbol) \
|
||||
OPERATOR_SCHEMA(name) \
|
||||
.NumInputs(2) \
|
||||
.NumOutputs(1) \
|
||||
.AllowInplace({{0, 0}}) \
|
||||
.FillUsing(BitwiseDocGenerator(symbol)); \
|
||||
#define CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(name, symbol) \
|
||||
OPERATOR_SCHEMA(name) \
|
||||
.NumInputs(2) \
|
||||
.NumOutputs(1) \
|
||||
.AllowInplace({{0, 0}}) \
|
||||
.FillUsing(BitwiseDocGenerator(symbol)) \
|
||||
.TensorInferenceFunction(ElementwiseOpShapeInference); \
|
||||
SHOULD_NOT_DO_GRADIENT(name)
|
||||
|
||||
CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(BitwiseOr, "bitwise_or");
|
||||
@@ -878,6 +879,7 @@ CAFFE2_SCHEMA_FOR_BINARY_BITWISE_OP(BitwiseXor, "bitwise_xor");
|
||||
OPERATOR_SCHEMA(Not)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.IdenticalTypeAndShapeOfInput(0)
|
||||
.SetDoc(R"DOC(
|
||||
Performs element-wise negation on input tensor `X`.
|
||||
|
||||
@@ -934,6 +936,7 @@ SHOULD_NOT_DO_GRADIENT(Not);
|
||||
OPERATOR_SCHEMA(Sign)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.IdenticalTypeAndShapeOfInput(0)
|
||||
.SetDoc(R"DOC(
|
||||
Computes sign for each element of the input: -1, 0 or 1.
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=abs_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(gc, op, [X], 0, [0])
|
||||
@@ -50,6 +51,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=exp_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(gc, op, [X], 0, [0])
|
||||
@@ -74,6 +76,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=log_op,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
self.assertGradientChecks(
|
||||
@@ -106,7 +109,8 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
inputs=[X, Y],
|
||||
reference=powt_op,
|
||||
output_to_grad="Z",
|
||||
grad_reference=powt_grad)
|
||||
grad_reference=powt_grad,
|
||||
ensure_outputs_are_inferred=True)
|
||||
|
||||
@given(n=st.integers(0, 6), m=st.integers(4, 6),
|
||||
seed=st.integers(0, 1000), **hu.gcs)
|
||||
@@ -128,6 +132,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=sqr_op,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
self.assertGradientChecks(
|
||||
@@ -156,6 +161,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=sqrt_op,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
# stepsize need to be smaller than the possible minimum X, so the
|
||||
@@ -179,6 +185,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=softsign_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
if not inplace:
|
||||
@@ -201,6 +208,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=rsqrt_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(gc, op, [X], 0, [0], stepsize=5e-3)
|
||||
@@ -228,6 +236,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
reference=cube_ref,
|
||||
output_to_grad="Y",
|
||||
grad_reference=cube_grad_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
|
||||
@@ -247,6 +256,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=cbrt_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
@given(X=hu.tensor(elements=st.floats(1.0, 10.0), dtype=np.float32),
|
||||
@@ -281,6 +291,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=swish,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
self.assertGradientChecks(
|
||||
@@ -331,6 +342,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=sigmoid_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(gc, op, [X], 0, [0])
|
||||
@@ -370,6 +382,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=hard_sigmoid_ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(
|
||||
@@ -390,6 +403,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X, Y],
|
||||
reference=eq,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
workspace.FeedBlob('X', X)
|
||||
@@ -419,6 +433,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X, Y],
|
||||
reference=eq,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
|
||||
workspace.FeedBlob('X', X)
|
||||
@@ -446,6 +461,7 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=inputs,
|
||||
reference=ref,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||
if test_grad:
|
||||
@@ -459,7 +475,8 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=inputs,
|
||||
reference=ref,
|
||||
)
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||
if test_grad:
|
||||
for i in range(len(inputs)):
|
||||
@@ -667,11 +684,32 @@ class TestElementwiseOps(hu.HypothesisTestCase):
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=reciprocal_op,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
self.assertGradientChecks(
|
||||
gc, op, [X], 0, [0], stepsize=1e-3, threshold=0.05)
|
||||
|
||||
@given(X=hu.tensor(dtype=np.bool), **hu.gcs)
|
||||
def test_not(self, X, gc, dc):
|
||||
def not_op(X):
|
||||
return [np.logical_not(X)]
|
||||
|
||||
op = core.CreateOperator(
|
||||
"Not",
|
||||
["X"],
|
||||
["Y"],
|
||||
)
|
||||
|
||||
self.assertReferenceChecks(
|
||||
device_option=gc,
|
||||
op=op,
|
||||
inputs=[X],
|
||||
reference=not_op,
|
||||
ensure_outputs_are_inferred=True,
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
||||
@@ -32,7 +32,8 @@ class TestMathOps(serial.SerializedTestCase):
|
||||
|
||||
self.assertReferenceChecks(gc, op, [X], powf,
|
||||
output_to_grad="Y",
|
||||
grad_reference=powf_grad),
|
||||
grad_reference=powf_grad,
|
||||
ensure_outputs_are_inferred=True)
|
||||
|
||||
@serial.given(X=hu.tensor(),
|
||||
exponent=st.floats(min_value=-3.0, max_value=3.0),
|
||||
@@ -44,7 +45,8 @@ class TestMathOps(serial.SerializedTestCase):
|
||||
op = core.CreateOperator(
|
||||
"Sign", ["X"], ["Y"])
|
||||
|
||||
self.assertReferenceChecks(gc, op, [X], signf),
|
||||
self.assertReferenceChecks(
|
||||
gc, op, [X], signf, ensure_outputs_are_inferred=True)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class SerializedTestCase(hu.HypothesisTestCase):
|
||||
grad_reference=None,
|
||||
atol=None,
|
||||
outputs_to_check=None,
|
||||
ensure_outputs_are_inferred=False,
|
||||
):
|
||||
outs = super(SerializedTestCase, self).assertReferenceChecks(
|
||||
device_option,
|
||||
@@ -241,6 +242,7 @@ class SerializedTestCase(hu.HypothesisTestCase):
|
||||
grad_reference,
|
||||
atol,
|
||||
outputs_to_check,
|
||||
ensure_outputs_are_inferred,
|
||||
)
|
||||
if not getattr(_output_context, 'disable_serialized_check', False):
|
||||
grad_ops = _getGradientOrNone(op)
|
||||
|
||||
Reference in New Issue
Block a user