[Mlir-commits] [mlir] [MLIR][ARITH] Adds missing foldings for truncf (PR #128096)

Zahi Moudallal llvmlistbot at llvm.org
Fri Feb 21 11:47:03 PST 2025


https://github.com/zahimoud updated https://github.com/llvm/llvm-project/pull/128096

>From 913cf19fd8a64a5da8ae4a94f88ad5f12bb72717 Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 13:55:14 -0800
Subject: [PATCH 1/5] [MLIR][ARITH] Fold extf followed bt truncf and truncf
 followed by truncf

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp    | 22 +++++++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir | 29 +++++++++++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..28de22d38571d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+  if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
+    Value src = getOperand().getDefiningOp()->getOperand(0);
+    Type srcType = getElementTypeOrSelf(src.getType());
+    Type dstType = getElementTypeOrSelf(getType());
+    // truncf(extf(a)) -> truncf(a)
+    if (llvm::cast<FloatType>(srcType).getWidth() >
+        llvm::cast<FloatType>(dstType).getWidth()) {
+      setOperand(src);
+      return getResult();
+    }
+
+    // truncf(extf(a)) -> a
+    if (srcType == dstType)
+      return src;
+  }
+
+  // truncf(truncf(a)) -> truncf(a)
+  if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
+    setOperand(getOperand().getDefiningOp()->getOperand(0));
+    return getResult();
+  }
+  
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..24c6fce636097 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
   return %0 : vector<2xf128>
 }
 
+// CHECK-LABEL: @truncExtf
+//       CHECK-NOT:  truncf
+//       CHECK:   return  %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %trunc = arith.truncf %extf : f64 to f32
+  return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+//       CHECK:  %[[ARG0:.+]]: f32
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %truncf = arith.truncf %extf : f64 to f16
+  return %truncf : f16
+}
+
+// CHECK-LABEL: @truncExtf3
+//       CHECK:  %[[ARG0:.+]]: f8
+//       CHECK:  %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
+//       CHECK:   return  %[[CST:.*]] : f16
+func.func @truncExtf3(%arg0: f8) -> f16 {
+  %extf = arith.extf %arg0 : f8 to f32
+  %truncf = arith.truncf %extf : f32 to f16
+  return %trunci : f16
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

>From 695c83ffc99df97ec45c45b3661be9dd2f5d043b Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 16:42:36 -0800
Subject: [PATCH 2/5] Fixes

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp           | 10 +++++-----
 .../Conversion/ArithToEmitC/arith-to-emitc.mlir  |  3 +--
 mlir/test/Dialect/Arith/canonicalize.mlir        | 16 ++++++++--------
 3 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 28de22d38571d..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,8 +1517,8 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
-  if (matchPattern(getOperand(), m_Op<arith::ExtFOp>())) {
-    Value src = getOperand().getDefiningOp()->getOperand(0);
+  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+    Value src = extOp.getIn();
     Type srcType = getElementTypeOrSelf(src.getType());
     Type dstType = getElementTypeOrSelf(getType());
     // truncf(extf(a)) -> truncf(a)
@@ -1534,11 +1534,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   }
 
   // truncf(truncf(a)) -> truncf(a)
-  if (matchPattern(getOperand(), m_Op<arith::TruncFOp>())) {
-    setOperand(getOperand().getDefiningOp()->getOperand(0));
+  if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+    setOperand(truncOp.getIn());
     return getResult();
   }
-  
+
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
 func.func @arith_truncf(%arg0: f64) -> f16 {
   // CHECK-LABEL: arith_truncf
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
-  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
   %truncd0 = arith.truncf %arg0 : f64 to f32
-  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
   %truncd1 = arith.truncf %truncd0 : f32 to f16
 
   return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 24c6fce636097..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -733,14 +733,14 @@ func.func @truncExtf2(%arg0: f32) -> f16 {
   return %truncf : f16
 }
 
-// CHECK-LABEL: @truncExtf3
-//       CHECK:  %[[ARG0:.+]]: f8
-//       CHECK:  %[[CST:.*]] = arith.extf %[[ARG0:.+]] : f8 to f16
-//       CHECK:   return  %[[CST:.*]] : f16
-func.func @truncExtf3(%arg0: f8) -> f16 {
-  %extf = arith.extf %arg0 : f8 to f32
-  %truncf = arith.truncf %extf : f32 to f16
-  return %trunci : f16
+// CHECK-LABEL: @truncTruncf
+//       CHECK:  %[[ARG0:.+]]: f64
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+  %truncf = arith.truncf %arg0 : f64 to f32
+  %truncf1 = arith.truncf %truncf : f32 to f16
+  return %truncf1 : f16
 }
 
 // TODO: We should also add a test for not folding arith.extf on information loss.

>From f66b1d374b2c1f0b6d7861346785dd86068768ee Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Thu, 20 Feb 2025 17:19:45 -0800
Subject: [PATCH 3/5] Remove truncf(truncf(a)) folding

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp                |  6 ------
 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir |  3 ++-
 mlir/test/Dialect/Arith/canonicalize.mlir             | 10 ----------
 3 files changed, 2 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 494985fbce94e..a3e12f164fa19 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1533,12 +1533,6 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
       return src;
   }
 
-  // truncf(truncf(a)) -> truncf(a)
-  if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
-    setOperand(truncOp.getIn());
-    return getResult();
-  }
-
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cebdebef85dc9..cb1d092918f03 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,8 +764,9 @@ func.func @arith_extf(%arg0: f16) -> f64 {
 func.func @arith_truncf(%arg0: f64) -> f16 {
   // CHECK-LABEL: arith_truncf
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
-  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
   %truncd0 = arith.truncf %arg0 : f64 to f32
+  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
   %truncd1 = arith.truncf %truncd0 : f32 to f16
 
   return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index aa4136cd6361e..a09991d966cf0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -733,16 +733,6 @@ func.func @truncExtf2(%arg0: f32) -> f16 {
   return %truncf : f16
 }
 
-// CHECK-LABEL: @truncTruncf
-//       CHECK:  %[[ARG0:.+]]: f64
-//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
-//       CHECK:   return  %[[CST:.*]]
-func.func @truncTruncf(%arg0: f64) -> f16 {
-  %truncf = arith.truncf %arg0 : f64 to f32
-  %truncf1 = arith.truncf %truncf : f32 to f16
-  return %truncf1 : f16
-}
-
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

>From 35a039b9086190be62c33b62aa9a3a3c97eba5ae Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Fri, 21 Feb 2025 11:40:04 -0800
Subject: [PATCH 4/5] Some fixes

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp    | 26 ++++++++++++-----------
 mlir/test/Dialect/Arith/canonicalize.mlir | 22 ++++++++++++++++++-
 2 files changed, 35 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a3e12f164fa19..01ce8f2970ba3 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,23 +1517,25 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
     Value src = extOp.getIn();
-    Type srcType = getElementTypeOrSelf(src.getType());
-    Type dstType = getElementTypeOrSelf(getType());
-    // truncf(extf(a)) -> truncf(a)
-    if (llvm::cast<FloatType>(srcType).getWidth() >
-        llvm::cast<FloatType>(dstType).getWidth()) {
-      setOperand(src);
-      return getResult();
-    }
+    auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
+    auto intermediateType = cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
+    // Check if the srcType is representable in the intermediateType
+    if(llvm::APFloatBase::isRepresentableBy(srcType.getFloatSemantics(), intermediateType.getFloatSemantics())) {
+      // truncf(extf(a)) -> truncf(a)
+      if (srcType.getWidth() > resElemType.getWidth()) {
+        setOperand(src);
+        return getResult();
+      }
 
-    // truncf(extf(a)) -> a
-    if (srcType == dstType)
-      return src;
+      // truncf(extf(a)) -> a
+      if (srcType == resElemType)
+        return src;
+    }
   }
 
-  auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a09991d966cf0..f0b2731707d18 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -723,11 +723,31 @@ func.func @truncExtf(%arg0: f32) -> f32 {
   return %trunc : f32
 }
 
+// CHECK-LABEL: @truncExtf1
+//       CHECK-NOT:  truncf
+//       CHECK:   return  %arg0
+func.func @truncExtf1(%arg0: bf16) -> bf16 {
+  %extf = arith.extf %arg0 : bf16 to f32
+  %trunc = arith.truncf %extf : f32 to bf16
+  return %trunc : bf16
+}
+
 // CHECK-LABEL: @truncExtf2
+//       CHECK:  %[[ARG0:.+]]: bf16
+//       CHECK:  %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32
+//       CHECK:  %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16
+//       CHECK:   return  %[[TRUNCF:.*]]
+func.func @truncExtf2(%arg0: bf16) -> f16 {
+  %extf = arith.extf %arg0 : bf16 to f32
+  %trunc = arith.truncf %extf : f32 to f16
+  return %trunc : f16
+}
+
+// CHECK-LABEL: @truncExtf3
 //       CHECK:  %[[ARG0:.+]]: f32
 //       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
 //       CHECK:   return  %[[CST:.*]]
-func.func @truncExtf2(%arg0: f32) -> f16 {
+func.func @truncExtf3(%arg0: f32) -> f16 {
   %extf = arith.extf %arg0 : f32 to f64
   %truncf = arith.truncf %extf : f64 to f16
   return %truncf : f16

>From b8c749c205305e7f61e6f73fa0e846ee1643e19f Mon Sep 17 00:00:00 2001
From: Zahi Moudallal <zahi at openai.com>
Date: Fri, 21 Feb 2025 11:46:48 -0800
Subject: [PATCH 5/5] Format

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 21 ++++++++++++---------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 01ce8f2970ba3..afa0540013b71 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1521,9 +1521,12 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
     Value src = extOp.getIn();
     auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
-    auto intermediateType = cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
+    auto intermediateType =
+        cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
     // Check if the srcType is representable in the intermediateType
-    if(llvm::APFloatBase::isRepresentableBy(srcType.getFloatSemantics(), intermediateType.getFloatSemantics())) {
+    if (llvm::APFloatBase::isRepresentableBy(
+            srcType.getFloatSemantics(),
+            intermediateType.getFloatSemantics())) {
       // truncf(extf(a)) -> truncf(a)
       if (srcType.getWidth() > resElemType.getWidth()) {
         setOperand(src);
@@ -2397,12 +2400,12 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
 
   // Constant-fold constant operands over non-splat constant condition.
   // select %cst_vec, %cst0, %cst1 => %cst2
-  if (auto cond =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
-    if (auto lhs =
-            llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
-      if (auto rhs =
-              llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
+  if (auto cond = llvm::dyn_cast_if_present<DenseElementsAttr>(
+          adaptor.getCondition())) {
+    if (auto lhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getTrueValue())) {
+      if (auto rhs = llvm::dyn_cast_if_present<DenseElementsAttr>(
+              adaptor.getFalseValue())) {
         SmallVector<Attribute> results;
         results.reserve(static_cast<size_t>(cond.getNumElements()));
         auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
@@ -2670,7 +2673,7 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
     return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minimumf:
     return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
-   case AtomicRMWKind::maxnumf:
+  case AtomicRMWKind::maxnumf:
     return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
   case AtomicRMWKind::minnumf:
     return builder.create<arith::MinNumFOp>(loc, lhs, rhs);



More information about the Mlir-commits mailing list