[Mlir-commits] [mlir] [mlir][arith] Improve `truncf` folding (PR #80206)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jan 31 17:22:34 PST 2024
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/80206
>From de125761be4758080e0e74cb5428a530aa07e102 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Jan 2024 16:52:26 -0500
Subject: [PATCH 1/3] [mlir][arith] Improve `truncf` folding
* Use APFloat conversion function instead of going through double to check
if fold results in information loss.
* Support folding vector constants.
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 30 +++++++++++------------
mlir/test/Dialect/Arith/canonicalize.mlir | 9 +++++++
2 files changed, 24 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ff72becc8dfa7..a02f7d6dd5053 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -22,8 +22,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
@@ -1393,23 +1395,21 @@ LogicalResult arith::TruncIOp::verify() {
// TruncFOp
//===----------------------------------------------------------------------===//
-/// Perform safe const propagation for truncf, i.e. only propagate if FP value
+/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss or rounding.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
- auto constOperand = adaptor.getIn();
- if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
- return {};
-
- // Convert to target type via 'double'.
- double sourceValue =
- llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
- auto targetAttr = FloatAttr::get(getType(), sourceValue);
-
- // Propagate if constant's value does not change after truncation.
- if (sourceValue == targetAttr.getValue().convertToDouble())
- return targetAttr;
-
- return {};
+ auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+ const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
+ return constFoldCastOp<FloatAttr, FloatAttr>(
+ adaptor.getOperands(), getType(),
+ [&targetSemantics](APFloat a, bool &castStatus) {
+ bool loosesInfo = false;
+ auto status =
+ a.convert(targetSemantics, llvm::RoundingMode::NearestTiesToEven,
+ &loosesInfo);
+ castStatus = !loosesInfo && status == APFloat::opOK;
+ return a;
+ });
}
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 10050d87d7568..44df11ab2433a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -825,6 +825,15 @@ func.func @truncFPConstant() -> bf16 {
return %0 : bf16
}
+// CHECK-LABEL: @truncFPVectorConstant
+// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
+// CHECK: return %[[cres]]
+func.func @truncFPVectorConstant() -> vector<2xbf16> {
+ %cst = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xf32>
+ %0 = arith.truncf %cst : vector<2xf32> to vector<2xbf16>
+ return %0 : vector<2xbf16>
+}
+
// Test that cases with rounding are NOT propagated
// CHECK-LABEL: @truncFPConstantRounding
// CHECK: arith.constant 1.444000e+25 : f32
>From 6d3302a0c4f56a643d1652418f428a4a5fc064f8 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Jan 2024 16:58:28 -0500
Subject: [PATCH 2/3] fix typo
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a02f7d6dd5053..c4af95e8443cb 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1403,11 +1403,10 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[&targetSemantics](APFloat a, bool &castStatus) {
- bool loosesInfo = false;
- auto status =
- a.convert(targetSemantics, llvm::RoundingMode::NearestTiesToEven,
- &loosesInfo);
- castStatus = !loosesInfo && status == APFloat::opOK;
+ bool losesInfo = false;
+ auto status = a.convert(
+ targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+ castStatus = !losesInfo && status == APFloat::opOK;
return a;
});
}
>From 85e1de23f6615c0dc2447697663d9786a0856d0e Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Jan 2024 20:22:23 -0500
Subject: [PATCH 3/3] Clarify why we care about information loss
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index c4af95e8443cb..270df3f3f9e99 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1396,7 +1396,8 @@ LogicalResult arith::TruncIOp::verify() {
//===----------------------------------------------------------------------===//
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
-/// can be represented without precision loss or rounding.
+/// can be represented without precision loss or rounding. This is because the
+/// semantics of `arith.truncf` do not assume a specific rounding mode.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
More information about the Mlir-commits
mailing list