[Mlir-commits] [mlir] [mlir][tosa] Fold tensor.cast into tosa.transpose (PR #170029)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 01:26:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tomer Solomon (recursion-man)
<details>
<summary>Changes</summary>
Push tensor.cast operations past tosa.transpose when the cast goes from a more static type to a more dynamic one. This lets the transpose operate on the more specific input type and preserves shape information. A cast back to the original result type is inserted for compatibility with existing users.
For example:
```mlir
%cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
%t = tosa.transpose %cast {perms = [0, 2, 1]} : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
```
is canonicalized to:
```mlir
%t = tosa.transpose %input {perms = [0, 2, 1]} : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
%cast = tensor.cast %t : tensor<6x40x256xi8> to tensor<6x?x256xi8>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/170029.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+56-1)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+32)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..6d1e4601475ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -411,9 +411,64 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
}
};
+/// Pattern to fold a tensor.cast into a tosa.transpose operation.
+///
+/// This pattern pushes tensor.cast operations past transpose when the cast
+/// goes from a more static type to a less static (more dynamic) type. This
+/// allows the transpose to operate on more refined types, enabling better
+/// optimizations and type inference in downstream operations.
+///
+/// The pattern adds a cast back to the original result type for compatibility
+/// with existing users.
+///
+/// Example:
+/// ```
+/// %cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
+/// %transpose = tosa.transpose %cast {perms = [0, 2, 1]}
+/// : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
+/// ```
+/// is canonicalized to:
+/// ```
+/// %transpose = tosa.transpose %input {perms = [0, 2, 1]}
+/// : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
+/// %cast = tensor.cast %transpose
+/// : tensor<6x40x256xi8> to tensor<6x?x256xi8>
+/// ```
+struct TransposeOpCastFolder : public OpRewritePattern<tosa::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ if (!tensor::hasFoldableTensorCastOperand(transposeOp))
+ return rewriter.notifyMatchFailure(transposeOp,
+ "no foldable cast operand");
+
+ auto castOp = cast<tensor::CastOp>(transposeOp.getInput1().getDefiningOp());
+ auto srcType = cast<RankedTensorType>(castOp.getSource().getType());
+ auto oldResultType = cast<RankedTensorType>(transposeOp.getType());
+
+ ArrayRef<int32_t> perms = transposeOp.getPerms();
+ assert(perms.size() == static_cast<size_t>(srcType.getRank()) &&
+ "permutation size must match source rank");
+ SmallVector<int64_t> newShape;
+ newShape.reserve(srcType.getRank());
+ for (int32_t perm : perms)
+ newShape.push_back(srcType.getDimSize(perm));
+ auto newResultType = RankedTensorType::get(
+ newShape, srcType.getElementType(), srcType.getEncoding());
+ auto newTransposeOp = tosa::TransposeOp::create(
+ rewriter, transposeOp.getLoc(), newResultType, castOp.getSource(), perms);
+
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ transposeOp, oldResultType, newTransposeOp);
+ return success();
+ }
+};
+
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
+ results.add<ConsolidateTransposeOptimization, TransposeIsReshape,
+ TransposeOpCastFolder>(context);
}
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..07fc8a38c3157 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -411,6 +411,38 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @fold_relaxing_cast_into_transpose
+func.func @fold_relaxing_cast_into_transpose(%arg0: tensor<6x256x40xi8>) -> tensor<6x?x256xi8> {
+// CHECK: %[[VAL_1:.*]] = tosa.transpose %arg0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
+// CHECK: tensor.cast %[[VAL_1]] : tensor<6x40x256xi8> to tensor<6x?x256xi8>
+ %0 = tensor.cast %arg0 : tensor<6x256x40xi8> to tensor<6x256x?xi8>
+ %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
+ return %1 : tensor<6x?x256xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @no_fold_refining_cast_into_transpose(
+func.func @no_fold_refining_cast_into_transpose(%arg0: tensor<?x?x256xf32>) -> tensor<?x256x8xf32> {
+// CHECK: %[[VAL_1:.*]] = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
+// CHECK: tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+ %0 = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
+ %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+ return %1 : tensor<?x256x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @elide_identity_cast_before_transpose
+func.func @elide_identity_cast_before_transpose(%arg0: tensor<?x8x256xf32>) -> tensor<?x256x8xf32> {
+// CHECK-NOT: tensor.cast
+ %0 = tensor.cast %arg0 : tensor<?x8x256xf32> to tensor<?x8x256xf32>
+ %1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
+ return %1 : tensor<?x256x8xf32>
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_stride_2
func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
// CHECK: tosa.conv2d
``````````
</details>
https://github.com/llvm/llvm-project/pull/170029
More information about the Mlir-commits
mailing list