[Mlir-commits] [mlir] c56e7f2 - [mlir][arith] Canonicalize sitofp(truncf) -> sitofp, and uitofp. (#139925)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 19 06:07:33 PDT 2025
Author: Christian Sigg
Date: 2025-05-19T15:07:30+02:00
New Revision: c56e7f22f06ac52d2ef3ea487910ab60a1256138
URL: https://github.com/llvm/llvm-project/commit/c56e7f22f06ac52d2ef3ea487910ab60a1256138
DIFF: https://github.com/llvm/llvm-project/commit/c56e7f22f06ac52d2ef3ea487910ab60a1256138.diff
LOG: [mlir][arith] Canonicalize sitofp(truncf) -> sitofp, and uitofp. (#139925)
Add a canonicalization patterns that simplifies `truncf(sitofp(x))` to
`sitofp(x)` and `truncf(uitofp(x))` to `uitofp(x)`, if truncf has default rounding mode.
This assumes that the destination type of truncf is representable by the
intermediate type.
Note that the truncf semantics requires that the destination type is
narrower than the source type, so this is true for all types I can
possibly think of, but one could probably construct an artificial
counter example.
Somewhat related: https://github.com/llvm/llvm-project/pull/128096
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index d50b6aeca15c9..599b3b982ec7f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1273,6 +1273,7 @@ def Arith_TruncFOp :
];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
let assemblyFormat = [{ $in ($roundingmode^)?
(`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7e212df9029d1..13eb97a910bd4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -419,6 +419,22 @@ def TruncIShrUIMulIToMulUIExtended :
(ValueWiderThan $mul, $x),
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+// truncf(sitofp(x)) -> sitofp(x) if default rounding mode.
+def TruncFSIToFPToSIToFP :
+ Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
+ (Arith_SIToFPOp $x),
+ [(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
+
+// truncf(uitofp(x)) -> uitofp(x) if default rounding mode.
+def TruncFUIToFPToUIToFP :
+ Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
+ (Arith_UIToFPOp $x),
+ [(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 3b308716c84dc..41f2d0f3425e2 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
});
}
+void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
+}
+
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d62c5b18fd041..b6188c81ff912 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -753,6 +753,24 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
return %truncf : f16
}
+// CHECK-LABEL: @truncSitofp
+// CHECK: %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
+// CHECK-NOT: truncf
+// CHECK: return %[[SITOFP]]
+func.func @truncSitofp(%arg0: i32) -> f32 {
+ %sitofp = arith.sitofp %arg0 : i32 to f64
+ %trunc = arith.truncf %sitofp : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncSitofpConstrained
+// CHECK: truncf
+func.func @truncSitofpConstrained(%arg0: i32) -> f32 {
+ %sitofp = arith.sitofp %arg0 : i32 to f64
+ %trunc = arith.truncf %sitofp to_nearest_even : f64 to f32
+ return %trunc : f32
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
More information about the Mlir-commits
mailing list