[llvm] [ValueTracking] Compute known FPClass from dominating condition (PR #80740)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 20:56:28 PST 2024
https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/80740
>From b0c9bb77112428e842cc3ec53e71a3f09228f3a5 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 04:19:46 +0800
Subject: [PATCH 1/2] [ValueTracking] Add pre-commit tests. NFC.
---
.../InstCombine/fpclass-from-dom-cond.ll | 80 +++++++++++++++++++
1 file changed, 80 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
diff --git a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
new file mode 100644
index 00000000000000..ba265018217c90
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
@@ -0,0 +1,80 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define float @test_signbit_check(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT: [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK: if.then1:
+; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT: br label [[IF_END:%.*]]
+; CHECK: if.else:
+; CHECK-NEXT: br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK: if.then2:
+; CHECK-NEXT: br label [[IF_END]]
+; CHECK: if.end:
+; CHECK-NEXT: [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT: [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
+; CHECK-NEXT: ret float [[RET]]
+;
+ %i32 = bitcast float %x to i32
+ %cmp = icmp slt i32 %i32, 0
+ br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+ %fneg = fneg float %x
+ br label %if.end
+
+if.else:
+ br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+ br label %if.end
+
+if.end:
+ %value = phi float [ %fneg, %if.then1 ], [ %x, %if.then2 ], [ %x, %if.else ]
+ %ret = call float @llvm.fabs.f32(float %value)
+ ret float %ret
+}
+
+define float @test_signbit_check_fail(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_signbit_check_fail(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT: [[I32:%.*]] = bitcast float [[X]] to i32
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[I32]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_ELSE:%.*]]
+; CHECK: if.then1:
+; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[X]]
+; CHECK-NEXT: br label [[IF_END:%.*]]
+; CHECK: if.else:
+; CHECK-NEXT: br i1 [[COND]], label [[IF_THEN2:%.*]], label [[IF_END]]
+; CHECK: if.then2:
+; CHECK-NEXT: [[FNEG2:%.*]] = fneg float [[X]]
+; CHECK-NEXT: br label [[IF_END]]
+; CHECK: if.end:
+; CHECK-NEXT: [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[FNEG2]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
+; CHECK-NEXT: [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
+; CHECK-NEXT: ret float [[RET]]
+;
+ %i32 = bitcast float %x to i32
+ %cmp = icmp slt i32 %i32, 0
+ br i1 %cmp, label %if.then1, label %if.else
+
+if.then1:
+ %fneg = fneg float %x
+ br label %if.end
+
+if.else:
+ br i1 %cond, label %if.then2, label %if.end
+
+if.then2:
+ %fneg2 = fneg float %x
+ br label %if.end
+
+if.end:
+ %value = phi float [ %fneg, %if.then1 ], [ %fneg2, %if.then2 ], [ %x, %if.else ]
+ %ret = call float @llvm.fabs.f32(float %value)
+ ret float %ret
+}
>From 9f02fb4657fb393c796045b647dc25bdc4d24432 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 04:22:17 +0800
Subject: [PATCH 2/2] [ValueTracking] Compute known FPClass from dominating
condition
---
llvm/lib/Analysis/DomConditionCache.cpp | 3 +
llvm/lib/Analysis/ValueTracking.cpp | 120 ++++++++++++++----
.../InstCombine/fpclass-from-dom-cond.ll | 3 +-
3 files changed, 97 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index c7f4cab4158880..7c3d23e26d1183 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -51,6 +51,9 @@ static void findAffectedValues(Value *Cond,
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
if (match(A, m_Add(m_Value(X), m_ConstantInt())))
AddAffected(X);
+ // Handle icmp slt/sgt (bitcast X to int) 0/-1
+ if (match(A, m_BitCast(m_Value(X))))
+ Affected.push_back(X);
}
}
}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 58db81f470130e..c10c4a77e0695f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4213,9 +4213,82 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
}
-static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
- const SimplifyQuery &Q) {
- FPClassTest KnownFromAssume = fcAllFlags;
+static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
+ bool CondIsTrue,
+ const Instruction *CxtI,
+ KnownFPClass &KnownFromContext) {
+ CmpInst::Predicate Pred;
+ Value *LHS;
+ Value *RHS;
+ uint64_t ClassVal = 0;
+ if (match(Cond, m_Cmp(Pred, m_Value(LHS), m_Value(RHS)))) {
+ if (CmpInst::isIntPredicate(Pred)) {
+ if (!match(LHS, m_BitCast(m_Specific(V))))
+ return;
+ Type *SrcType = V->getType();
+ Type *DstType = LHS->getType();
+
+ // Make sure the bitcast doesn't change between scalar and vector and
+ // doesn't change the number of vector elements.
+ if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+ SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) {
+ // TODO: move IsSignBitCheck to ValueTracking
+ bool TrueIfSigned;
+ if ((Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) ||
+ (Pred == ICmpInst::ICMP_SLE && match(RHS, m_AllOnes())))
+ TrueIfSigned = true;
+ else if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()) ||
+ (Pred == ICmpInst::ICMP_SGE && match(RHS, m_Zero())))
+ TrueIfSigned = false;
+ else
+ return;
+ if (TrueIfSigned == CondIsTrue)
+ KnownFromContext.signBitMustBeOne();
+ else
+ KnownFromContext.signBitMustBeZero();
+ }
+ } else {
+ const APFloat *CRHS;
+ if (match(RHS, m_APFloat(CRHS))) {
+ auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
+ Pred, *CxtI->getParent()->getParent(), LHS, *CRHS, LHS != V);
+ if (CmpVal == V)
+ KnownFromContext.knownNot(~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
+ }
+ }
+ } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(
+ m_Value(LHS), m_ConstantInt(ClassVal)))) {
+ FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
+ KnownFromContext.knownNot(CondIsTrue ? ~Mask : Mask);
+ }
+}
+
+static KnownFPClass computeKnownFPClassFromContext(const Value *V,
+ const SimplifyQuery &Q) {
+ KnownFPClass KnownFromContext;
+
+ if (!Q.CxtI)
+ return KnownFromContext;
+
+ if (Q.DC && Q.DT) {
+ // Handle dominating conditions.
+ for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+ Value *Cond = BI->getCondition();
+
+ BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
+ if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
+ computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, Q.CxtI,
+ KnownFromContext);
+
+ BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
+ if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
+ computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, Q.CxtI,
+ KnownFromContext);
+ }
+ }
+
+ if (!Q.AC)
+ return KnownFromContext;
// Try to restrict the floating-point classes based on information from
// assumptions.
@@ -4233,25 +4306,11 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
continue;
- CmpInst::Predicate Pred;
- Value *LHS, *RHS;
- uint64_t ClassVal = 0;
- if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
- const APFloat *CRHS;
- if (match(RHS, m_APFloat(CRHS))) {
- auto [CmpVal, MaskIfTrue, MaskIfFalse] =
- fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
- if (CmpVal == V)
- KnownFromAssume &= MaskIfTrue;
- }
- } else if (match(I->getArgOperand(0),
- m_Intrinsic<Intrinsic::is_fpclass>(
- m_Value(LHS), m_ConstantInt(ClassVal)))) {
- KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
- }
+ computeKnownFPClassFromCond(V, I->getArgOperand(0), /*CondIsTrue=*/true,
+ Q.CxtI, KnownFromContext);
}
- return KnownFromAssume;
+ return KnownFromContext;
}
void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
@@ -4359,10 +4418,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
KnownNotFromFlags |= fcInf;
}
- if (Q.AC) {
- FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
- KnownNotFromFlags |= ~AssumedClasses;
- }
+ KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
+ KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
// We no longer need to find out about these bits from inputs if we can
// assume this from flags/attributes.
@@ -4370,6 +4427,12 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
Known.knownNot(KnownNotFromFlags);
+ if (!Known.SignBit && AssumedClasses.SignBit) {
+ if (*AssumedClasses.SignBit)
+ Known.signBitMustBeOne();
+ else
+ Known.signBitMustBeZero();
+ }
});
if (!Op)
@@ -5271,7 +5334,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
bool First = true;
- for (Value *IncValue : P->incoming_values()) {
+ for (const Use &U : P->operands()) {
+ Value *IncValue = U.get();
// Skip direct self references.
if (IncValue == P)
continue;
@@ -5280,8 +5344,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
// Recurse, but cap the recursion to two levels, because we don't want
// to waste time spinning around in loops. We need at least depth 2 to
// detect known sign bits.
- computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
- PhiRecursionLimit, Q);
+ computeKnownFPClass(
+ IncValue, DemandedElts, InterestedClasses, KnownSrc,
+ PhiRecursionLimit,
+ Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));
if (First) {
Known = KnownSrc;
diff --git a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
index ba265018217c90..7338fa176843a6 100644
--- a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
+++ b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
@@ -16,8 +16,7 @@ define float @test_signbit_check(float %x, i1 %cond) {
; CHECK-NEXT: br label [[IF_END]]
; CHECK: if.end:
; CHECK-NEXT: [[VALUE:%.*]] = phi float [ [[FNEG]], [[IF_THEN1]] ], [ [[X]], [[IF_THEN2]] ], [ [[X]], [[IF_ELSE]] ]
-; CHECK-NEXT: [[RET:%.*]] = call float @llvm.fabs.f32(float [[VALUE]])
-; CHECK-NEXT: ret float [[RET]]
+; CHECK-NEXT: ret float [[VALUE]]
;
%i32 = bitcast float %x to i32
%cmp = icmp slt i32 %i32, 0
More information about the llvm-commits
mailing list