[Mlir-commits] [mlir] [MLIR][ARITH] Adds missing foldings for truncf (PR #128096)
Zahi Moudallal
llvmlistbot at llvm.org
Thu Feb 20 16:48:06 PST 2025
https://github.com/zahimoud created https://github.com/llvm/llvm-project/pull/128096
This patch is mainly to deal with folding truncf, as follows:
truncf(extf(a)) -> a, if `a` has the same bitwidth as the result
truncf(extf(a)) -> truncf(a), if `a` has larger bitwidth than the result
truncf(truncf(a)) -> truncf(a), in any case
>From 913cf19fd8a64a5da8ae4a94f88ad5f12bb72717 Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 13:55:14 -0800
Subject: [PATCH 1/2] [MLIR][ARITH] Fold extf followed bt truncf and truncf
followed by truncf
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 22 +++++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 29 +++++++++++++++++++++++
2 files changed, 51 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..28de22d38571d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
+ Value src = getOperand().getDefiningOp()->getOperand(0);
+ Type srcType = getElementTypeOrSelf(src.getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ // truncf(extf(a)) -> truncf(a)
+ if (llvm::cast<FloatType>(srcType).getWidth() >
+ llvm::cast<FloatType>(dstType).getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+
+ // truncf(extf(a)) -> a
+ if (srcType == dstType)
+ return src;
+ }
+
+ // truncf(truncf(a)) -> truncf(a)
+ if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
+ setOperand(getOperand().getDefiningOp()->getOperand(0));
+ return getResult();
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..24c6fce636097 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// CHECK-LABEL: @truncExtf
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %trunc = arith.truncf %extf : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %truncf = arith.truncf %extf : f64 to f16
+ return %truncf : f16
+}
+
+// CHECK-LABEL: @truncExtf3
+// CHECK: %[[ARG0:.+]]: f8
+// CHECK: %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
+// CHECK: return %[[CST:.*]] : f16
+func.func @truncExtf3(%arg0: f8) -> f16 {
+ %extf = arith.extf %arg0 : f8 to f32
+ %truncf = arith.truncf %extf : f32 to f16
+ return %trunci : f16
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
>From 695c83ffc99df97ec45c45b3661be9dd2f5d043b Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 16:42:36 -0800
Subject: [PATCH 2/2] Fixes
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 10 +++++-----
.../Conversion/ArithToEmitC/arith-to-emitc.mlir | 3 +--
mlir/test/Dialect/Arith/canonicalize.mlir | 16 ++++++++--------
3 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 28de22d38571d..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,8 +1517,8 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
- if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
- Value src = getOperand().getDefiningOp()->getOperand(0);
+ if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+ Value src = extOp.getIn();
Type srcType = getElementTypeOrSelf(src.getType());
Type dstType = getElementTypeOrSelf(getType());
// truncf(extf(a)) -> truncf(a)
@@ -1534,11 +1534,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
}
// truncf(truncf(a)) -> truncf(a)
- if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
- setOperand(getOperand().getDefiningOp()->getOperand(0));
+ if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+ setOperand(truncOp.getIn());
return getResult();
}
-
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
- // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
%truncd0 = arith.truncf %arg0 : f64 to f32
- // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 24c6fce636097..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -733,14 +733,14 @@ func.func @truncExtf2(%arg0: f32) -> f16 {
return %truncf : f16
}
-// CHECK-LABEL: @truncExtf3
-// CHECK: %[[ARG0:.+]]: f8
-// CHECK: %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
-// CHECK: return %[[CST:.*]] : f16
-func.func @truncExtf3(%arg0: f8) -> f16 {
- %extf = arith.extf %arg0 : f8 to f32
- %truncf = arith.truncf %extf : f32 to f16
- return %trunci : f16
+// CHECK-LABEL: @truncTruncf
+// CHECK: %[[ARG0:.+]]: f64
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+ %truncf = arith.truncf %arg0 : f64 to f32
+ %truncf1 = arith.truncf %truncf : f32 to f16
+ return %truncf1 : f16
}
// TODO: We should also add a test for not folding arith.extf on information loss.
More information about the Mlir-commits
mailing list