[Mlir-commits] [mlir] [mlir][tosa] Fold tensor.cast into tosa.transpose (PR #170029)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 01:26:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Tomer Solomon (recursion-man)

<details>
<summary>Changes</summary>

Push tensor.cast operations past tosa.transpose when the cast goes from a more static type to a more dynamic one. This lets the transpose operate on the more specific input type and preserves shape information. A cast back to the original result type is inserted for compatibility with existing users.

For example:

```mlir
%cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
%t = tosa.transpose %cast {perms = [0, 2, 1]} : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
```
is canonicalized to:
```mlir
%t = tosa.transpose %input {perms = [0, 2, 1]} : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
%cast = tensor.cast %t : tensor<6x40x256xi8> to tensor<6x?x256xi8>
```


---
Full diff: https://github.com/llvm/llvm-project/pull/170029.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+56-1) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+32) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..6d1e4601475ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -411,9 +411,64 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
   }
 };
 
+/// Pattern to fold a tensor.cast into a tosa.transpose operation.
+///
+/// This pattern pushes tensor.cast operations past transpose when the cast
+/// goes from a more static type to a less static (more dynamic) type. This
+/// allows the transpose to operate on more refined types, enabling better
+/// optimizations and type inference in downstream operations.
+///
+/// The pattern adds a cast back to the original result type for compatibility
+/// with existing users.
+///
+/// Example:
+/// ```
+///   %cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
+///   %transpose = tosa.transpose %cast {perms = [0, 2, 1]}
+///     : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
+/// ```
+/// is canonicalized to:
+/// ```
+///   %transpose = tosa.transpose %input {perms = [0, 2, 1]}
+///     : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
+///   %cast = tensor.cast %transpose
+///     : tensor<6x40x256xi8> to tensor<6x?x256xi8>
+/// ```
+struct TransposeOpCastFolder : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!tensor::hasFoldableTensorCastOperand(transposeOp))
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "no foldable cast operand");
+    
+    auto castOp = cast<tensor::CastOp>(transposeOp.getInput1().getDefiningOp());
+    auto srcType = cast<RankedTensorType>(castOp.getSource().getType());
+    auto oldResultType = cast<RankedTensorType>(transposeOp.getType());
+
+    ArrayRef<int32_t> perms = transposeOp.getPerms();
+    assert(perms.size() == static_cast<size_t>(srcType.getRank()) &&
+           "permutation size must match source rank");
+    SmallVector<int64_t> newShape;
+    newShape.reserve(srcType.getRank());
+    for (int32_t perm : perms)
+      newShape.push_back(srcType.getDimSize(perm));
+    auto newResultType = RankedTensorType::get(
+        newShape, srcType.getElementType(), srcType.getEncoding());
+    auto newTransposeOp = tosa::TransposeOp::create(
+        rewriter, transposeOp.getLoc(), newResultType, castOp.getSource(), perms);
+    
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        transposeOp, oldResultType, newTransposeOp);
+    return success();
+  }
+};
+
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
+  results.add<ConsolidateTransposeOptimization, TransposeIsReshape,
+              TransposeOpCastFolder>(context);
 }
 
 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..07fc8a38c3157 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -411,6 +411,38 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_relaxing_cast_into_transpose
+func.func @fold_relaxing_cast_into_transpose(%arg0: tensor<6x256x40xi8>) -> tensor<6x?x256xi8> {
+// CHECK:           %[[VAL_1:.*]] = tosa.transpose %arg0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
+// CHECK:           tensor.cast %[[VAL_1]] : tensor<6x40x256xi8> to tensor<6x?x256xi8>
+  %0 = tensor.cast %arg0 : tensor<6x256x40xi8> to tensor<6x256x?xi8>
+  %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
+  return %1 : tensor<6x?x256xi8>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @no_fold_refining_cast_into_transpose(
+func.func @no_fold_refining_cast_into_transpose(%arg0: tensor<?x?x256xf32>) -> tensor<?x256x8xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
+// CHECK:           tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+  %0 = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
+  %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+  return %1 : tensor<?x256x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @elide_identity_cast_before_transpose
+func.func @elide_identity_cast_before_transpose(%arg0: tensor<?x8x256xf32>) -> tensor<?x256x8xf32> {
+// CHECK-NOT:           tensor.cast
+  %0 = tensor.cast %arg0 : tensor<?x8x256xf32> to tensor<?x8x256xf32>
+  %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+  return %1 : tensor<?x256x8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
   // CHECK: tosa.conv2d

``````````

</details>


https://github.com/llvm/llvm-project/pull/170029


More information about the Mlir-commits mailing list