[llvm] [InstCombine] Eliminate fptrunc/fpext if fast math flags allow it (PR #115027)

John Brawn via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 06:16:55 PST 2024


https://github.com/john-brawn-arm updated https://github.com/llvm/llvm-project/pull/115027

>From 78e1634b60254f3c71fea51740ba65b57270e8fe Mon Sep 17 00:00:00 2001
From: John Brawn <john.brawn at arm.com>
Date: Fri, 1 Nov 2024 15:56:51 +0000
Subject: [PATCH] [InstCombine] Eliminate fptrunc/fpext if fast math flags
 allow it

When expressions of a floating-point type are evaluated at a higher
precision (e.g. _Float16 being evaluated as float) this results in a
fptrunc then fpext between each operation. With the appropriate fast
math flags (nnan ninf contract) we can eliminate these cast
instructions.
---
 .../InstCombine/InstCombineCasts.cpp          | 25 +++++++
 llvm/test/Transforms/InstCombine/fpextend.ll  | 73 +++++++++++++++++++
 2 files changed, 98 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 7221c987b98219..6b8c362c544569 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1940,6 +1940,31 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
       return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
   }
 
+  // fpext (fptrunc(x)) -> x, if the fast math flags allow it
+  if (auto *Trunc = dyn_cast<FPTruncInst>(Src)) {
+    // Whether this transformation is possible depends on the fast math flags of
+    // both the fpext and fptrunc.
+    FastMathFlags SrcFlags = Trunc->getFastMathFlags();
+    FastMathFlags DstFlags = FPExt.getFastMathFlags();
+    // Trunc can introduce inf and change the encoding of a nan, so the
+    // destination must have the nnan and ninf flags to indicate that we don't
+    // need to care about that. We are also removing a rounding step, and that
+    // requires both the source and destination to allow contraction.
+    if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() &&
+        DstFlags.allowContract()) {
+      Value *TruncSrc = Trunc->getOperand(0);
+      // We do need a single cast if the source and destination types don't
+      // match.
+      if (TruncSrc->getType() != Ty) {
+        Instruction *Ret = CastInst::CreateFPCast(TruncSrc, Ty);
+        Ret->copyFastMathFlags(&FPExt);
+        return Ret;
+      } else {
+        return replaceInstUsesWith(FPExt, TruncSrc);
+      }
+    }
+  }
+
   return commonCastTransforms(FPExt);
 }
 
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db44..c18238d9721921 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,76 @@ define bfloat @bf16_frem(bfloat %x) {
   %t3 = fptrunc float %t2 to bfloat
   ret bfloat %t3
 }
+
+define double @fptrunc_fpextend_nofast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_nofast(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd double %x, %y
+  %trunc = fptrunc double %add1 to float
+  %ext = fpext float %trunc to double
+  %add2 = fadd double %ext, %z
+  ret double %add2
+}
+
+define double @fptrunc_fpextend_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_fast(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd double [[ADD1]], [[Z:%.*]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd double %x, %y
+  %trunc = fptrunc contract double %add1 to float
+  %ext = fpext nnan ninf contract float %trunc to double
+  %add2 = fadd double %ext, %z
+  ret double %add2
+}
+
+define float @fptrunc_fpextend_result_smaller(double %x, double %y, float %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_smaller(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[EXT:%.*]] = fptrunc nnan ninf contract double [[ADD1]] to float
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd float [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret float [[ADD2]]
+;
+  %add1 = fadd double %x, %y
+  %trunc = fptrunc contract double %add1 to half
+  %ext = fpext nnan ninf contract half %trunc to float
+  %add2 = fadd float %ext, %z
+  ret float %add2
+}
+
+define double @fptrunc_fpextend_result_larger(float %x, float %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_larger(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd float [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext nnan ninf contract float [[ADD1]] to double
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd float %x, %y
+  %trunc = fptrunc contract float %add1 to half
+  %ext = fpext nnan ninf contract half %trunc to double
+  %add2 = fadd double %ext, %z
+  ret double %add2
+}
+
+define double @fptrunc_fpextend_multiple_use(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd double [[ADD1]], [[A:%.*]]
+; CHECK-NEXT:    [[ADD3:%.*]] = fadd double [[ADD1]], [[B:%.*]]
+; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT:    ret double [[MUL]]
+;
+  %add1 = fadd double %x, %y
+  %trunc = fptrunc contract double %add1 to float
+  %ext = fpext nnan ninf contract float %trunc to double
+  %add2 = fadd double %ext, %a
+  %add3 = fadd double %ext, %b
+  %mul = fmul double %add2, %add3
+  ret double %mul
+}



More information about the llvm-commits mailing list