[llvm] [InstCombine] Eliminate fptrunc/fpext if fast math flags allow it (PR #115027)
John Brawn via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 06:16:55 PST 2024
https://github.com/john-brawn-arm updated https://github.com/llvm/llvm-project/pull/115027
>From 78e1634b60254f3c71fea51740ba65b57270e8fe Mon Sep 17 00:00:00 2001
From: John Brawn <john.brawn at arm.com>
Date: Fri, 1 Nov 2024 15:56:51 +0000
Subject: [PATCH] [InstCombine] Eliminate fptrunc/fpext if fast math flags
allow it
When expressions of a floating-point type are evaluated at a higher
precision (e.g. _Float16 being evaluated as float) this results in a
fptrunc then fpext between each operation. With the appropriate fast
math flags (nnan ninf contract) we can eliminate these cast
instructions.
---
.../InstCombine/InstCombineCasts.cpp | 25 +++++++
llvm/test/Transforms/InstCombine/fpextend.ll | 73 +++++++++++++++++++
2 files changed, 98 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 7221c987b98219..6b8c362c544569 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1940,6 +1940,31 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
}
+ // fpext (fptrunc(x)) -> x, if the fast math flags allow it
+ if (auto *Trunc = dyn_cast<FPTruncInst>(Src)) {
+ // Whether this transformation is possible depends on the fast math flags of
+ // both the fpext and fptrunc.
+ FastMathFlags SrcFlags = Trunc->getFastMathFlags();
+ FastMathFlags DstFlags = FPExt.getFastMathFlags();
+ // Trunc can introduce inf and change the encoding of a nan, so the
+ // destination must have the nnan and ninf flags to indicate that we don't
+ // need to care about that. We are also removing a rounding step, and that
+ // requires both the source and destination to allow contraction.
+ if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() &&
+ DstFlags.allowContract()) {
+ Value *TruncSrc = Trunc->getOperand(0);
+ // We do need a single cast if the source and destination types don't
+ // match.
+ if (TruncSrc->getType() != Ty) {
+ Instruction *Ret = CastInst::CreateFPCast(TruncSrc, Ty);
+ Ret->copyFastMathFlags(&FPExt);
+ return Ret;
+ } else {
+ return replaceInstUsesWith(FPExt, TruncSrc);
+ }
+ }
+ }
+
return commonCastTransforms(FPExt);
}
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db44..c18238d9721921 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,76 @@ define bfloat @bf16_frem(bfloat %x) {
%t3 = fptrunc float %t2 to bfloat
ret bfloat %t3
}
+
+define double @fptrunc_fpextend_nofast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_nofast(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT: [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd double %x, %y
+ %trunc = fptrunc double %add1 to float
+ %ext = fpext float %trunc to double
+ %add2 = fadd double %ext, %z
+ ret double %add2
+}
+
+define double @fptrunc_fpextend_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_fast(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[ADD1]], [[Z:%.*]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd double %x, %y
+ %trunc = fptrunc contract double %add1 to float
+ %ext = fpext nnan ninf contract float %trunc to double
+ %add2 = fadd double %ext, %z
+ ret double %add2
+}
+
+define float @fptrunc_fpextend_result_smaller(double %x, double %y, float %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_smaller(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[EXT:%.*]] = fptrunc nnan ninf contract double [[ADD1]] to float
+; CHECK-NEXT: [[ADD2:%.*]] = fadd float [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret float [[ADD2]]
+;
+ %add1 = fadd double %x, %y
+ %trunc = fptrunc contract double %add1 to half
+ %ext = fpext nnan ninf contract half %trunc to float
+ %add2 = fadd float %ext, %z
+ ret float %add2
+}
+
+define double @fptrunc_fpextend_result_larger(float %x, float %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_larger(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd float [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[EXT:%.*]] = fpext nnan ninf contract float [[ADD1]] to double
+; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd float %x, %y
+ %trunc = fptrunc contract float %add1 to half
+ %ext = fpext nnan ninf contract half %trunc to double
+ %add2 = fadd double %ext, %z
+ ret double %add2
+}
+
+define double @fptrunc_fpextend_multiple_use(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[ADD1]], [[A:%.*]]
+; CHECK-NEXT: [[ADD3:%.*]] = fadd double [[ADD1]], [[B:%.*]]
+; CHECK-NEXT: [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT: ret double [[MUL]]
+;
+ %add1 = fadd double %x, %y
+ %trunc = fptrunc contract double %add1 to float
+ %ext = fpext nnan ninf contract float %trunc to double
+ %add2 = fadd double %ext, %a
+ %add3 = fadd double %ext, %b
+ %mul = fmul double %add2, %add3
+ ret double %mul
+}
More information about the llvm-commits
mailing list