[Mlir-commits] [mlir] [MLIR][ARITH] Adds missing foldings for truncf (PR #128096)
Zahi Moudallal
llvmlistbot at llvm.org
Fri Feb 21 11:47:03 PST 2025
https://github.com/zahimoud updated https://github.com/llvm/llvm-project/pull/128096
>From 913cf19fd8a64a5da8ae4a94f88ad5f12bb72717 Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 13:55:14 -0800
Subject: [PATCH 1/5] [MLIR][ARITH] Fold extf followed bt truncf and truncf
followed by truncf
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 22 +++++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 29 +++++++++++++++++++++++
2 files changed, 51 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..28de22d38571d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
+ Value src = getOperand().getDefiningOp()->getOperand(0);
+ Type srcType = getElementTypeOrSelf(src.getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ // truncf(extf(a)) -> truncf(a)
+ if (llvm::cast<FloatType>(srcType).getWidth() >
+ llvm::cast<FloatType>(dstType).getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+
+ // truncf(extf(a)) -> a
+ if (srcType == dstType)
+ return src;
+ }
+
+ // truncf(truncf(a)) -> truncf(a)
+ if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
+ setOperand(getOperand().getDefiningOp()->getOperand(0));
+ return getResult();
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..24c6fce636097 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// CHECK-LABEL: @truncExtf
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %trunc = arith.truncf %extf : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %truncf = arith.truncf %extf : f64 to f16
+ return %truncf : f16
+}
+
+// CHECK-LABEL: @truncExtf3
+// CHECK: %[[ARG0:.+]]: f8
+// CHECK: %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
+// CHECK: return %[[CST:.*]] : f16
+func.func @truncExtf3(%arg0: f8) -> f16 {
+ %extf = arith.extf %arg0 : f8 to f32
+ %truncf = arith.truncf %extf : f32 to f16
+ return %trunci : f16
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
>From 695c83ffc99df97ec45c45b3661be9dd2f5d043b Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 16:42:36 -0800
Subject: [PATCH 2/5] Fixes
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 10 +++++-----
.../Conversion/ArithToEmitC/arith-to-emitc.mlir | 3 +--
mlir/test/Dialect/Arith/canonicalize.mlir | 16 ++++++++--------
3 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 28de22d38571d..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,8 +1517,8 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
- if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
- Value src = getOperand().getDefiningOp()->getOperand(0);
+ if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+ Value src = extOp.getIn();
Type srcType = getElementTypeOrSelf(src.getType());
Type dstType = getElementTypeOrSelf(getType());
// truncf(extf(a)) -> truncf(a)
@@ -1534,11 +1534,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
}
// truncf(truncf(a)) -> truncf(a)
- if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
- setOperand(getOperand().getDefiningOp()->getOperand(0));
+ if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+ setOperand(truncOp.getIn());
return getResult();
}
-
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
- // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
%truncd0 = arith.truncf %arg0 : f64 to f32
- // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 24c6fce636097..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -733,14 +733,14 @@ func.func @truncExtf2(%arg0: f32) -> f16 {
return %truncf : f16
}
-// CHECK-LABEL: @truncExtf3
-// CHECK: %[[ARG0:.+]]: f8
-// CHECK: %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
-// CHECK: return %[[CST:.*]] : f16
-func.func @truncExtf3(%arg0: f8) -> f16 {
- %extf = arith.extf %arg0 : f8 to f32
- %truncf = arith.truncf %extf : f32 to f16
- return %trunci : f16
+// CHECK-LABEL: @truncTruncf
+// CHECK: %[[ARG0:.+]]: f64
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+ %truncf = arith.truncf %arg0 : f64 to f32
+ %truncf1 = arith.truncf %truncf : f32 to f16
+ return %truncf1 : f16
}
// TODO: We should also add a test for not folding arith.extf on information loss.
>From f66b1d374b2c1f0b6d7861346785dd86068768ee Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 17:19:45 -0800
Subject: [PATCH 3/5] Remove truncf(truncf(a)) folding
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 6 ------
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir | 3 ++-
mlir/test/Dialect/Arith/canonicalize.mlir | 10 ----------
3 files changed, 2 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 494985fbce94e..a3e12f164fa19 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1533,12 +1533,6 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
return src;
}
- // truncf(truncf(a)) -> truncf(a)
- if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
- setOperand(truncOp.getIn());
- return getResult();
- }
-
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cebdebef85dc9..cb1d092918f03 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,8 +764,9 @@ func.func @arith_extf(%arg0: f16) -> f64 {
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
- // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
%truncd0 = arith.truncf %arg0 : f64 to f32
+ // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index aa4136cd6361e..a09991d966cf0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -733,16 +733,6 @@ func.func @truncExtf2(%arg0: f32) -> f16 {
return %truncf : f16
}
-// CHECK-LABEL: @truncTruncf
-// CHECK: %[[ARG0:.+]]: f64
-// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
-// CHECK: return %[[CST:.*]]
-func.func @truncTruncf(%arg0: f64) -> f16 {
- %truncf = arith.truncf %arg0 : f64 to f32
- %truncf1 = arith.truncf %truncf : f32 to f16
- return %truncf1 : f16
-}
-
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
>From 35a039b9086190be62c33b62aa9a3a3c97eba5ae Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Fri, 21 Feb 2025 11:40:04 -0800
Subject: [PATCH 4/5] Some fixes
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 26 ++++++++++++-----------
mlir/test/Dialect/Arith/canonicalize.mlir | 22 ++++++++++++++++++-
2 files changed, 35 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a3e12f164fa19..01ce8f2970ba3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,23 +1517,25 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
Value src = extOp.getIn();
- Type srcType = getElementTypeOrSelf(src.getType());
- Type dstType = getElementTypeOrSelf(getType());
- // truncf(extf(a)) -> truncf(a)
- if (llvm::cast<FloatType>(srcType).getWidth() >
- llvm::cast<FloatType>(dstType).getWidth()) {
- setOperand(src);
- return getResult();
- }
+ auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
+ auto intermediateType = cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
+ // Check if the srcType is representable in the intermediateType
+ if(llvm::APFloatBase::isRepresentableBy(srcType.getFloatSemantics(), intermediateType.getFloatSemantics())) {
+ // truncf(extf(a)) -> truncf(a)
+ if (srcType.getWidth() > resElemType.getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
- // truncf(extf(a)) -> a
- if (srcType == dstType)
- return src;
+ // truncf(extf(a)) -> a
+ if (srcType == resElemType)
+ return src;
+ }
}
- auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a09991d966cf0..f0b2731707d18 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -723,11 +723,31 @@ func.func @truncExtf(%arg0: f32) -> f32 {
return %trunc : f32
}
+// CHECK-LABEL: @truncExtf1
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf1(%arg0: bf16) -> bf16 {
+ %extf = arith.extf %arg0 : bf16 to f32
+ %trunc = arith.truncf %extf : f32 to bf16
+ return %trunc : bf16
+}
+
// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: bf16
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32
+// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16
+// CHECK: return %[[TRUNCF:.*]]
+func.func @truncExtf2(%arg0: bf16) -> f16 {
+ %extf = arith.extf %arg0 : bf16 to f32
+ %trunc = arith.truncf %extf : f32 to f16
+ return %trunc : f16
+}
+
+// CHECK-LABEL: @truncExtf3
// CHECK: %[[ARG0:.+]]: f32
// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
// CHECK: return %[[CST:.*]]
-func.func @truncExtf2(%arg0: f32) -> f16 {
+func.func @truncExtf3(%arg0: f32) -> f16 {
%extf = arith.extf %arg0 : f32 to f64
%truncf = arith.truncf %extf : f64 to f16
return %truncf : f16
>From b8c749c205305e7f61e6f73fa0e846ee1643e19f Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Fri, 21 Feb 2025 11:46:48 -0800
Subject: [PATCH 5/5] Format
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 01ce8f2970ba3..afa0540013b71 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1521,9 +1521,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
Value src = extOp.getIn();
auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
- auto intermediateType = cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
+ auto intermediateType =
+ cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
// Check if the srcType is representable in the intermediateType
- if(llvm::APFloatBase::isRepresentableBy(srcType.getFloatSemantics(), intermediateType.getFloatSemantics())) {
+ if (llvm::APFloatBase::isRepresentableBy(
+ srcType.getFloatSemantics(),
+ intermediateType.getFloatSemantics())) {
// truncf(extf(a)) -> truncf(a)
if (srcType.getWidth() > resElemType.getWidth()) {
setOperand(src);
@@ -2397,12 +2400,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
- if (auto cond =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
- if (auto lhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
- if (auto rhs =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+ if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getCondition())) {
+ if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getTrueValue())) {
+ if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+ adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2670,7 +2673,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minimumf:
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxnumf:
+ case AtomicRMWKind::maxnumf:
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
case AtomicRMWKind::minnumf:
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
More information about the Mlir-commits
mailing list