[llvm] [ValueTracking] Compute known FPClass from dominating condition (PR #80740)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 12:40:23 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch improves `computeKnownFPClass` by using context-sensitive information from `DomConditionCache`.
The motivation of this patch is to optimize the following case found in [fmt/format.h](https://github.com/fmtlib/fmt/blob/e17bc67547a66cdd378ca6a90c56b865d30d6168/include/fmt/format.h#L3555-L3566):
```
define float @<!-- -->test(float %x, i1 %cond) {
  %i32 = bitcast float %x to i32
  %cmp = icmp slt i32 %i32, 0
  br i1 %cmp, label %if.then1, label %if.else

if.then1:
  %fneg = fneg float %x
  br label %if.end

if.else:
  br i1 %cond, label %if.then2, label %if.end

if.then2:
  br label %if.end

if.end:
  %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
  %ret = call float @<!-- -->llvm.fabs.f32(float %value)
  ret float %ret
}
```
We can prove the sign bit of %value is always zero. Then the fabs can be eliminated.

This pattern also exists in cpython/duckdb/oiio/openexr.


---
Full diff: https://github.com/llvm/llvm-project/pull/80740.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/DomConditionCache.cpp (+3) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+67-13) 
- (added) llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll (+79) 


``````````diff
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index c7f4cab4158880..7c3d23e26d1183 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -51,6 +51,9 @@ static void findAffectedValues(Value *Cond,
       // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
       if (match(A, m_Add(m_Value(X), m_ConstantInt())))
         AddAffected(X);
+      // Handle icmp slt/sgt (bitcast X to int) 0/-1
+      if (match(A, m_BitCast(m_Value(X))))
+        Affected.push_back(X);
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 58db81f470130e..b3315c0bedd874 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4213,9 +4213,56 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
   return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
 }
 
-static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
-                                                  const SimplifyQuery &Q) {
-  FPClassTest KnownFromAssume = fcAllFlags;
+static KnownFPClass computeKnownFPClassFromContext(const Value *V,
+                                                   const SimplifyQuery &Q) {
+  KnownFPClass KnownFromContext;
+
+  if (!Q.CxtI)
+    return KnownFromContext;
+
+  if (Q.DC && Q.DT) {
+    auto computeKnownFPClassFromCmp = [&](CmpInst::Predicate Pred, Value *LHS,
+                                          Value *RHS) {
+      if (match(LHS, m_BitCast(m_Specific(V)))) {
+        Type *SrcType = V->getType();
+        Type *DstType = LHS->getType();
+
+        // Make sure the bitcast doesn't change between scalar and vector and
+        // doesn't change the number of vector elements.
+        if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+            SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
+          // TODO: move IsSignBitCheck to ValueTracking
+          if ((Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) ||
+              (Pred == ICmpInst::ICMP_SLE && match(RHS, m_AllOnes())))
+            KnownFromContext.signBitMustBeOne();
+          else if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()) ||
+                   (Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())))
+            KnownFromContext.signBitMustBeZero();
+        }
+      }
+    };
+
+    // Handle dominating conditions.
+    for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+      // TODO: handle fcmps
+      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
+      if (!Cmp)
+        continue;
+
+      BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
+      if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
+        computeKnownFPClassFromCmp(Cmp->getPredicate(), Cmp->getOperand(0),
+                                   Cmp->getOperand(1));
+
+      BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
+      if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
+        computeKnownFPClassFromCmp(Cmp->getInversePredicate(),
+                                   Cmp->getOperand(0), Cmp->getOperand(1));
+    }
+  }
+
+  if (!Q.AC)
+    return KnownFromContext;
 
   // Try to restrict the floating-point classes based on information from
   // assumptions.
@@ -4242,16 +4289,16 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
         auto [CmpVal, MaskIfTrue, MaskIfFalse] =
             fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
         if (CmpVal == V)
-          KnownFromAssume &= MaskIfTrue;
+          KnownFromContext.knownNot(~MaskIfTrue);
       }
     } else if (match(I->getArgOperand(0),
                      m_Intrinsic<Intrinsic::is_fpclass>(
                          m_Value(LHS), m_ConstantInt(ClassVal)))) {
-      KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
+      KnownFromContext.knownNot(~static_cast<FPClassTest>(ClassVal));
     }
   }
 
-  return KnownFromAssume;
+  return KnownFromContext;
 }
 
 void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
@@ -4359,10 +4406,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
       KnownNotFromFlags |= fcInf;
   }
 
-  if (Q.AC) {
-    FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
-    KnownNotFromFlags |= ~AssumedClasses;
-  }
+  KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
+  KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
 
   // We no longer need to find out about these bits from inputs if we can
   // assume this from flags/attributes.
@@ -4370,6 +4415,12 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
 
   auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
     Known.knownNot(KnownNotFromFlags);
+    if (!Known.SignBit && AssumedClasses.SignBit) {
+      if (*AssumedClasses.SignBit)
+        Known.signBitMustBeOne();
+      else
+        Known.signBitMustBeZero();
+    }
   });
 
   if (!Op)
@@ -5271,7 +5322,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
 
       bool First = true;
 
-      for (Value *IncValue : P->incoming_values()) {
+      for (const Use &U : P->operands()) {
+        Value *IncValue = U.get();
         // Skip direct self references.
         if (IncValue == P)
           continue;
@@ -5280,8 +5332,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
         // Recurse, but cap the recursion to two levels, because we don't want
         // to waste time spinning around in loops. We need at least depth 2 to
         // detect known sign bits.
-        computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
-                            PhiRecursionLimit, Q);
+        computeKnownFPClass(
+            IncValue, DemandedElts, InterestedClasses, KnownSrc,
+            PhiRecursionLimit,
+            Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));
 
         if (First) {
           Known = KnownSrc;
diff --git a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
new file mode 100644
index 00000000000000..7338fa176843a6
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
@@ -0,0 +1,79 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define float @test_signbit_check(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then1:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT:    ret float [[VALUE]]
+;
+  %i32 = bitcast float %x to i32
+  %cmp = icmp slt i32 %i32, 0
+  br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+  %fneg = fneg float %x
+  br label %if.end
+
+if.else:
+  br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+  br label %if.end
+
+if.end:
+  %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
+  %ret = call float @llvm.fabs.f32(float %value)
+  ret float %ret
+}
+
+define float @test_signbit_check_fail(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check_fail(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK:       if.then1:
+; CHECK-NEXT:    [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END:%.*]]
+; CHECK:       if.else:
+; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK:       if.then2:
+; CHECK-NEXT:    [[FNEG2:%.*]] = fneg float [[X]]
+; CHECK-NEXT:    br label [[IF_END]]
+; CHECK:       if.end:
+; CHECK-NEXT:    [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[FNEG2]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT:    [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
+; CHECK-NEXT:    ret float [[RET]]
+;
+  %i32 = bitcast float %x to i32
+  %cmp = icmp slt i32 %i32, 0
+  br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+  %fneg = fneg float %x
+  br label %if.end
+
+if.else:
+  br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+  %fneg2 = fneg float %x
+  br label %if.end
+
+if.end:
+  %value = phi float [ %fneg, %if.then1 ], [ %fneg2, %if.then2 ], [ %x, %if.else ]
+  %ret = call float @llvm.fabs.f32(float %value)
+  ret float %ret
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/80740


More information about the llvm-commits mailing list