[Mlir-commits] [mlir] [mlir][tosa] Add a canonicalization to optimize cast cast sequences (PR #176904)
Luke Hutton
llvmlistbot at llvm.org
Thu Jan 22 01:17:22 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/176904
>From 2f47e2e94f88d143d2225a6930debf232a9c8158 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 20 Jan 2026 11:36:07 +0000
Subject: [PATCH 1/2] [mlir][tosa] Add a canonicalization to optimize cast cast
sequences
This commit introduces a new canonicalization over a sequence of cast
operations. cast->cast sequences can be simplified to a single cast
when no narrowing is performed inbetween. This optimization is
limited to integer types, since floating point casts may impact
numerical behaviour.
Change-Id: I63af06838162a882b2739b08cafeb42301dd1a25
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 +
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 49 +++++++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 21 ++++++++
3 files changed, 71 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index edd8f0fc266bb..0005b6f5a0c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2507,6 +2507,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..6bc8a05d8fe1d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -885,6 +885,55 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
SliceDynamicSizeCanonicalization>(context);
}
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+ using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ const Value castInput = castOp.getInput();
+ auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+ if (!innerCastOp)
+ return rewriter.notifyMatchFailure(castOp,
+ "input must be cast operation");
+
+ const Value innerCastInput = innerCastOp.getInput();
+
+ const auto innerInputType =
+ llvm::cast<ShapedType>(innerCastInput.getType());
+ const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+ const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+ const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+ outerOutputType};
+ if (llvm::any_of(types, [](const ShapedType type) {
+ return !type.getElementType().isInteger();
+ }))
+ return rewriter.notifyMatchFailure(castOp,
+ "only integer types are supported");
+
+ // Check inner cast is non-narrowing
+ const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+ if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "inner cast operation is narrowing");
+
+ // Check outer cast is non-narrowing from the inner cast input
+ if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "outer cast operation is narrowing");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+ innerCastInput);
+
+ return success();
+ }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<NonNarrowingCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..aae5c0dbbb867 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1466,3 +1466,24 @@ func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32
+// CHECK: %[[OUT:.*]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+// CHECK: return %[[OUT]] : tensor<13x21x3xi32>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi32>
+ return %1 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i8
+// CHECK: return %arg0 : tensor<13x21x3xi8>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
>From a3375d0e581092a4dcd8c5f266f32693fdc86d7b Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 22 Jan 2026 09:15:46 +0000
Subject: [PATCH 2/2] Add negative test cases
Change-Id: I20cbc2e4b4f002f5aed57791c2ba20d6683f503f
---
mlir/test/Dialect/Tosa/canonicalize.mlir | 33 ++++++++++++++++++++++++
1 file changed, 33 insertions(+)
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index aae5c0dbbb867..81e537babf9ab 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1487,3 +1487,36 @@ func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi
%1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
return %1 : tensor<13x21x3xi8>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f32_to_f8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f32_to_f8(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf8E5M2> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+ return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8_to_i16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ return %1 : tensor<13x21x3xi16>
+}
More information about the Mlir-commits
mailing list