[Mlir-commits] [mlir] f010621 - [mlir][tosa] Add a canonicalization to optimize cast cast sequences (#176904)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 06:16:12 PST 2026


Author: Luke Hutton
Date: 2026-01-23T14:16:07Z
New Revision: f010621809b3fab9f3d64fa39699ea5c0c11e699

URL: https://github.com/llvm/llvm-project/commit/f010621809b3fab9f3d64fa39699ea5c0c11e699
DIFF: https://github.com/llvm/llvm-project/commit/f010621809b3fab9f3d64fa39699ea5c0c11e699.diff

LOG: [mlir][tosa] Add a canonicalization to optimize cast cast sequences (#176904)

This commit introduces a new canonicalization over a sequence of cast
operations. cast->cast sequences can be simplified to a single cast when
no narrowing is performed inbetween. This optimization is limited to
integer types, since floating point casts may impact numerical
behaviour.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index edd8f0fc266bb..0005b6f5a0c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2507,6 +2507,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 14639ef5925ae..5f41c8c3f300f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -885,6 +885,55 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
               SliceDynamicSizeCanonicalization>(context);
 }
 
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+  using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::CastOp castOp,
+                                PatternRewriter &rewriter) const override {
+    const Value castInput = castOp.getInput();
+    auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+    if (!innerCastOp)
+      return rewriter.notifyMatchFailure(castOp,
+                                         "input must be cast operation");
+
+    const Value innerCastInput = innerCastOp.getInput();
+
+    const auto innerInputType =
+        llvm::cast<ShapedType>(innerCastInput.getType());
+    const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+    const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+    const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+                                              outerOutputType};
+    if (llvm::any_of(types, [](const ShapedType type) {
+          return !type.getElementType().isInteger();
+        }))
+      return rewriter.notifyMatchFailure(castOp,
+                                         "only integer types are supported");
+
+    // Check inner cast is non-narrowing
+    const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+    if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+      return rewriter.notifyMatchFailure(castOp,
+                                         "inner cast operation is narrowing");
+
+    // Check outer cast is non-narrowing from the inner cast input
+    if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+      return rewriter.notifyMatchFailure(castOp,
+                                         "outer cast operation is narrowing");
+
+    rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+                                              innerCastInput);
+
+    return success();
+  }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<NonNarrowingCastsOptimization>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Operator Folders.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..81e537babf9ab 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1466,3 +1466,57 @@ func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
   %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
   return %1 : tensor<i1>
 }
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32
+// CHECK: %[[OUT:.*]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+// CHECK: return %[[OUT]] : tensor<13x21x3xi32>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi32>
+  return %1 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i8
+// CHECK: return %arg0 : tensor<13x21x3xi8>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f32_to_f8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f32_to_f8(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf8E5M2> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+  return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi16>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+  return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8_to_i16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi16> {
+  %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
+  %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+  return %1 : tensor<13x21x3xi16>
+}


        


More information about the Mlir-commits mailing list