[Mlir-commits] [mlir] [mlir][arith] Canonicalize sitofp(truncf) -> sitofp, and uitofp. (PR #139925)

Christian Sigg llvmlistbot at llvm.org
Mon May 19 01:29:05 PDT 2025


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/139925

>From 4381c875221839cebf1b0278bfc50c14fc9a4037 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Wed, 14 May 2025 17:45:09 +0200
Subject: [PATCH 1/3] [mlir][arith] Canonicalize sitofp(truncf) -> sitofp, or
 uitofp.

Add a canonicalization pattern that simplifies `truncf(sitofp(x))` to
`sitofp(x)`.

This assumes that the destination type of truncf is representable by the source type.
Note that the truncf semantics requires that the destination type is narrower than
the source type, so this is true for all fp types I can possibly think of, but one
could probably construct a artificial counter example.
---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td     |  1 +
 mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 14 ++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp             |  5 +++++
 mlir/test/Dialect/Arith/canonicalize.mlir          | 10 ++++++++++
 4 files changed, 30 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index d50b6aeca15c9..599b3b982ec7f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1273,6 +1273,7 @@ def Arith_TruncFOp :
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
   let assemblyFormat = [{ $in ($roundingmode^)?
                           (`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7e212df9029d1..7be73c4343639 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -419,6 +419,20 @@ def TruncIShrUIMulIToMulUIExtended :
        (ValueWiderThan $mul, $x),
        (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
 
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFSIToFPToSIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_SIToFPOp $x)>;
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFUIToFPToUIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_UIToFPOp $x)>;
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 3b308716c84dc..41f2d0f3425e2 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
+}
+
 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
 }
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d62c5b18fd041..a5dab73a62fac 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -753,6 +753,16 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
   return %truncf : f16
 }
 
+// CHECK-LABEL: @truncSitofp
+//       CHECK-NOT: truncf
+//       CHECK:     %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
+//       CHECK:     return %[[SITOFP]]
+func.func @truncSitofp(%arg0: i32) -> f32 {
+  %sitofp = arith.sitofp %arg0 : i32 to f64
+  %trunc = arith.truncf %sitofp : f64 to f32
+  return %trunc : f32
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

>From c3875c18fe2cca8a90cc5d45e020615098cb1922 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 16 May 2025 10:55:14 +0200
Subject: [PATCH 2/3] Restrict to default rounding mode.

---
 mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7be73c4343639..13eb97a910bd4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -423,15 +423,17 @@ def TruncIShrUIMulIToMulUIExtended :
 // TruncIOp
 //===----------------------------------------------------------------------===//
 
-// truncf(sitofp(x)) -> sitofp(x).
+// truncf(sitofp(x)) -> sitofp(x) if default rounding mode.
 def TruncFSIToFPToSIToFP :
     Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
-        (Arith_SIToFPOp $x)>;
+        (Arith_SIToFPOp $x),
+        [(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
 
-// truncf(sitofp(x)) -> sitofp(x).
+// truncf(uitofp(x)) -> uitofp(x) if default rounding mode.
 def TruncFUIToFPToUIToFP :
     Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
-        (Arith_UIToFPOp $x)>;
+        (Arith_UIToFPOp $x),
+        [(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
 
 //===----------------------------------------------------------------------===//
 // MulFOp

>From bd58ab77958f1e89b285b5b2bcfcdb713dbde670 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Mon, 19 May 2025 10:27:49 +0200
Subject: [PATCH 3/3] Add test for contrained (with rounding mode) case.

---
 mlir/test/Dialect/Arith/canonicalize.mlir | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a5dab73a62fac..b6188c81ff912 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -754,8 +754,8 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
 }
 
 // CHECK-LABEL: @truncSitofp
-//       CHECK-NOT: truncf
 //       CHECK:     %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
+//       CHECK-NOT: truncf
 //       CHECK:     return %[[SITOFP]]
 func.func @truncSitofp(%arg0: i32) -> f32 {
   %sitofp = arith.sitofp %arg0 : i32 to f64
@@ -763,6 +763,14 @@ func.func @truncSitofp(%arg0: i32) -> f32 {
   return %trunc : f32
 }
 
+// CHECK-LABEL: @truncSitofpConstrained
+//       CHECK: truncf
+func.func @truncSitofpConstrained(%arg0: i32) -> f32 {
+  %sitofp = arith.sitofp %arg0 : i32 to f64
+  %trunc = arith.truncf %sitofp to_nearest_even : f64 to f32
+  return %trunc : f32
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 



More information about the Mlir-commits mailing list