[llvm] [InstCombine] Eliminate fptrunc/fpext if fast math flags allow it (PR #115027)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 09:15:24 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: John Brawn (john-brawn-arm)

<details>
<summary>Changes</summary>

When expressions of a floating-point type are evaluated at a higher precision (e.g. _Float16 being evaluated as float) this results in a fptrunc then fpext between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these cast instructions. As cast instructions don't have fast math flags it's the source and destination of the casts whose flags are checked.

---
Full diff: https://github.com/llvm/llvm-project/pull/115027.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+26) 
- (modified) llvm/test/Transforms/InstCombine/fpextend.ll (+92) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 6c2554ea73b7f8..4c2bbcfec5cf82 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1949,6 +1949,32 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
       return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
   }
 
+  // fpext (fptrunc(x)) -> x, if the fast math flags allow it
+  Instruction *SrcInstr;
+  if (match(Src, m_FPTrunc(m_Instruction(SrcInstr)))) {
+    // Whether this transformation is possible depends both on the flags of the
+    // value that is truncated, and the flags on the instructions that use the
+    // fpext.
+    FastMathFlags SrcFlags = SrcInstr->getFastMathFlags();
+    FastMathFlags DstFlags = FastMathFlags::getFast();
+    for (User *U : FPExt.users())
+      if (auto *UInstr = dyn_cast<Instruction>(U))
+        DstFlags &= UInstr->getFastMathFlags();
+    // Trunc can introduce inf and change the encoding of a nan, so the
+    // destination must have the nnan and ninf flags to indicate that we don't
+    // need to care about that. We are also removing a rounding step, and that
+    // requires both the source and destination to allow contraction.
+    if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() &&
+        DstFlags.allowContract()) {
+      // We do need a single cast if the source and destination types don't
+      // match.
+      if (SrcInstr->getType() != Ty)
+        return CastInst::CreateFPCast(SrcInstr, Ty);
+      else
+        return replaceInstUsesWith(FPExt, SrcInstr);
+    }
+  }
+
   return commonCastTransforms(FPExt);
 }
 
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db44..d3ac511b10996d 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,95 @@ define bfloat @bf16_frem(bfloat %x) {
   %t3 = fptrunc float %t2 to bfloat
   ret bfloat %t3
 }
+
+define double @fptrunc_fpextend_nofast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_nofast(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd double %x, %y
+  %trunc = fptrunc double %add1 to float
+  %ext = fpext float %trunc to double
+  %add2 = fadd double %ext, %z
+  ret double %add2
+}
+
+define double @fptrunc_fpextend_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_fast(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd nnan ninf contract double [[ADD1]], [[Z:%.*]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd contract double %x, %y
+  %trunc = fptrunc double %add1 to float
+  %ext = fpext float %trunc to double
+  %add2 = fadd nnan ninf contract double %ext, %z
+  ret double %add2
+}
+
+define float @fptrunc_fpextend_result_smaller(double %x, double %y, float %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_smaller(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[EXT:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd nnan ninf contract float [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret float [[ADD2]]
+;
+  %add1 = fadd contract double %x, %y
+  %trunc = fptrunc double %add1 to half
+  %ext = fpext half %trunc to float
+  %add2 = fadd nnan ninf contract float %ext, %z
+  ret float %add2
+}
+
+define double @fptrunc_fpextend_result_larger(float %x, float %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_larger(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd contract float [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[ADD1]] to double
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd nnan ninf contract double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT:    ret double [[ADD2]]
+;
+  %add1 = fadd contract float %x, %y
+  %trunc = fptrunc float %add1 to half
+  %ext = fpext half %trunc to double
+  %add2 = fadd nnan ninf contract double %ext, %z
+  ret double %add2
+}
+
+define double @fptrunc_fpextend_multiple_use(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd nnan ninf contract double [[ADD1]], [[A:%.*]]
+; CHECK-NEXT:    [[ADD3:%.*]] = fadd nnan ninf contract double [[ADD1]], [[B:%.*]]
+; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT:    ret double [[MUL]]
+;
+  %add1 = fadd contract double %x, %y
+  %trunc = fptrunc double %add1 to float
+  %ext = fpext float %trunc to double
+  %add2 = fadd nnan ninf contract double %ext, %a
+  %add3 = fadd nnan ninf contract double %ext, %b
+  %mul = fmul double %add2, %add3
+  ret double %mul
+}
+
+define double @fptrunc_fpextend_multiple_use_flag_mismatch(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use_flag_mismatch(
+; CHECK-NEXT:    [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT:    [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT:    [[ADD2:%.*]] = fadd nnan ninf contract double [[A:%.*]], [[EXT]]
+; CHECK-NEXT:    [[ADD3:%.*]] = fadd nnan ninf double [[B:%.*]], [[EXT]]
+; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT:    ret double [[MUL]]
+;
+  %add1 = fadd contract double %x, %y
+  %trunc = fptrunc double %add1 to float
+  %ext = fpext float %trunc to double
+  %add2 = fadd nnan ninf contract double %ext, %a
+  %add3 = fadd nnan ninf double %ext, %b
+  %mul = fmul double %add2, %add3
+  ret double %mul
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/115027


More information about the llvm-commits mailing list