[Mlir-commits] [mlir] [mlir][arith] Improve `truncf` folding (PR #80206)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jan 31 14:17:49 PST 2024


================
@@ -1393,23 +1395,20 @@ LogicalResult arith::TruncIOp::verify() {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-/// Perform safe const propagation for truncf, i.e. only propagate if FP value
+/// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss or rounding.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = adaptor.getIn();
-  if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
-    return {};
-
-  // Convert to target type via 'double'.
-  double sourceValue =
-      llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
-  auto targetAttr = FloatAttr::get(getType(), sourceValue);
-
-  // Propagate if constant's value does not change after truncation.
-  if (sourceValue == targetAttr.getValue().convertToDouble())
-    return targetAttr;
-
-  return {};
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](APFloat a, bool &castStatus) {
+        bool losesInfo = false;
+        auto status = a.convert(
+            targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+        castStatus = !losesInfo && status == APFloat::opOK;
+        return a;
----------------
kuhar wrote:

arith doesn't specify the rounding mode while spirv doesn't have a single static default. We would have to change the op semantics and then decide what to do with spirv lowering so that the runtime semantics and constant folding don't diverge.

https://github.com/llvm/llvm-project/pull/80206


More information about the Mlir-commits mailing list