[llvm] [InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL (PR #87474)

Sushant Gokhale via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 00:44:21 PST 2025


================
@@ -1864,6 +1950,63 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
   return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
 }
 
+// Change
+// X = 1/sqrt(a)
+// R1 = X * X
+// R2 = a * X
+//
+// TO
+//
+// FDiv = 1/a
+// FSqrt = sqrt(a)
+// FMul = FDiv * FSqrt
+// Replace Uses Of R1 With FDiv
+// Replace Uses Of R2 With FSqrt
+// Replace Uses Of X With FMul
+static Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
+                                      ArrayRef<Instruction *> R1,
+                                      ArrayRef<Instruction *> R2,
+                                      InstCombiner::BuilderTy &B) {
+
+  B.SetInsertPoint(X);
+
+  // Every instance of R1 may have different fpmath metadata and fpmath flags.
+  // We try to preserve them by having seperate fdiv instruction per R1
+  // instance.
+  Value *SqrtOp = CI->getArgOperand(0);
+  Instruction *FDiv;
+  for (Instruction *I : R1) {
+    FDiv = cast<Instruction>(
+        B.CreateFDiv(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp));
+    FDiv->copyMetadata(*I);
+    FDiv->copyFastMathFlags(I);
+    I->replaceAllUsesWith(FDiv);
+  }
+
+  // Although, by value, FSqrt = CI , every instance of R2 may have different
+  // fpmath metadata and fpmath flags. We try to preserve them by cloning the
+  // call instruction per R2 instance.
----------------
sushgokh wrote:

We clone Sqrt call for value but borrow the flags from original R2. So, transformation looks like this.

Value wise:  R2 --> Sqrt(a)
Flags wise:  flagsBeforeTransform(R2) --> flagsAfterTransform(R2)

Consider this example:
```
@x = global double 0.000000e+00
@r1 = global double 0.000000e+00
@r2.1 = global double 0.000000e+00
@r2.2 = global double 0.000000e+00

define void @bb_constraint_case1(double %a) {
entry:
  %sqrt = call reassoc nnan nsz ninf double @llvm.sqrt.f64(double %a)
  %1 = fdiv reassoc arcp ninf double 1.000000e+00, %sqrt
  store double %1, ptr @x
  %2 = fmul reassoc double %1, %1
  store double %2, ptr @r1
  %3 = fdiv reassoc double %a, %sqrt
  %4 = fdiv ninf reassoc double %a, %sqrt
  store double %3, ptr @r2.1
  store double %4, ptr @r2.2
  ret void
}
```

The output of the transformation is:
```
@x = global double 0.000000e+00
@r1 = global double 0.000000e+00
@r2.1 = global double 0.000000e+00
@r2.2 = global double 0.000000e+00

define void @bb_constraint_case1(double %a) {
entry:
  %0 = call reassoc ninf double @llvm.sqrt.f64(double %a)
  %1 = call reassoc double @llvm.sqrt.f64(double %a)
  %2 = fdiv reassoc double 1.000000e+00, %a
  %3 = fmul reassoc ninf arcp double %2, %1
  store double %3, ptr @x, align 8
  store double %2, ptr @r1, align 8
  store double %1, ptr @r2.1, align 8
  store double %0, ptr @r2.2, align 8
  ret void
}
```

So, we are not losing any flags here.


https://github.com/llvm/llvm-project/pull/87474


More information about the llvm-commits mailing list