[Mlir-commits] [mlir] [mlir][tosa] Add CastFolder for tosa.table (PR #170044)

Tomer Solomon llvmlistbot at llvm.org
Sun Nov 30 10:38:10 PST 2025


https://github.com/recursion-man created https://github.com/llvm/llvm-project/pull/170044

Push tensor.cast operation past tosa.table when the cast goes from a more static type to a less static (more dynamic) type. This allows the table to operate on more refined types, enabling better optimizations and type inference in downstream operations. The pattern adds a cast back to the original result type for compatibility with existing users.

 For example:
 ```mlir
   %cast = tensor.cast %input : tensor<6x256xi8> to tensor<?x256xi8>
   %table_out = tosa.table %cast, %table_tensor
     : (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
 ```
 Can be folded to:
 ```mlir
   %table_out = tosa.table %input, %table_tensor
     : (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
   %cast = tensor.cast %table_out
     : tensor<6x256xi8> to tensor<?x256xi8>
 ```

>From 5374b4dd681bfc74d009cbb11da6750fae7c1265 Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Sun, 30 Nov 2025 18:29:06 +0200
Subject: [PATCH] [mlir][tosa] Add CastFolder for tosa.table

Push tensor.cast operation past tosa.table when the cast
goes from a more static type to a less static (more dynamic) type. This
allows the table to operate on more refined types, enabling better
optimizations and type inference in downstream operations.
The pattern adds a cast back to the original result type for compatibility
with existing users.
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  2 +
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 52 +++++++++++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 11 ++++
 3 files changed, 65 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index bb8faf01802fa..e31e1ddf0613e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1194,6 +1194,8 @@ def Tosa_TableOp : Tosa_InferShapedTypeOp<"table"> {
   }];
 
   let hasVerifier = 1;
+  
+  let hasCanonicalizer = 1;
 
   let assemblyFormat =
       "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..a86725cf15b12 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -416,6 +416,58 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
 }
 
+/// Fold `tensor.cast` into `tosa.table`.
+///
+/// This pattern pushes tensor.cast operations past table when the cast
+/// goes from a more static type to a less static (more dynamic) type. This
+/// allows the table to operate on more refined types, enabling better
+/// optimizations and type inference in downstream operations.
+/// The pattern adds a cast back to the original result type for compatibility
+/// with existing users.
+/// For example:
+/// ```mlir
+///   %cast = tensor.cast %input : tensor<6x256xi8> to tensor<?x256xi8>
+///   %table_out = tosa.table %cast, %table_tensor
+///     : (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
+/// ```
+/// Can be folded to:
+/// ```mlir
+///   %table_out = tosa.table %input, %table_tensor
+///     : (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
+///   %cast = tensor.cast %table_out
+///     : tensor<6x256xi8> to tensor<?x256xi8>
+/// ```
+/// The result cast may be folded away in subsequent canonicalization if users
+/// can accept the more static type.
+struct TableOpCastFolder : public OpRewritePattern<tosa::TableOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TableOp tableOp,
+                                PatternRewriter &rewriter) const override {
+    if (!tensor::hasFoldableTensorCastOperand(tableOp))
+      return rewriter.notifyMatchFailure(tableOp, "no foldable cast operand");
+    auto castOp = cast<tensor::CastOp>(tableOp.getInput1().getDefiningOp());
+    auto srcType = cast<RankedTensorType>(castOp.getSource().getType());
+    auto oldResultType = cast<RankedTensorType>(tableOp.getType());
+    auto newResultType = RankedTensorType::get(srcType.getShape(),
+                                               oldResultType.getElementType(),
+                                               oldResultType.getEncoding());
+
+    auto newTableOp =
+        tosa::TableOp::create(rewriter, tableOp.getLoc(), newResultType,
+                              castOp.getSource(), tableOp.getTable());
+
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(tableOp, oldResultType,
+                                                newTableOp);
+    return success();
+  }
+};
+
+void TableOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                          MLIRContext *context) {
+  results.add<TableOpCastFolder>(context);
+}
+
 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
   using OpRewritePattern::OpRewritePattern;
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..a16e5f4a189cd 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -411,6 +411,17 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_cast_into_table
+func.func @fold_cast_into_table(%arg0: tensor<6x256xi8>, %arg1: tensor<256xi8>) -> tensor<?x256xi8> {
+  // CHECK:           %[[VAL_0:.*]] = tosa.table %arg0, %arg1 : (tensor<6x256xi8>, tensor<256xi8>) -> tensor<6x256xi8>
+  // CHECK:           tensor.cast %[[VAL_0]] : tensor<6x256xi8> to tensor<?x256xi8>
+  %0 = tensor.cast %arg0 : tensor<6x256xi8> to tensor<?x256xi8>
+  %1 = tosa.table %0, %arg1 : (tensor<?x256xi8>, tensor<256xi8>) -> tensor<?x256xi8>
+  return %1 : tensor<?x256xi8>
+}
+
+// -----
+
 // CHECK-LABEL: @conv2d_stride_2
 func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
   // CHECK: tosa.conv2d



More information about the Mlir-commits mailing list