[Mlir-commits] [mlir] 5294ad1 - [mlir][arith] Improve `extf` folder (#80232)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 2 15:05:51 PST 2024


Author: Jakub Kuderski
Date: 2024-02-02T18:05:47-05:00
New Revision: 5294ad1d5c995850ecc903ff2c3464d37cfb49c2

URL: https://github.com/llvm/llvm-project/commit/5294ad1d5c995850ecc903ff2c3464d37cfb49c2
DIFF: https://github.com/llvm/llvm-project/commit/5294ad1d5c995850ecc903ff2c3464d37cfb49c2.diff

LOG: [mlir][arith] Improve `extf` folder (#80232)

* Use APFloat conversion function to avoid losing information by
converting to `double`. This would be the case with large types like
`f80` or `f128`.
* Check for potential information loss. This is intended for small
floating point types that may have values not present in larger ones
(e.g., f8m2e5fnuz and f16).
* Support folding vector constants.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 270df3f3f9e99..275c2debe9a6f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
@@ -1258,6 +1259,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
                                      srcType.getIntOrFloatBitWidth());
 }
 
+/// Attempts to convert `sourceValue` to an APFloat value with
+/// `targetSemantics`, without any information loss or rounding.
+static FailureOr<APFloat>
+convertFloatValue(APFloat sourceValue,
+                  const llvm::fltSemantics &targetSemantics) {
+  bool losesInfo = false;
+  auto status = sourceValue.convert(
+      targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+  if (losesInfo || status != APFloat::opOK)
+    return failure();
+
+  return sourceValue;
+}
+
 //===----------------------------------------------------------------------===//
 // ExtUIOp
 //===----------------------------------------------------------------------===//
@@ -1321,14 +1336,21 @@ LogicalResult arith::ExtSIOp::verify() {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-/// Always fold extension of FP constants.
+/// Fold extension of float constants when there is no information loss due the
+/// 
diff erence in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
-  if (!constOperand)
-    return {};
-
-  // Convert to target type via 'double'.
-  return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+  const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+  return constFoldCastOp<FloatAttr, FloatAttr>(
+      adaptor.getOperands(), getType(),
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+        if (failed(result)) {
+          castStatus = false;
+          return a;
+        }
+        return *result;
+      });
 }
 
 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1403,12 +1425,13 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
-      [&targetSemantics](APFloat a, bool &castStatus) {
-        bool losesInfo = false;
-        auto status = a.convert(
-            targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
-        castStatus = !losesInfo && status == APFloat::opOK;
-        return a;
+      [&targetSemantics](const APFloat &a, bool &castStatus) {
+        FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+        if (failed(result)) {
+          castStatus = false;
+          return a;
+        }
+        return *result;
       });
 }
 
@@ -1496,6 +1519,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
         return apf;
       });
 }
+
 //===----------------------------------------------------------------------===//
 // FPToUIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 44df11ab2433a..f128b13e9f732 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -701,6 +701,18 @@ func.func @extFPConstant() -> f64 {
   return %0 : f64
 }
 
+// CHECK-LABEL: @extFPVectorConstant
+//       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf128>
+//       CHECK:   return %[[cres]]
+func.func @extFPVectorConstant() -> vector<2xf128> {
+  %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf80>
+  %0 = arith.extf %cst : vector<2xf80> to vector<2xf128>
+  return %0 : vector<2xf128>
+}
+
+// TODO: We should also add a test for not folding arith.extf on information loss.
+// This may happen when extending f8E5M2FNUZ to f16.
+
 // CHECK-LABEL: @truncConstant
 //       CHECK:   %[[cres:.+]] = arith.constant -2 : i16
 //       CHECK:   return %[[cres]]


        


More information about the Mlir-commits mailing list