[Mlir-commits] [mlir] [MLIR][ARITH] Adds missing foldings for truncf (PR #128096)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 20 16:48:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Zahi Moudallal (zahimoud)

<details>
<summary>Changes</summary>

This patch is mainly to deal with folding truncf, as follows:
truncf(extf(a)) -> a, if `a` has the same bitwidth as the result
truncf(extf(a)) -> truncf(a), if `a` has larger bitwidth than the result
truncf(truncf(a)) -> truncf(a), in any case

---
Full diff: https://github.com/llvm/llvm-project/pull/128096.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+22) 
- (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+1-2) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+29) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+    Value src = extOp.getIn();
+    Type srcType = getElementTypeOrSelf(src.getType());
+    Type dstType = getElementTypeOrSelf(getType());
+    // truncf(extf(a)) -> truncf(a)
+    if (llvm::cast<FloatType>(srcType).getWidth() >
+        llvm::cast<FloatType>(dstType).getWidth()) {
+      setOperand(src);
+      return getResult();
+    }
+
+    // truncf(extf(a)) -> a
+    if (srcType == dstType)
+      return src;
+  }
+
+  // truncf(truncf(a)) -> truncf(a)
+  if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+    setOperand(truncOp.getIn());
+    return getResult();
+  }
+
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
 func.func @arith_truncf(%arg0: f64) -> f16 {
   // CHECK-LABEL: arith_truncf
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
-  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
   %truncd0 = arith.truncf %arg0 : f64 to f32
-  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
   %truncd1 = arith.truncf %truncd0 : f32 to f16
 
   return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
   return %0 : vector<2xf128>
 }
 
+// CHECK-LABEL: @truncExtf
+//       CHECK-NOT:  truncf
+//       CHECK:   return  %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %trunc = arith.truncf %extf : f64 to f32
+  return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+//       CHECK:  %[[ARG0:.+]]: f32
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %truncf = arith.truncf %extf : f64 to f16
+  return %truncf : f16
+}
+
+// CHECK-LABEL: @truncTruncf
+//       CHECK:  %[[ARG0:.+]]: f64
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+  %truncf = arith.truncf %arg0 : f64 to f32
+  %truncf1 = arith.truncf %truncf : f32 to f16
+  return %truncf1 : f16
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/128096


More information about the Mlir-commits mailing list