[Mlir-commits] [mlir] [mlir][tosa] Fix tosa.transpose identity canonicalization type (PR #188020)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 04:42:19 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Hocky Yudhiono (hockyy)
<details>
<summary>Changes</summary>
Fixes #<!-- -->187974. The bug is caused by an assertion to check every fold canonicalization must have the same input and folded type. This MR proposed another way to canonicalize identity `tosa.transpose` operations that has unranked tensor type.
---
Full diff: https://github.com/llvm/llvm-project/pull/188020.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+3)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..7c990b54e8119 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1920,6 +1920,9 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
+ if (resultTy != getInput1().getType())
+ return {};
+
return getInput1();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..4f5aaa776e57e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1388,3 +1388,15 @@ func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21
%1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
return %1 : tensor<13x21x3xi16>
}
+
+// -----
+
+// CHECK-LABEL: test_canonicalize_different_type_identity_transpose(
+// CHECK: tosa.const_shape
+// CHECK: %[[RESHAPE:.+]] = tosa.reshape
+// CHECK-SAME: -> tensor<*xi32>
+// CHECK: return %[[RESHAPE]]
+func.func @test_canonicalize_different_type_identity_transpose(%arg0: tensor<3x2xi32>) -> tensor<*xi32> {
+ %0 = tosa.transpose %arg0 {perms = array<i32: 0, 1>} : (tensor<3x2xi32>) -> tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/188020
More information about the Mlir-commits
mailing list