[Mlir-commits] [mlir] [mlir][tosa] Add a canonicalization to optimize cast cast sequences (PR #176904)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 20 03:47:12 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit introduces a new canonicalization over a sequence of cast operations. cast->cast sequences can be simplified to a single cast when no narrowing is performed inbetween. This optimization is limited to integer types, since floating point casts may impact numerical behaviour.

---
Full diff: https://github.com/llvm/llvm-project/pull/176904.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+49) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index edd8f0fc266bb..0005b6f5a0c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2507,6 +2507,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..ab864e9b9b6a3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -885,6 +885,55 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
               SliceDynamicSizeCanonicalization>(context);
 }
 
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+  using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::CastOp castOp,
+                                PatternRewriter &rewriter) const override {
+    const Value castInput = castOp.getInput();
+    auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+    if (!innerCastOp)
+      return rewriter.notifyMatchFailure(castOp,
+                                         "input must be cast operation");
+
+    const Value innerCastInput = innerCastOp.getInput();
+
+    const auto innerInputType =
+        llvm::cast<ShapedType>(innerCastInput.getType());
+    const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+    const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+    const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+                                        outerOutputType};
+    if (llvm::any_of(types, [](const ShapedType type) {
+          return !type.getElementType().isInteger();
+        }))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "only integer types are supported");
+
+    // Check inner cast is non-narrowing
+    const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+    if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+      return rewriter.notifyMatchFailure(castOp,
+                                         "inner cast operation is narrowing");
+
+    // Check outer cast is non-narrowing from the inner cast input
+    if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+      return rewriter.notifyMatchFailure(castOp,
+                                         "outer cast operation is narrowing");
+
+    rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+                                              innerCastInput);
+
+    return success();
+  }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<NonNarrowingCastsOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..aae5c0dbbb867 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1466,3 +1466,24 @@ func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
   %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
   return %1 : tensor<i1>
 }
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32
+// CHECK: %[[OUT:.*]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+// CHECK: return %[[OUT]] : tensor<13x21x3xi32>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi32>
+  return %1 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i8
+// CHECK: return %arg0 : tensor<13x21x3xi8>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/176904


More information about the Mlir-commits mailing list