[llvm] [ValueTracking] Handle `not` in `isImpliedCondition` (PR #85397)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 15 06:24:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch handles `not` in `isImpliedCondition` to enable more fold in some multi-use cases.


---
Full diff: https://github.com/llvm/llvm-project/pull/85397.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/ValueTracking.cpp (+23-8) 
- (added) llvm/test/Transforms/InstCombine/and-or-implied-cond-not.ll (+69) 
- (modified) llvm/test/Transforms/LoopVectorize/reduction-inloop.ll (+1-2) 


``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index edbeede910d7f7..9aa6886ac02b09 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8622,6 +8622,10 @@ llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred,
   assert(LHS->getType()->isIntOrIntVectorTy(1) &&
          "Expected integer type only!");
 
+  // Match not
+  if (match(LHS, m_Not(m_Value(LHS))))
+    LHSIsTrue = !LHSIsTrue;
+
   // Both LHS and RHS are icmps.
   const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS);
   if (LHSCmp)
@@ -8648,10 +8652,21 @@ std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
   if (LHS == RHS)
     return LHSIsTrue;
 
-  if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS))
-    return isImpliedCondition(LHS, RHSCmp->getPredicate(),
-                              RHSCmp->getOperand(0), RHSCmp->getOperand(1), DL,
-                              LHSIsTrue, Depth);
+  // Match not
+  bool InvertRHS = false;
+  if (match(RHS, m_Not(m_Value(RHS)))) {
+    if (LHS == RHS)
+      return !LHSIsTrue;
+    InvertRHS = true;
+  }
+
+  if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS)) {
+    if (auto Implied = isImpliedCondition(
+            LHS, RHSCmp->getPredicate(), RHSCmp->getOperand(0),
+            RHSCmp->getOperand(1), DL, LHSIsTrue, Depth))
+      return InvertRHS ? !*Implied : *Implied;
+    return std::nullopt;
+  }
 
   if (Depth == MaxAnalysisRecursionDepth)
     return std::nullopt;
@@ -8663,21 +8678,21 @@ std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
     if (std::optional<bool> Imp =
             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
       if (*Imp == true)
-        return true;
+        return !InvertRHS;
     if (std::optional<bool> Imp =
             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
       if (*Imp == true)
-        return true;
+        return !InvertRHS;
   }
   if (match(RHS, m_LogicalAnd(m_Value(RHS1), m_Value(RHS2)))) {
     if (std::optional<bool> Imp =
             isImpliedCondition(LHS, RHS1, DL, LHSIsTrue, Depth + 1))
       if (*Imp == false)
-        return false;
+        return InvertRHS;
     if (std::optional<bool> Imp =
             isImpliedCondition(LHS, RHS2, DL, LHSIsTrue, Depth + 1))
       if (*Imp == false)
-        return false;
+        return InvertRHS;
   }
 
   return std::nullopt;
diff --git a/llvm/test/Transforms/InstCombine/and-or-implied-cond-not.ll b/llvm/test/Transforms/InstCombine/and-or-implied-cond-not.ll
new file mode 100644
index 00000000000000..89b3164018ac9a
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/and-or-implied-cond-not.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i1 @test_imply_not1(i32 %depth) {
+; CHECK-LABEL: define i1 @test_imply_not1(
+; CHECK-SAME: i32 [[DEPTH:%.*]]) {
+; CHECK-NEXT:    [[CMP1_NOT1:%.*]] = icmp eq i32 [[DEPTH]], 16
+; CHECK-NEXT:    call void @use(i1 [[CMP1_NOT1]])
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i32 [[DEPTH]], 8
+; CHECK-NEXT:    call void @use(i1 [[CMP2]])
+; CHECK-NEXT:    br i1 [[CMP1_NOT1]], label [[IF_ELSE:%.*]], label [[IF_THEN:%.*]]
+; CHECK:       if.then:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       if.else:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+;
+  %cmp1 = icmp eq i32 %depth, 16
+  call void @use(i1 %cmp1)
+  %cmp2 = icmp slt i32 %depth, 8
+  call void @use(i1 %cmp2)
+  %cmp.not = xor i1 %cmp1, true
+  %brmerge = or i1 %cmp2, %cmp.not
+  br i1 %brmerge, label %if.then, label %if.else
+if.then:
+  call void @func1()
+  unreachable
+
+if.else:
+  call void @func2()
+  unreachable
+}
+
+define i1 @test_imply_not2(i32 %a, i1 %cmp2) {
+; CHECK-LABEL: define i1 @test_imply_not2(
+; CHECK-SAME: i32 [[A:%.*]], i1 [[CMP2:%.*]]) {
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ne i32 [[A]], 0
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMP1]], i1 true, i1 [[CMP2]]
+; CHECK-NEXT:    ret i1 [[BRMERGE]]
+;
+  %cmp1 = icmp eq i32 %a, 0
+  %or.cond = select i1 %cmp1, i1 %cmp2, i1 false
+  %cmp.not = xor i1 %cmp1, true
+  %brmerge = or i1 %or.cond, %cmp.not
+  ret i1 %brmerge
+}
+
+define i1 @test_imply_not3(i32 %a, i32 %b, i1 %cond) {
+; CHECK-LABEL: define i1 @test_imply_not3(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT:    call void @use(i1 [[CMP1]])
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i32 [[A]], [[B]]
+; CHECK-NEXT:    [[AND:%.*]] = select i1 [[CMP2]], i1 [[COND]], i1 false
+; CHECK-NEXT:    ret i1 [[AND]]
+;
+  %cmp1 = icmp eq i32 %a, %b
+  call void @use(i1 %cmp1)
+  %cmp2 = icmp slt i32 %a, %b
+  %cmp.not = xor i1 %cmp1, true
+  %sel = select i1 %cmp.not, i1 %cond, i1 false
+  %and = and i1 %cmp2, %sel
+  ret i1 %and
+}
+
+declare void @func1()
+declare void @func2()
+declare void @use(i1)
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index d85241167d0cd8..cd2161d279dec1 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1354,9 +1354,8 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK:       pred.load.continue6:
 ; CHECK-NEXT:    [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
-; CHECK-NEXT:    [[TMP45:%.*]] = select <4 x i1> [[TMP19]], <4 x i1> [[TMP44]], <4 x i1> zeroinitializer
 ; CHECK-NEXT:    [[TMP46:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
-; CHECK-NEXT:    [[TMP47:%.*]] = or <4 x i1> [[TMP45]], [[TMP46]]
+; CHECK-NEXT:    [[TMP47:%.*]] = select <4 x i1> [[TMP46]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
 ; CHECK-NEXT:    [[TMP48:%.*]] = bitcast <4 x i1> [[TMP47]] to i4
 ; CHECK-NEXT:    [[TMP49:%.*]] = call i4 @llvm.ctpop.i4(i4 [[TMP48]]), !range [[RNG42:![0-9]+]]
 ; CHECK-NEXT:    [[TMP50:%.*]] = zext nneg i4 [[TMP49]] to i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/85397


More information about the llvm-commits mailing list