[Mlir-commits] [mlir] [mlir][arith] Improve `extf` folder (PR #80232)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Feb 2 13:35:50 PST 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/80232
>From 66b97d42051b733c6450daaf1a0d8ab1096584e7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Jan 2024 22:16:17 -0500
Subject: [PATCH 1/3] [mlir][arith] Improve `extf` folder
* Use APFloat conversion function to avoid losing information by
converting to `double`. This would be the case with large types like
`f80` or `f128`.
* Check for potential information loss. This is intended for small
floating point types that may have values not present in larger ones
(e.g., negative zero).
* Support folding vector constants.
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 47 +++++++++++++++++------
mlir/test/Dialect/Arith/canonicalize.mlir | 9 +++++
2 files changed, 45 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 270df3f3f9e99..7b026310ab7d6 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -9,6 +9,7 @@
#include <cassert>
#include <cstdint>
#include <functional>
+#include <optional>
#include <utility>
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1258,6 +1259,21 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
srcType.getIntOrFloatBitWidth());
}
+/// Attempts to convert `sourceValue` to an APFloat value with
+/// `targetSemantics`, without any information loss or rounding. Return
+/// std::nullopt on failure.
+static std::optional<APFloat>
+convertFloatValue(APFloat sourceValue,
+ const llvm::fltSemantics &targetSemantics) {
+ bool losesInfo = false;
+ auto status = sourceValue.convert(
+ targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+ if (losesInfo || status != APFloat::opOK)
+ return std::nullopt;
+
+ return sourceValue;
+}
+
//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
@@ -1321,14 +1337,21 @@ LogicalResult arith::ExtSIOp::verify() {
// ExtFOp
//===----------------------------------------------------------------------===//
-/// Always fold extension of FP constants.
+/// Fold extension of float constants when there is no information loss due the
+/// difference in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
- auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
- if (!constOperand)
- return {};
+ auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+ const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+ return constFoldCastOp<FloatAttr, FloatAttr>(
+ adaptor.getOperands(), getType(),
+ [&targetSemantics](const APFloat &a, bool &castStatus) {
+ if (std::optional<APFloat> result =
+ convertFloatValue(a, targetSemantics))
+ return *result;
- // Convert to target type via 'double'.
- return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
+ castStatus = false;
+ return a;
+ });
}
bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@@ -1403,11 +1426,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
- [&targetSemantics](APFloat a, bool &castStatus) {
- bool losesInfo = false;
- auto status = a.convert(
- targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
- castStatus = !losesInfo && status == APFloat::opOK;
+ [&targetSemantics](const APFloat &a, bool &castStatus) {
+ if (std::optional<APFloat> result =
+ convertFloatValue(a, targetSemantics))
+ return *result;
+
+ castStatus = false;
return a;
});
}
@@ -1496,6 +1520,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
return apf;
});
}
+
//===----------------------------------------------------------------------===//
// FPToUIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 44df11ab2433a..57a71bcc8feeb 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -701,6 +701,15 @@ func.func @extFPConstant() -> f64 {
return %0 : f64
}
+// CHECK-LABEL: @extFPVectorConstant
+// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf128>
+// CHECK: return %[[cres]]
+func.func @extFPVectorConstant() -> vector<2xf128> {
+ %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf80>
+ %0 = arith.extf %cst : vector<2xf80> to vector<2xf128>
+ return %0 : vector<2xf128>
+}
+
// CHECK-LABEL: @truncConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]]
>From 04b05a2f031942baa8d76d92f174a1c11e74271e Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 2 Feb 2024 16:21:58 -0500
Subject: [PATCH 2/3] Address comments
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 32 +++++++++++------------
mlir/test/Dialect/Arith/canonicalize.mlir | 3 +++
2 files changed, 19 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7b026310ab7d6..5d8d40c877142 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
@@ -1260,16 +1261,15 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
}
/// Attempts to convert `sourceValue` to an APFloat value with
-/// `targetSemantics`, without any information loss or rounding. Return
-/// std::nullopt on failure.
-static std::optional<APFloat>
+/// `targetSemantics`, without any information loss or rounding.
+static FailureOr<APFloat>
convertFloatValue(APFloat sourceValue,
const llvm::fltSemantics &targetSemantics) {
bool losesInfo = false;
auto status = sourceValue.convert(
targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
if (losesInfo || status != APFloat::opOK)
- return std::nullopt;
+ return failure();
return sourceValue;
}
@@ -1345,12 +1345,12 @@ OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[&targetSemantics](const APFloat &a, bool &castStatus) {
- if (std::optional<APFloat> result =
- convertFloatValue(a, targetSemantics))
- return *result;
-
- castStatus = false;
- return a;
+ FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+ if (failed(result)) {
+ castStatus = false;
+ return a;
+ }
+ return *result;
});
}
@@ -1427,12 +1427,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[&targetSemantics](const APFloat &a, bool &castStatus) {
- if (std::optional<APFloat> result =
- convertFloatValue(a, targetSemantics))
- return *result;
-
- castStatus = false;
- return a;
+ FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+ if (failed(result)) {
+ castStatus = false;
+ return a;
+ }
+ return *result;
});
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 57a71bcc8feeb..f128b13e9f732 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -710,6 +710,9 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// TODO: We should also add a test for not folding arith.extf on information loss.
+// This may happen when extending f8E5M2FNUZ to f16.
+
// CHECK-LABEL: @truncConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]]
>From ae227f8658549dacb5c7f6c3a6195faeb2aae774 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 2 Feb 2024 16:35:40 -0500
Subject: [PATCH 3/3] Drop unused include
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5d8d40c877142..275c2debe9a6f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -9,7 +9,6 @@
#include <cassert>
#include <cstdint>
#include <functional>
-#include <optional>
#include <utility>
#include "mlir/Dialect/Arith/IR/Arith.h"
More information about the Mlir-commits
mailing list