[Mlir-commits] [mlir] 730f498 - [mlir][arith] Improve `truncf` folding (#80206)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 31 17:28:05 PST 2024


Author: Jakub Kuderski
Date: 2024-01-31T20:28:01-05:00
New Revision: 730f498c961f29691a605028f9b1cd6d9e232460

URL: https://github.com/llvm/llvm-project/commit/730f498c961f29691a605028f9b1cd6d9e232460
DIFF: https://github.com/llvm/llvm-project/commit/730f498c961f29691a605028f9b1cd6d9e232460.diff

LOG: [mlir][arith] Improve `truncf` folding (#80206)

* Use APFloat conversion function instead of going through double to
check if fold results in information loss.
* Support folding vector constants.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa7..270df3f3f9e99 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -22,8 +22,10 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
@@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-/// Perform safe const propagation for truncf, i.e. only propagate if FP value
-/// can be represented without precision loss or rounding.
+/// Perform safe const propagation for truncf, i.e., only propagate if FP value
+/// can be represented without precision loss or rounding. This is because the
+/// semantics of `arith.truncf` do not assume a specific rounding mode.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = adaptor.getIn();
-  if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
-    return {};
-
-  // Convert to target type via 'double'.
-  double sourceValue =
-      llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
-  auto targetAttr = FloatAttr::get(getType(), sourceValue);
-
-  // Propagate if constant's value does not change after truncation.
-  if (sourceValue == targetAttr.getValue().convertToDouble())
-    return targetAttr;
-
-  return {};
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](APFloat a, bool &castStatus) {
+        bool losesInfo = false;
+        auto status = a.convert(
+            targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+        castStatus = !losesInfo && status == APFloat::opOK;
+        return a;
+      });
 }
 
 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 10050d87d7568..44df11ab2433a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -825,6 +825,15 @@ func.func @truncFPConstant() -> bf16 {
   return %0 : bf16
 }
 
+// CHECK-LABEL: @truncFPVectorConstant
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
+//       CHECK:   return %[[cres]]
+func.func @truncFPVectorConstant() -> vector<2xbf16> {
+  %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf32>
+  %0 = arith.truncf %cst : vector<2xf32> to vector<2xbf16>
+  return %0 : vector<2xbf16>
+}
+
 // Test that cases with rounding are NOT propagated
 // CHECK-LABEL: @truncFPConstantRounding
 //       CHECK:   arith.constant 1.444000e+25 : f32


        


More information about the Mlir-commits mailing list