[MHLO:Linalg] Add support for lowering reshape of unsigned tensors

PiperOrigin-RevId: 373461627
Change-Id: Icad96f7e001567eb920696f17c81d6f48d5b8c2c
This commit is contained in:
Hanhan Wang
2021-05-12 15:13:20 -07:00
committed by TensorFlower Gardener
parent 370566e547
commit 47f2d76ab5
2 changed files with 29 additions and 0 deletions

View File

@@ -782,6 +782,9 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
return failure();
result_type = this->typeConverter->convertType(result_type)
.template cast<ShapedType>();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> src_shape =
(operand_type.getRank() > result_type.getRank()

View File

@@ -482,6 +482,19 @@ func @reshape_0D_1D(%arg0: tensor<i32>) -> tensor<1xi32> {
// -----
func @reshape_0D_1D_unsigned(%arg0: tensor<ui32>) -> tensor<1xui32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<ui32>) -> tensor<1xui32>
return %0 : tensor<1xui32>
}
// CHECK-LABEL: func @reshape_0D_1D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<ui32> to tensor<i32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<i32> into tensor<1xi32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<1xi32> to tensor<1xui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<1xui32>
// -----
// CHECK-LABEL: func @reshape_1D_0D
func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<i32>
@@ -491,6 +504,19 @@ func @reshape_1D_0D(%arg0: tensor<1xi32>) -> tensor<i32> {
// -----
func @reshape_1D_0D_unsigned(%arg0: tensor<1xui32>) -> tensor<ui32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<1xui32>) -> tensor<ui32>
return %0 : tensor<ui32>
}
// CHECK-LABEL: func @reshape_1D_0D_unsigned
// CHECK-SAME: %[[ARG_UNSIGNED:[a-zA-Z0-9_]*]]
// CHECK: %[[ARG_SIGNLESS:.*]] = unrealized_conversion_cast %[[ARG_UNSIGNED]] : tensor<1xui32> to tensor<1xi32>
// CHECK: %[[RET_SIGNLESS:.*]] = linalg.tensor_reshape %[[ARG_SIGNLESS]] [] : tensor<1xi32> into tensor<i32>
// CHECK: %[[RET_UNSIGNED:.*]] = unrealized_conversion_cast %[[RET_SIGNLESS]] : tensor<i32> to tensor<ui32>
// CHECK: return %[[RET_UNSIGNED]] : tensor<ui32>
// -----
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>