[Mlir-commits] [mlir] [mlir][tosa] Fix tosa.transpose identity canonicalization type (PR #188020)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 04:42:19 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

Fixes #<!-- -->187974. The bug is caused by an assertion to check every fold canonicalization must have the same input and folded type. This MR proposed another way to canonicalize identity `tosa.transpose` operations that has unranked tensor type.

---
Full diff: https://github.com/llvm/llvm-project/pull/188020.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..7c990b54e8119 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1920,6 +1920,9 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
 
+  if (resultTy != getInput1().getType())
+    return {};
+
   return getInput1();
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..4f5aaa776e57e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1388,3 +1388,15 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
   %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
   return %1 : tensor<13x21x3xi16>
 }
+
+// -----
+
+// CHECK-LABEL: test_canonicalize_different_type_identity_transpose(
+// CHECK: tosa.const_shape
+// CHECK: %[[RESHAPE:.+]] = tosa.reshape
+// CHECK-SAME: -> tensor<*xi32>
+// CHECK: return %[[RESHAPE]]
+func.func @test_canonicalize_different_type_identity_transpose(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+  %0 = tosa.transpose %arg0 {perms = array<i32: 0, 1>} : (tensor<3x2xi32>) -> tensor<*xi32>
+  return %0 : tensor<*xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/188020


More information about the Mlir-commits mailing list