[Mlir-commits] [mlir] f010621 - [mlir][tosa] Add a canonicalization to optimize cast cast sequences (#176904)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 06:16:12 PST 2026
Author: Luke Hutton
Date: 2026-01-23T14:16:07Z
New Revision: f010621809b3fab9f3d64fa39699ea5c0c11e699
URL: https://github.com/llvm/llvm-project/commit/f010621809b3fab9f3d64fa39699ea5c0c11e699
DIFF: https://github.com/llvm/llvm-project/commit/f010621809b3fab9f3d64fa39699ea5c0c11e699.diff
LOG: [mlir][tosa] Add a canonicalization to optimize cast cast sequences (#176904)
This commit introduces a new canonicalization over a sequence of cast
operations. cast->cast sequences can be simplified to a single cast when
no narrowing is performed inbetween. This optimization is limited to
integer types, since floating point casts may impact numerical
behaviour.
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index edd8f0fc266bb..0005b6f5a0c63 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2507,6 +2507,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 14639ef5925ae..5f41c8c3f300f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -885,6 +885,55 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
SliceDynamicSizeCanonicalization>(context);
}
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+ using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ const Value castInput = castOp.getInput();
+ auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+ if (!innerCastOp)
+ return rewriter.notifyMatchFailure(castOp,
+ "input must be cast operation");
+
+ const Value innerCastInput = innerCastOp.getInput();
+
+ const auto innerInputType =
+ llvm::cast<ShapedType>(innerCastInput.getType());
+ const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+ const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+ const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+ outerOutputType};
+ if (llvm::any_of(types, [](const ShapedType type) {
+ return !type.getElementType().isInteger();
+ }))
+ return rewriter.notifyMatchFailure(castOp,
+ "only integer types are supported");
+
+ // Check inner cast is non-narrowing
+ const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+ if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "inner cast operation is narrowing");
+
+ // Check outer cast is non-narrowing from the inner cast input
+ if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "outer cast operation is narrowing");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+ innerCastInput);
+
+ return success();
+ }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<NonNarrowingCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 84776c47b628d..81e537babf9ab 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1466,3 +1466,57 @@ func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i32
+// CHECK: %[[OUT:.*]] = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi32>
+// CHECK: return %[[OUT]] : tensor<13x21x3xi32>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i32(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi32>
+ return %1 : tensor<13x21x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_i8_to_i8
+// CHECK: return %arg0 : tensor<13x21x3xi8>
+func.func @test_canonicalize_non_narrowing_cast_i8_to_i8(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_non_narrowing_cast_f32_to_f8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_non_narrowing_cast_f32_to_f8(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf8E5M2> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xf16>) -> tensor<13x21x3xf8E5M2>
+ return %1 : tensor<13x21x3xf8E5M2>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi16>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi16>) -> tensor<13x21x3xi8>
+ return %1 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_canonicalize_narrowing_cast_i32_to_i8_to_i16
+// CHECK: tosa.cast
+// CHECK: tosa.cast
+func.func @test_canonicalize_narrowing_cast_i32_to_i8_to_i16(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi16> {
+ %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
+ %1 = tosa.cast %0 : (tensor<13x21x3xi8>) -> tensor<13x21x3xi16>
+ return %1 : tensor<13x21x3xi16>
+}
More information about the Mlir-commits
mailing list