[Mlir-commits] [mlir] Fix tosa::TransposeOp::inferReturnTypeComponents() (PR #88656)

Maya Amrami llvmlistbot at llvm.org
Tue Apr 16 05:11:49 PDT 2024


https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/88656

>From d98dbd1b7ef8f3cac4a79b06e19be6f9ea39fab3 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 14 Apr 2024 11:48:56 +0300
Subject: [PATCH] [mlir] Enable tosa::TransposeOp verification

The interface InferTensorType was added to the op.
The op already implements InferShapedTypeOpInterface and InferTypeOpInterface,
thus the verifier is now generated automatically. In addition, more versions
of Tosa::TransposeOp::build are generated. If one gives result type - it is verified,
otherwise it is inferred
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 16 +++++-----
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 13 +++++---
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  9 ------
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  | 21 ++++---------
 mlir/test/Dialect/Tosa/invalid.mlir           | 16 ++++++++--
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 30 ++-----------------
 6 files changed, 38 insertions(+), 67 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..b0e90d32389cda 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1501,7 +1501,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
     /// Method used by InferTypeOpInterface.
@@ -1651,7 +1651,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -1707,7 +1707,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 //===----------------------------------------------------------------------===//
 // Operator: transpose
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
+def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", [InferTensorType]> {
   let summary = "Transpose operator";
 
   let description = [{
@@ -1834,9 +1834,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 
     | Mode                     | Input   | Output  |
     |--------------------------|---------|---------|
-    | signed 8 to bool         | int8    | Boolean | 
-    | signed 16 to bool        | int16   | Boolean | 
-    | signed 32 to bool        | int32   | Boolean | 
+    | signed 8 to bool         | int8    | Boolean |
+    | signed 16 to bool        | int16   | Boolean |
+    | signed 32 to bool        | int32   | Boolean |
     | bool to 8                | Boolean | int8    |
     | bool to 16               | Boolean | int16   |
     | bool to 32               | Boolean | int32   |
@@ -1850,8 +1850,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
     | float to signed 16       | float   | int16   |
     | signed 8 to float        | int8    | float   |
     | signed 16 to float       | int16   | float   |
-    | float 32 to float 64     | float32 | float64 | 
-    | float 64 to float 32     | float64 | float32 | 
+    | float 32 to float 64     | float32 | float64 |
+    | float 64 to float 32     | float64 | float32 |
   }];
 
   let arguments = (ins
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..e270363d3f3139 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1012,6 +1012,9 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   if (permsShape.hasRank() && permsShape.getRank() == 0)
     return failure();
 
+  Type inputType =
+      adaptor.getInput1().getType().cast<TensorType>().getElementType();
+
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
   if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -1029,7 +1032,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   SmallVector<int64_t> outputShape;
   // Rank-0 means no permutations matter.
   if (inputShape.getRank() == 0) {
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
@@ -1046,12 +1050,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   // permutation.
   if (allTheSame) {
     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
+  // If the permutations are a constant we can directly determine the output
   // shape.
   DenseIntElementsAttr attr;
   if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
@@ -1075,7 +1080,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6eac759a083645..4e7dadad7db578 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -425,15 +425,6 @@ func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform<i8:f32, 1.
   return %1 :  tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
 }
 
-// CHECK-LABEL: @transpose_canonicalize_strip_quant
-func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3xi8>) {
-  // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x1x3xi8>}> : () -> tensor<2x1x3xi8>
-  %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
-  %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
-  %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3xi8>
-  return %1 :  tensor<2x1x3xi8>
-}
-
 // CHECK-LABEL: @slice_fold
 func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
   // CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index de752f31fcbaa1..ca7337be386a24 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -20,10 +20,10 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
 }
 
 // CHECK-LABEL: @transpose_nofold_shape
-func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
+func.func @transpose_nofold_shape(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: tosa.transpose
   %0 = arith.constant dense<[1, 0]> : tensor<2xi32>
-  %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+  %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
@@ -87,11 +87,11 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2
 }
 
 // CHECK-LABEL: @transpose_nofold_non_cst_perms
-func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<?x?xf32> {
   %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
   // CHECK: tosa.transpose
-  %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
-  return %1 : tensor<3x2xf32>
+  %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
 }
 
 // CHECK-LABEL: @transpose_nofold_multi_users
@@ -103,15 +103,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
   return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
 }
 
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
-  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
-  // CHECK: tosa.transpose
-  %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-  return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
 // CHECK-LABEL: @transpose_nofold_dense_resource
 func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
   %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
@@ -1078,7 +1069,7 @@ func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
   // AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
   // AGGRESIVE:       %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
   // AGGRESIVE:       return %[[VAL_0:.*]] : tensor<1x3xi32>
-  
+
   // CHECK-LABEL:     func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
   // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   // CHECK:           %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 730ac41dd7a8d3..82b64eab3f09bb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -72,10 +72,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
 
 // -----
 
-func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
+func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<?x?x?xf32> {
   // expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
-  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
-  return %0 : tensor<3x13x21xf32>
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
 
 // -----
@@ -413,3 +413,13 @@ func.func @test_tile_invalid_multiples() {
   %1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
   return
 }
+
+// -----
+
+func.func @transpose_wrong_sizes(%arg0: tensor<1x1x1x1xi16>) -> (tensor<2x1x1x1xi16>) {
+  %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+  // expected-error at +2 {{'tosa.transpose' op inferred type(s) 'tensor<1x1x1x1xi16>' are incompatible with return type(s) of operation 'tensor<2x1x1x1xi16>'}}
+  // expected-error at +1 {{'tosa.transpose' op failed to infer returned types}}
+  %1 = tosa.transpose %arg0, %0 : (tensor<1x1x1x1xi16>, tensor<4xi32>) -> tensor<2x1x1x1xi16>
+  return %1 : tensor<2x1x1x1xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 2be120439ed68e..c018a770a7bf08 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -553,7 +553,7 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
 // CHECK-LABEL: @test_transpose_same
 func.func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () {
   // CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
-  %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+  %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
   return
 }
 
@@ -572,7 +572,7 @@ func.func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor
 func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
   %0 = arith.constant dense<[2, 1, 0]> : tensor<3xi32>
   // CHECK: tosa.transpose %arg0, %cst : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
-  %1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+  %1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
   return
 }
 
@@ -1374,29 +1374,3 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
   %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
   return %1 : tensor<?x16x16x16xf32>
 }
-
-// -----
-
-// CHECK-LABEL: test_rank_size_constant_permutation
-func.func @test_rank_size_constant_permutation() {
-  %c6 = arith.constant 6 : index
-  %cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
-  %14 = tensor.empty(%c6) : tensor<?x27xi64>
-  // Fail to infer the shape but not crash.
-  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: test_large_constant_permutation
-func.func @test_large_constant_permutation() {
-  %c6 = arith.constant 6 : index
-  %cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
-  %14 = tensor.empty(%c6) : tensor<?x27xi64>
-  // Fail to infer the shape but not crash.
-  // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
-  return
-}



More information about the Mlir-commits mailing list