[Mlir-commits] [mlir] 5d0c5c6 - [MLIR][ARITH] Adds missing foldings for truncf (#128096)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 15:37:22 PST 2025
Author: Zahi Moudallal
Date: 2025-02-21T15:37:19-08:00
New Revision: 5d0c5c638ad2f34552f1188c6e9b9ff5406580f9
URL: https://github.com/llvm/llvm-project/commit/5d0c5c638ad2f34552f1188c6e9b9ff5406580f9
DIFF: https://github.com/llvm/llvm-project/commit/5d0c5c638ad2f34552f1188c6e9b9ff5406580f9.diff
LOG: [MLIR][ARITH] Adds missing foldings for truncf (#128096)
This patch is mainly to deal with folding `truncf`, as follows:
`truncf(extf(a))` -> `a`, if `a` has the same bitwidth as the result
`truncf(extf(a))` -> `truncf(a)`, if `a` has larger bitwidth than the
result
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..e9545c3146b2f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1518,6 +1518,27 @@ LogicalResult arith::TruncIOp::verify() {
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
+ if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+ Value src = extOp.getIn();
+ auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
+ auto intermediateType =
+ cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
+ // Check if the srcType is representable in the intermediateType.
+ if (llvm::APFloatBase::isRepresentableBy(
+ srcType.getFloatSemantics(),
+ intermediateType.getFloatSemantics())) {
+ // truncf(extf(a)) -> truncf(a)
+ if (srcType.getWidth() > resElemType.getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+
+ // truncf(extf(a)) -> a
+ if (srcType == resElemType)
+ return src;
+ }
+ }
+
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..f0b2731707d18 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,45 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// CHECK-LABEL: @truncExtf
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %trunc = arith.truncf %extf : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf1
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf1(%arg0: bf16) -> bf16 {
+ %extf = arith.extf %arg0 : bf16 to f32
+ %trunc = arith.truncf %extf : f32 to bf16
+ return %trunc : bf16
+}
+
+// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: bf16
+// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32
+// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16
+// CHECK: return %[[TRUNCF:.*]]
+func.func @truncExtf2(%arg0: bf16) -> f16 {
+ %extf = arith.extf %arg0 : bf16 to f32
+ %trunc = arith.truncf %extf : f32 to f16
+ return %trunc : f16
+}
+
+// CHECK-LABEL: @truncExtf3
+// CHECK: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtf3(%arg0: f32) -> f16 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %truncf = arith.truncf %extf : f64 to f16
+ return %truncf : f16
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
More information about the Mlir-commits
mailing list