[Mlir-commits] [mlir] [mlir][tosa] Add `AllElementTypesMatch` trait for `tosa.transpose` (PR #120964)

Longsheng Mou llvmlistbot at llvm.org
Mon Dec 23 05:38:43 PST 2024


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/120964

This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor. Fixes #119364.

>From 09691665d988eefd86b7a25b41b036b359401927 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Mon, 23 Dec 2024 19:34:04 +0800
Subject: [PATCH] [mlir][tosa] Add `AllElementTypesMatch` trait for
 `tosa.transpose`

This PR adds `AllElementTypesMatch` trait for `tosa.transpose` to ensure output tensor of same type as the input tensor.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td       | 3 ++-
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ----
 mlir/test/Dialect/Tosa/constant-op-fold.mlir       | 9 ---------
 mlir/test/Dialect/Tosa/invalid.mlir                | 9 +++++++++
 4 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..8ae5d3ab417b69 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1698,7 +1698,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 // Operator: transpose
 //===----------------------------------------------------------------------===//
 def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
-                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+                [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+                 AllElementTypesMatch<["input1", "output"]>]> {
   let summary = "Transpose operator";
 
   let description = [{
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..f51c3dbce6eefe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1002,10 +1002,6 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
       return input.reshape(resultTy);
   }
 
-  // Transpose does not change the input type.
-  if (getInput1().getType() != getType())
-    return {};
-
   // Transpose is not the identity transpose.
   SmallVector<int32_t> perms;
   if (getConstantPerms(perms).failed())
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..8198903b78ac05 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -117,15 +117,6 @@ func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>)
   return %1, %input : tensor<3x2xf32>, tensor<2x3xf32>
 }
 
-// CHECK-LABEL: @transpose_nofold_quantized_types
-func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>> {
-  %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32>
-  %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8>
-  // CHECK: tosa.transpose
-  %0 = tosa.transpose %input, %perms : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-  return %0: tensor<1x1x2x2x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01}>>
-}
-
 // CHECK-LABEL: @transpose_nofold_dense_resource
 func.func @transpose_nofold_dense_resource() -> tensor<2x2xf32> {
   %0 = "tosa.const"() <{value = dense_resource<resource> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index cca50b25d14d6b..b796a6343e5ed1 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -206,6 +206,15 @@ func.func @test_transpose_invalid_permutation_types_dynamic_dim_ok(%arg0: tensor
 
 // -----
 
+func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tensor<3x2xf32> {
+  %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // expected-error at +1 {{'tosa.transpose' op failed to verify that all of {input1, output} have same element type}}
+  %1 = tosa.transpose %arg0, %perms : (tensor<2x3xi32>, tensor<2xi32>) -> tensor<3x2xf32>
+  return %1 : tensor<3x2xf32>
+}
+
+// -----
+
 func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
   %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
   %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>



More information about the Mlir-commits mailing list