[llvm] [InstCombine] Drop poison-generating flags in `threadBinOpOverSelect` (PR #87230)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 1 04:38:13 PDT 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/87230

Alive2: https://alive2.llvm.org/ce/z/y_Jmdn
Fix https://github.com/llvm/llvm-project/issues/87042.

It is an alternative to #87075. Unfortunately I cannot move the whole function into InstCombine since it will cause some significant regressions. Therefore I just add a parameter `DropFlags` as `simplifyWithOpReplaced` does.


>From 9707684e7714b1d485a31906ff3e17d6f6f94ace Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 1 Apr 2024 19:20:07 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests for PR87042. NFC.

---
 llvm/test/Transforms/InstCombine/pr87042.ll | 49 +++++++++++++++++++++
 1 file changed, 49 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/pr87042.ll

diff --git a/llvm/test/Transforms/InstCombine/pr87042.ll b/llvm/test/Transforms/InstCombine/pr87042.ll
new file mode 100644
index 00000000000000..d9624faaedfa55
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr87042.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+define i64 @test_disjoint_or(i1 %cond, i64 %x) {
+; CHECK-LABEL: define i64 @test_disjoint_or(
+; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[OR2:%.*]] = or disjoint i64 [[X]], 7
+; CHECK-NEXT:    ret i64 [[OR2]]
+;
+  %or1 = or disjoint i64 %x, 7
+  %sel1 = select i1 %cond, i64 %or1, i64 %x
+  %or2 = or i64 %sel1, 7
+  ret i64 %or2
+}
+
+define i64 @test_or(i1 %cond, i64 %x) {
+; CHECK-LABEL: define i64 @test_or(
+; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[OR2:%.*]] = or i64 [[X]], 7
+; CHECK-NEXT:    ret i64 [[OR2]]
+;
+  %or1 = or i64 %x, 7
+  %sel1 = select i1 %cond, i64 %or1, i64 %x
+  %or2 = or i64 %sel1, 7
+  ret i64 %or2
+}
+
+define i64 @pr87042(i64 %x) {
+; CHECK-LABEL: define i64 @pr87042(
+; CHECK-SAME: i64 [[X:%.*]]) {
+; CHECK-NEXT:    [[AND1:%.*]] = and i64 [[X]], 65535
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[AND1]], 0
+; CHECK-NEXT:    [[OR1:%.*]] = or disjoint i64 [[X]], 7
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i64 [[OR1]], i64 [[X]]
+; CHECK-NEXT:    [[AND2:%.*]] = and i64 [[SEL1]], 16776960
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i64 [[AND2]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i64 [[OR1]], i64 [[SEL1]]
+; CHECK-NEXT:    ret i64 [[SEL2]]
+;
+  %and1 = and i64 %x, 65535
+  %cmp1 = icmp eq i64 %and1, 0
+  %or1 = or disjoint i64 %x, 7
+  %sel1 = select i1 %cmp1, i64 %or1, i64 %x
+  %and2 = and i64 %sel1, 16776960
+  %cmp2 = icmp eq i64 %and2, 0
+  %or2 = or i64 %sel1, 7
+  %sel2 = select i1 %cmp2, i64 %or2, i64 %sel1
+  ret i64 %sel2
+}

>From 900106267f135c4e3c5a6f744096e61dd94f3e44 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 1 Apr 2024 19:21:02 +0800
Subject: [PATCH 2/2] [InstCombine] Drop poison-generating flags in
 `threadBinOpOverSelect`

---
 .../include/llvm/Analysis/InstructionSimplify.h | 11 +++++++++++
 llvm/lib/Analysis/InstructionSimplify.cpp       | 17 ++++++++++++++++-
 .../InstCombine/InstructionCombining.cpp        |  9 +++++++++
 llvm/test/Transforms/InstCombine/pr87042.ll     |  4 ++--
 4 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 03d7ad12c12d8f..eacc9a25086cb7 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -264,6 +264,17 @@ simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                        const SimplifyQuery &Q, bool AllowRefinement,
                        SmallVectorImpl<Instruction *> *DropFlags = nullptr);
 
+/// In the case of a binary operation with a select instruction as an operand,
+/// try to simplify the binop by seeing whether evaluating it on both branches
+/// of the select results in the same value. Returns the common value if so,
+/// otherwise returns null.
+///
+/// If DropFlags is passed, then the result is only valid if
+/// poison-generating flags/metadata on the instruction are dropped.
+Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
+                             Value *RHS, const SimplifyQuery &Q,
+                             Instruction **DropFlags);
+
 /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
 ///
 /// This first performs a normal RAUW of I with SimpleV. It then recursively
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 9ff3faff799027..cee84793d43268 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -393,7 +393,8 @@ static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode,
 /// otherwise returns null.
 static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
                                     Value *RHS, const SimplifyQuery &Q,
-                                    unsigned MaxRecurse) {
+                                    unsigned MaxRecurse,
+                                    Instruction **DropFlags = nullptr) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
     return nullptr;
@@ -447,6 +448,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
       Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue();
       Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS;
       Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch;
+
+      if (Simplified->hasPoisonGeneratingFlags()) {
+        if (!DropFlags)
+          return nullptr;
+        *DropFlags = Simplified;
+      }
+
       if (Simplified->getOperand(0) == UnsimplifiedLHS &&
           Simplified->getOperand(1) == UnsimplifiedRHS)
         return Simplified;
@@ -460,6 +468,13 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
   return nullptr;
 }
 
+Value *llvm::threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
+                                   Value *RHS, const SimplifyQuery &Q,
+                                   Instruction **DropFlags) {
+  return ::threadBinOpOverSelect(Opcode, LHS, RHS, Q, RecursionLimit,
+                                 DropFlags);
+}
+
 /// In the case of a comparison with a select instruction, try to simplify the
 /// comparison by seeing whether both branches of the select result in the same
 /// value. Returns the common value if so, otherwise returns null.
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 7c40fb4fc86082..d3ac2762d64f4f 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1964,6 +1964,15 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
   if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
     if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
       return NewSel;
+
+    const SimplifyQuery SQ = getSimplifyQuery().getWithInstruction(&I);
+    Instruction *DropFlags = nullptr;
+    if (Value *V = threadBinOpOverSelect(I.getOpcode(), I.getOperand(0),
+                                         I.getOperand(1), SQ, &DropFlags)) {
+      if (DropFlags)
+        DropFlags->dropPoisonGeneratingFlags();
+      return replaceInstUsesWith(I, V);
+    }
   } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
     if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
       return NewPhi;
diff --git a/llvm/test/Transforms/InstCombine/pr87042.ll b/llvm/test/Transforms/InstCombine/pr87042.ll
index d9624faaedfa55..b3d6821e43c3d4 100644
--- a/llvm/test/Transforms/InstCombine/pr87042.ll
+++ b/llvm/test/Transforms/InstCombine/pr87042.ll
@@ -4,7 +4,7 @@
 define i64 @test_disjoint_or(i1 %cond, i64 %x) {
 ; CHECK-LABEL: define i64 @test_disjoint_or(
 ; CHECK-SAME: i1 [[COND:%.*]], i64 [[X:%.*]]) {
-; CHECK-NEXT:    [[OR2:%.*]] = or disjoint i64 [[X]], 7
+; CHECK-NEXT:    [[OR2:%.*]] = or i64 [[X]], 7
 ; CHECK-NEXT:    ret i64 [[OR2]]
 ;
   %or1 = or disjoint i64 %x, 7
@@ -30,7 +30,7 @@ define i64 @pr87042(i64 %x) {
 ; CHECK-SAME: i64 [[X:%.*]]) {
 ; CHECK-NEXT:    [[AND1:%.*]] = and i64 [[X]], 65535
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[AND1]], 0
-; CHECK-NEXT:    [[OR1:%.*]] = or disjoint i64 [[X]], 7
+; CHECK-NEXT:    [[OR1:%.*]] = or i64 [[X]], 7
 ; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i64 [[OR1]], i64 [[X]]
 ; CHECK-NEXT:    [[AND2:%.*]] = and i64 [[SEL1]], 16776960
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i64 [[AND2]], 0



More information about the llvm-commits mailing list