[Mlir-commits] [mlir] [mlir][tosa] Add CastFolder for tosa.table (PR #170044)
Tomer Solomon
llvmlistbot at llvm.org
Sun Nov 30 10:38:10 PST 2025
https://github.com/recursion-man created https://github.com/llvm/llvm-project/pull/170044
Push tensor.cast operation past tosa.table when the cast goes from a more static type to a less static (more dynamic) type. This allows the table to operate on more refined types, enabling better optimizations and type inference in downstream operations. The pattern adds a cast back to the original result type for compatibility with existing users.
For example:
```mlir
%cast = tensor.cast %input : tensor<6x256xi8> to tensor<?x256xi8>
%table_out = tosa.table %cast, %table_tensor
: (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
```
Can be folded to:
```mlir
%table_out = tosa.table %input, %table_tensor
: (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
%cast = tensor.cast %table_out
: tensor<6x256xi8> to tensor<?x256xi8>
```
>From 5374b4dd681bfc74d009cbb11da6750fae7c1265 Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Sun, 30 Nov 2025 18:29:06 +0200
Subject: [PATCH] [mlir][tosa] Add CastFolder for tosa.table
Push tensor.cast operation past tosa.table when the cast
goes from a more static type to a less static (more dynamic) type. This
allows the table to operate on more refined types, enabling better
optimizations and type inference in downstream operations.
The pattern adds a cast back to the original result type for compatibility
with existing users.
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 52 +++++++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 11 ++++
3 files changed, 65 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index bb8faf01802fa..e31e1ddf0613e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1194,6 +1194,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
}];
let hasVerifier = 1;
+
+ let hasCanonicalizer = 1;
let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..a86725cf15b12 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -416,6 +416,58 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
}
+/// Fold `tensor.cast` into `tosa.table`.
+///
+/// This pattern pushes tensor.cast operations past table when the cast
+/// goes from a more static type to a less static (more dynamic) type. This
+/// allows the table to operate on more refined types, enabling better
+/// optimizations and type inference in downstream operations.
+/// The pattern adds a cast back to the original result type for compatibility
+/// with existing users.
+/// For example:
+/// ```mlir
+/// %cast = tensor.cast %input : tensor<6x256xi8> to tensor<?x256xi8>
+/// %table_out = tosa.table %cast, %table_tensor
+/// : (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
+/// ```
+/// Can be folded to:
+/// ```mlir
+/// %table_out = tosa.table %input, %table_tensor
+/// : (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
+/// %cast = tensor.cast %table_out
+/// : tensor<6x256xi8> to tensor<?x256xi8>
+/// ```
+/// The result cast may be folded away in subsequent canonicalization if users
+/// can accept the more static type.
+struct TableOpCastFolder : public OpRewritePattern<tosa::TableOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TableOp tableOp,
+ PatternRewriter &rewriter) const override {
+ if (!tensor::hasFoldableTensorCastOperand(tableOp))
+ return rewriter.notifyMatchFailure(tableOp, "no foldable cast operand");
+ auto castOp = cast<tensor::CastOp>(tableOp.getInput1().getDefiningOp());
+ auto srcType = cast<RankedTensorType>(castOp.getSource().getType());
+ auto oldResultType = cast<RankedTensorType>(tableOp.getType());
+ auto newResultType = RankedTensorType::get(srcType.getShape(),
+ oldResultType.getElementType(),
+ oldResultType.getEncoding());
+
+ auto newTableOp =
+ tosa::TableOp::create(rewriter, tableOp.getLoc(), newResultType,
+ castOp.getSource(), tableOp.getTable());
+
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(tableOp, oldResultType,
+ newTableOp);
+ return success();
+ }
+};
+
+void TableOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<TableOpCastFolder>(context);
+}
+
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..a16e5f4a189cd 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -411,6 +411,17 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
// -----
+// CHECK-LABEL: @fold_cast_into_table
+func.func @fold_cast_into_table(%arg0: tensor<6x256xi8>, %arg1: tensor<256xi8>) -> tensor<?x256xi8> {
+ // CHECK: %[[VAL_0:.*]] = tosa.table %arg0, %arg1 : (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
+ // CHECK: tensor.cast %[[VAL_0]] : tensor<6x256xi8> to tensor<?x256xi8>
+ %0 = tensor.cast %arg0 : tensor<6x256xi8> to tensor<?x256xi8>
+ %1 = tosa.table %0, %arg1 : (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
+ return %1 : tensor<?x256xi8>
+}
+
+// -----
+
// CHECK-LABEL: @conv2d_stride_2
func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
// CHECK: tosa.conv2d
More information about the Mlir-commits
mailing list