[Mlir-commits] [mlir] [MLIR][ARITH] Adds missing foldings for truncf (PR #128096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 16:48:37 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir
Author: Zahi Moudallal (zahimoud)
<details>
<summary>Changes</summary>
This patch is mainly to deal with folding truncf, as follows:
truncf(extf(a)) -> a, if `a` has the same bitwidth as the result
truncf(extf(a)) -> truncf(a), if `a` has larger bitwidth than the result
truncf(truncf(a)) -> truncf(a), in any case
---
Full diff: https://github.com/llvm/llvm-project/pull/128096.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+22)
- (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+1-2)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+29)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+ Value src = extOp.getIn();
+ Type srcType = getElementTypeOrSelf(src.getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ // truncf(extf(a)) -> truncf(a)
+ if (llvm::cast<FloatType>(srcType).getWidth() >
+ llvm::cast<FloatType>(dstType).getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+
+ // truncf(extf(a)) -> a
+ if (srcType == dstType)
+ return src;
+ }
+
+ // truncf(truncf(a)) -> truncf(a)
+ if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+ setOperand(truncOp.getIn());
+ return getResult();
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
- // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
%truncd0 = arith.truncf %arg0 : f64 to f32
- // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// CHECK-LABEL: @truncExtf
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %trunc = arith.truncf %extf : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %truncf = arith.truncf %extf : f64 to f16
+ return %truncf : f16
+}
+
+// CHECK-LABEL: @truncTruncf
+// CHECK: %[[ARG0:.+]]: f64
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+ %truncf = arith.truncf %arg0 : f64 to f32
+ %truncf1 = arith.truncf %truncf : f32 to f16
+ return %truncf1 : f16
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
``````````
</details>
https://github.com/llvm/llvm-project/pull/128096
More information about the Mlir-commits
mailing list