[Mlir-commits] [mlir] [mlir][tosa] Add a canonicalization to optimize cast cast sequences (PR #176904)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 20 03:47:11 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit introduces a new canonicalization over a sequence of cast operations. cast->cast sequences can be simplified to a single cast when no narrowing is performed inbetween. This optimization is limited to integer types, since floating point casts may impact numerical behaviour.
---
Full diff: https://github.com/llvm/llvm-project/pull/176904.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+49)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index edd8f0fc266bb..0005b6f5a0c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2507,6 +2507,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..ab864e9b9b6a3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -885,6 +885,55 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
SliceDynamicSizeCanonicalization>(context);
}
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+ using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ const Value castInput = castOp.getInput();
+ auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+ if (!innerCastOp)
+ return rewriter.notifyMatchFailure(castOp,
+ "input must be cast operation");
+
+ const Value innerCastInput = innerCastOp.getInput();
+
+ const auto innerInputType =
+ llvm::cast<ShapedType>(innerCastInput.getType());
+ const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+ const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+ const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+ outerOutputType};
+ if (llvm::any_of(types, [](const ShapedType type) {
+ return !type.getElementType().isInteger();
+ }))
+ return rewriter.notifyMatchFailure(castOp,
+ "only integer types are supported");
+
+ // Check inner cast is non-narrowing
+ const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+ if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "inner cast operation is narrowing");
+
+ // Check outer cast is non-narrowing from the inner cast input
+ if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "outer cast operation is narrowing");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+ innerCastInput);
+
+ return success();
+ }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<NonNarrowingCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..aae5c0dbbb867 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1466,3 +1466,24 @@ func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32
+// CHECK: %[[OUT:.*]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+// CHECK: return %[[OUT]] : tensor<13x21x3xi32>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi32>
+ return %1 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i8
+// CHECK: return %arg0 : tensor<13x21x3xi8>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/176904
More information about the Mlir-commits
mailing list