[Mlir-commits] [mlir] Fix tosa::TransposeOp::inferReturnTypeComponents() (PR #88656)
Maya Amrami
llvmlistbot at llvm.org
Tue Apr 16 05:11:49 PDT 2024
https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/88656
>From d98dbd1b7ef8f3cac4a79b06e19be6f9ea39fab3 Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 14 Apr 2024 11:48:56 +0300
Subject: [PATCH] [mlir] Enable tosa::TransposeOp verification
The interface InferTensorType was added to the op.
The op already implements InferShapedTypeOpInterface and InferTypeOpInterface,
thus the verifier is now generated automatically. In addition, more versions
of Tosa::TransposeOp::build are generated. If one gives result type - it is verified,
otherwise it is inferred
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 16 +++++-----
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 +++++---
mlir/test/Dialect/Tosa/canonicalize.mlir | 9 ------
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 21 ++++---------
mlir/test/Dialect/Tosa/invalid.mlir | 16 ++++++++--
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 30 ++-----------------
6 files changed, 38 insertions(+), 67 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 306e4a43952088..b0e90d32389cda 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1501,7 +1501,7 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
let hasFolder = 1;
let hasVerifier = 1;
-
+
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
@@ -1651,7 +1651,7 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
let hasFolder = 1;
let hasVerifier = 1;
-
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
@@ -1707,7 +1707,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
//===----------------------------------------------------------------------===//
// Operator: transpose
//===----------------------------------------------------------------------===//
-def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose"> {
+def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose", [InferTensorType]> {
let summary = "Transpose operator";
let description = [{
@@ -1834,9 +1834,9 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| Mode | Input | Output |
|--------------------------|---------|---------|
- | signed 8 to bool | int8 | Boolean |
- | signed 16 to bool | int16 | Boolean |
- | signed 32 to bool | int32 | Boolean |
+ | signed 8 to bool | int8 | Boolean |
+ | signed 16 to bool | int16 | Boolean |
+ | signed 32 to bool | int32 | Boolean |
| bool to 8 | Boolean | int8 |
| bool to 16 | Boolean | int16 |
| bool to 32 | Boolean | int32 |
@@ -1850,8 +1850,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| float to signed 16 | float | int16 |
| signed 8 to float | int8 | float |
| signed 16 to float | int16 | float |
- | float 32 to float 64 | float32 | float64 |
- | float 64 to float 32 | float64 | float32 |
+ | float 32 to float 64 | float32 | float64 |
+ | float 64 to float 32 | float64 | float32 |
}];
let arguments = (ins
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e06ac9a27ae4cc..e270363d3f3139 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1012,6 +1012,9 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
if (permsShape.hasRank() && permsShape.getRank() == 0)
return failure();
+ Type inputType =
+ adaptor.getInput1().getType().cast<TensorType>().getElementType();
+
// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
@@ -1029,7 +1032,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
SmallVector<int64_t> outputShape;
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
}
@@ -1046,12 +1050,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
// permutation.
if (allTheSame) {
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
+ // If the permutations are a constant we can directly determine the output
// shape.
DenseIntElementsAttr attr;
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
@@ -1075,7 +1080,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
}
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 6eac759a083645..4e7dadad7db578 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -425,15 +425,6 @@ func.func @reshape_canonicalize_quant() -> (tensor<1x3x!quant.uniform<i8:f32, 1.
return %1 : tensor<1x3x!quant.uniform<i8:f32, 1.000000e+00>>
}
-// CHECK-LABEL: @transpose_canonicalize_strip_quant
-func.func @transpose_canonicalize_strip_quant() -> (tensor<2x1x3xi8>) {
- // CHECK: "tosa.const"() <{value = dense<0> : tensor<2x1x3xi8>}> : () -> tensor<2x1x3xi8>
- %perms = "tosa.const"() {value = dense<[1, 0, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
- %0 = "tosa.const"() {value = dense<0> : tensor<1x2x3xi8>} : ()-> tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>
- %1 = tosa.transpose %0, %perms : (tensor<1x2x3x!quant.uniform<i8:f32, 1.000000e+00>>, tensor<3xi32>) -> tensor<2x1x3xi8>
- return %1 : tensor<2x1x3xi8>
-}
-
// CHECK-LABEL: @slice_fold
func.func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: return %arg0
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index de752f31fcbaa1..ca7337be386a24 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -20,10 +20,10 @@ func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> {
}
// CHECK-LABEL: @transpose_nofold_shape
-func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
+func.func @transpose_nofold_shape(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: tosa.transpose
%0 = arith.constant dense<[1, 0]> : tensor<2xi32>
- %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ %1 = tosa.transpose %arg0, %0 { perms = [1, 0] }: (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
@@ -87,11 +87,11 @@ func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2
}
// CHECK-LABEL: @transpose_nofold_non_cst_perms
-func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> {
+func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<?x?xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
// CHECK: tosa.transpose
- %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
- return %1 : tensor<3x2xf32>
+ %1 = tosa.transpose %input, %perms : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @transpose_nofold_multi_users
@@ -103,15 +103,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
}
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
- %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
- %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
- // CHECK: tosa.transpose
- %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
- return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
// CHECK-LABEL: @transpose_nofold_dense_resource
func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
%0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
@@ -1078,7 +1069,7 @@ func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
// AGGRESIVE-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
// AGGRESIVE: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
// AGGRESIVE: return %[[VAL_0:.*]] : tensor<1x3xi32>
-
+
// CHECK-LABEL: func.func @reduce_sum_constant_aggressive() -> tensor<1x3xi32> {
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
// CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 730ac41dd7a8d3..82b64eab3f09bb 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -72,10 +72,10 @@ func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<i8>) -> t
// -----
-func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> {
+func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<?x?x?xf32> {
// expected-error at +1 {{'tosa.transpose' op perms of transpose is not constant}}
- %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32>
- return %0 : tensor<3x13x21xf32>
+ %0 = tosa.transpose %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
}
// -----
@@ -413,3 +413,13 @@ func.func @test_tile_invalid_multiples() {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
return
}
+
+// -----
+
+func.func @transpose_wrong_sizes(%arg0: tensor<1x1x1x1xi16>) -> (tensor<2x1x1x1xi16>) {
+ %0 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
+ // expected-error at +2 {{'tosa.transpose' op inferred type(s) 'tensor<1x1x1x1xi16>' are incompatible with return type(s) of operation 'tensor<2x1x1x1xi16>'}}
+ // expected-error at +1 {{'tosa.transpose' op failed to infer returned types}}
+ %1 = tosa.transpose %arg0, %0 : (tensor<1x1x1x1xi16>, tensor<4xi32>) -> tensor<2x1x1x1xi16>
+ return %1 : tensor<2x1x1x1xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 2be120439ed68e..c018a770a7bf08 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -553,7 +553,7 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// CHECK-LABEL: @test_transpose_same
func.func @test_transpose_same(%arg0 : tensor<4x4x4xi32>, %arg1 : tensor<3xi32>) -> () {
// CHECK: tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
- %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+ %0 = tosa.transpose %arg0, %arg1 : (tensor<4x4x4xi32>, tensor<3xi32>) -> tensor<4x4x4xi32>
return
}
@@ -572,7 +572,7 @@ func.func @test_transpose_perm_unknown(%arg0 : tensor<4x4x5xi32>, %arg1 : tensor
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
%0 = arith.constant dense<[2, 1, 0]> : tensor<3xi32>
// CHECK: tosa.transpose %arg0, %cst : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
- %1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
+ %1 = tosa.transpose %arg0, %0 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<5x4x3xi32>
return
}
@@ -1374,29 +1374,3 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
return %1 : tensor<?x16x16x16xf32>
}
-
-// -----
-
-// CHECK-LABEL: test_rank_size_constant_permutation
-func.func @test_rank_size_constant_permutation() {
- %c6 = arith.constant 6 : index
- %cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
- %14 = tensor.empty(%c6) : tensor<?x27xi64>
- // Fail to infer the shape but not crash.
- // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
- %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
- return
-}
-
-// -----
-
-// CHECK-LABEL: test_large_constant_permutation
-func.func @test_large_constant_permutation() {
- %c6 = arith.constant 6 : index
- %cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
- %14 = tensor.empty(%c6) : tensor<?x27xi64>
- // Fail to infer the shape but not crash.
- // CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
- %72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
- return
-}
More information about the Mlir-commits
mailing list