[llvm] [SelectOpt] Add handling for Select-like operations. (PR #77284)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 16 07:15:56 PST 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/77284

>From b62eef224349c7ec20561675f06a7e30d4478bb3 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 16 Jan 2024 14:59:39 +0000
Subject: [PATCH] [SelectOpt] Add handling for Select-like operations.

Some operations behave like selects. For example `or(zext(c), y)` is the same
as select(c, y|1, y)` and instcombine can canonicalize the select to the or
form. These operations can still be worthwhile converting to branch as opposed
to keeping as a select or Or instruction.

This patch attempts to add some basic handling for them, creating a SelectLike
abstraction in the select optimization pass. The backend can opt into handling
`or(zext(c),x)` as a select if it could be profitable, and the select
optimization pass attempts to handle them in much the same way as a
`select(c, x|1, x)`. The Or(x, 1) may need to be added as a new instruction,
generated as the or is converted to branches.

This helps fix a regression from selects being converted to or's recently.
---
 .../llvm/Analysis/TargetTransformInfo.h       |  10 +
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   9 +
 llvm/lib/Analysis/TargetTransformInfo.cpp     |   5 +
 llvm/lib/CodeGen/SelectOptimize.cpp           | 381 ++++++++++++------
 .../AArch64/AArch64TargetTransformInfo.h      |  12 +
 llvm/test/CodeGen/AArch64/selectopt.ll        |  30 +-
 6 files changed, 327 insertions(+), 120 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9697278eaeaee29..83f8b2d02ba4dc8 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -934,6 +934,12 @@ class TargetTransformInfo {
   /// Should the Select Optimization pass be enabled and ran.
   bool enableSelectOptimize() const;
 
+  /// Should the Select Optimization pass treat the given instruction like a
+  /// select, potentially converting it to a conditional branch. This can
+  /// include select-like instructions like or(zext(c), x) that can be converted
+  /// to selects.
+  bool shouldTreatInstructionLikeSelect(Instruction *I) const;
+
   /// Enable matching of interleaved access groups.
   bool enableInterleavedAccessVectorization() const;
 
@@ -1878,6 +1884,7 @@ class TargetTransformInfo::Concept {
   virtual MemCmpExpansionOptions
   enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const = 0;
   virtual bool enableSelectOptimize() = 0;
+  virtual bool shouldTreatInstructionLikeSelect(Instruction *I) = 0;
   virtual bool enableInterleavedAccessVectorization() = 0;
   virtual bool enableMaskedInterleavedAccessVectorization() = 0;
   virtual bool isFPVectorizationPotentiallyUnsafe() = 0;
@@ -2415,6 +2422,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   bool enableSelectOptimize() override {
     return Impl.enableSelectOptimize();
   }
+  bool shouldTreatInstructionLikeSelect(Instruction *I) override {
+    return Impl.shouldTreatInstructionLikeSelect(I);
+  }
   bool enableInterleavedAccessVectorization() override {
     return Impl.enableInterleavedAccessVectorization();
   }
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 60eab53fa2f6019..3e0c295cae295d6 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -378,6 +378,15 @@ class TargetTransformInfoImplBase {
 
   bool enableSelectOptimize() const { return true; }
 
+  bool shouldTreatInstructionLikeSelect(Instruction *I) {
+    // If the select is a logical-and/logical-or then it is better treated as a
+    // and/or by the backend.
+    using namespace llvm::PatternMatch;
+    return isa<SelectInst>(I) &&
+           !match(I, m_CombineOr(m_LogicalAnd(m_Value(), m_Value()),
+                                 m_LogicalOr(m_Value(), m_Value())));
+  }
+
   bool enableInterleavedAccessVectorization() const { return false; }
 
   bool enableMaskedInterleavedAccessVectorization() const { return false; }
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a5a18a538d76899..6a0c842f62591c6 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -604,6 +604,11 @@ bool TargetTransformInfo::enableSelectOptimize() const {
   return TTIImpl->enableSelectOptimize();
 }
 
+bool TargetTransformInfo::shouldTreatInstructionLikeSelect(
+    Instruction *I) const {
+  return TTIImpl->shouldTreatInstructionLikeSelect(I);
+}
+
 bool TargetTransformInfo::enableInterleavedAccessVectorization() const {
   return TTIImpl->enableInterleavedAccessVectorization();
 }
diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 1316919e65dacc7..5e60e31c958cda3 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -42,6 +42,7 @@
 #include <stack>
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 #define DEBUG_TYPE "select-optimize"
 
@@ -114,12 +115,6 @@ class SelectOptimizeImpl {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
   bool runOnFunction(Function &F, Pass &P);
 
-private:
-  // Select groups consist of consecutive select instructions with the same
-  // condition.
-  using SelectGroup = SmallVector<SelectInst *, 2>;
-  using SelectGroups = SmallVector<SelectGroup, 2>;
-
   using Scaled64 = ScaledNumber<uint64_t>;
 
   struct CostInfo {
@@ -129,6 +124,146 @@ class SelectOptimizeImpl {
     Scaled64 NonPredCost;
   };
 
+  /// SelectLike is an abstraction over SelectInst and other operations that can
+  /// act like selects. For example Or(Zext(icmp), X) can be treated like
+  /// select(icmp, X|1, X).
+  class SelectLike {
+  private:
+    SelectLike(Instruction *I) : I(I) {}
+
+    Instruction *I;
+
+  public:
+    /// Match a select or select-like instruction, returning a SelectLike.
+    static SelectLike match(Instruction *I) {
+      // Select instruction are what we are usually looking for.
+      if (isa<SelectInst>(I))
+        return SelectLike(I);
+
+      // An Or(zext(i1 X), Y) can also be treated like a select, with condition
+      // C and values Y|1 and Y.
+      Value *X;
+      if (PatternMatch::match(
+              I, m_c_Or(m_OneUse(m_ZExt(m_Value(X))), m_Value())) &&
+          X->getType()->isIntegerTy(1))
+        return SelectLike(I);
+
+      return SelectLike(nullptr);
+    }
+
+    bool isValid() { return I; }
+    operator bool() { return isValid(); }
+
+    Instruction *getI() { return I; }
+    const Instruction *getI() const { return I; }
+
+    Type *getType() const { return I->getType(); }
+
+    /// Return the condition for the SelectLike instruction. For example the
+    /// condition of a select or c in `or(zext(c), x)`
+    Value *getCondition() const {
+      if (auto *Sel = dyn_cast<SelectInst>(I))
+        return Sel->getCondition();
+      // Or(zext) case
+      if (auto *BO = dyn_cast<BinaryOperator>(I)) {
+        Value *X;
+        if (PatternMatch::match(BO->getOperand(0),
+                                m_OneUse(m_ZExt(m_Value(X)))))
+          return X;
+        if (PatternMatch::match(BO->getOperand(1),
+                                m_OneUse(m_ZExt(m_Value(X)))))
+          return X;
+      }
+
+      llvm_unreachable("Unhandled case in getCondition");
+    }
+
+    /// Return the true value for the SelectLike instruction. Note this may not
+    /// exist for all SelectLike instructions. For example, for `or(zext(c), x)`
+    /// the true value would be `or(x,1)`. As this value does not exist, nullptr
+    /// is returned.
+    Value *getTrueValue() const {
+      if (auto *Sel = dyn_cast<SelectInst>(I))
+        return Sel->getTrueValue();
+      // Or(zext) case - The true value is Or(X), so return nullptr as the value
+      // does not yet exist.
+      if (isa<BinaryOperator>(I))
+        return nullptr;
+
+      llvm_unreachable("Unhandled case in getTrueValue");
+    }
+
+    /// Return the false value for the SelectLike instruction. For example the
+    /// getFalseValue of a select or `x` in `or(zext(c), x)` (which is
+    /// `select(c, x|1, x)`)
+    Value *getFalseValue() const {
+      if (auto *Sel = dyn_cast<SelectInst>(I))
+        return Sel->getFalseValue();
+      // Or(zext) case - return the operand which is not the zext.
+      if (auto *BO = dyn_cast<BinaryOperator>(I)) {
+        Value *X;
+        if (PatternMatch::match(BO->getOperand(0),
+                                m_OneUse(m_ZExt(m_Value(X)))))
+          return BO->getOperand(1);
+        if (PatternMatch::match(BO->getOperand(1),
+                                m_OneUse(m_ZExt(m_Value(X)))))
+          return BO->getOperand(0);
+      }
+
+      llvm_unreachable("Unhandled case in getFalseValue");
+    }
+
+    /// Return the NonPredCost cost of the true op, given the costs in
+    /// InstCostMap. This may need to be generated for select-like instructions.
+    Scaled64 getTrueOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
+                           const TargetTransformInfo *TTI) {
+      if (auto *Sel = dyn_cast<SelectInst>(I))
+        if (auto *I = dyn_cast<Instruction>(Sel->getTrueValue()))
+          return InstCostMap.contains(I) ? InstCostMap[I].NonPredCost
+                                         : Scaled64::getZero();
+
+      // Or case - add the cost of an extra Or to the cost of the False case.
+      if (isa<BinaryOperator>(I))
+        if (auto I = dyn_cast<Instruction>(getFalseValue()))
+          if (InstCostMap.contains(I)) {
+            InstructionCost OrCost = TTI->getArithmeticInstrCost(
+                Instruction::Or, I->getType(), TargetTransformInfo::TCK_Latency,
+                {TargetTransformInfo::OK_AnyValue,
+                 TargetTransformInfo::OP_None},
+                {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
+            return InstCostMap[I].NonPredCost +
+                   Scaled64::get(*OrCost.getValue());
+          }
+
+      return Scaled64::getZero();
+    }
+
+    /// Return the NonPredCost cost of the false op, given the costs in
+    /// InstCostMap. This may need to be generated for select-like instructions.
+    Scaled64
+    getFalseOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
+                   const TargetTransformInfo *TTI) {
+      if (auto *Sel = dyn_cast<SelectInst>(I))
+        if (auto *I = dyn_cast<Instruction>(Sel->getFalseValue()))
+          return InstCostMap.contains(I) ? InstCostMap[I].NonPredCost
+                                         : Scaled64::getZero();
+
+      // Or case - return the cost of the false case
+      if (isa<BinaryOperator>(I))
+        if (auto I = dyn_cast<Instruction>(getFalseValue()))
+          if (InstCostMap.contains(I))
+            return InstCostMap[I].NonPredCost;
+
+      return Scaled64::getZero();
+    }
+  };
+
+private:
+  // Select groups consist of consecutive select instructions with the same
+  // condition.
+  using SelectGroup = SmallVector<SelectLike, 2>;
+  using SelectGroups = SmallVector<SelectGroup, 2>;
+
   // Converts select instructions of a function to conditional jumps when deemed
   // profitable. Returns true if at least one select was converted.
   bool optimizeSelects(Function &F);
@@ -156,12 +291,12 @@ class SelectOptimizeImpl {
 
   // Determines if a select group should be converted to a branch (base
   // heuristics).
-  bool isConvertToBranchProfitableBase(const SmallVector<SelectInst *, 2> &ASI);
+  bool isConvertToBranchProfitableBase(const SelectGroup &ASI);
 
   // Returns true if there are expensive instructions in the cold value
   // operand's (if any) dependence slice of any of the selects of the given
   // group.
-  bool hasExpensiveColdOperand(const SmallVector<SelectInst *, 2> &ASI);
+  bool hasExpensiveColdOperand(const SelectGroup &ASI);
 
   // For a given source instruction, collect its backwards dependence slice
   // consisting of instructions exclusively computed for producing the operands
@@ -170,7 +305,7 @@ class SelectOptimizeImpl {
                              Instruction *SI, bool ForSinking = false);
 
   // Returns true if the condition of the select is highly predictable.
-  bool isSelectHighlyPredictable(const SelectInst *SI);
+  bool isSelectHighlyPredictable(SelectLike SI);
 
   // Loop-level checks to determine if a non-predicated version (with branches)
   // of the given loop is more profitable than its predicated version.
@@ -189,14 +324,14 @@ class SelectOptimizeImpl {
   std::optional<uint64_t> computeInstCost(const Instruction *I);
 
   // Returns the misprediction cost of a given select when converted to branch.
-  Scaled64 getMispredictionCost(const SelectInst *SI, const Scaled64 CondCost);
+  Scaled64 getMispredictionCost(SelectLike SI, const Scaled64 CondCost);
 
   // Returns the cost of a branch when the prediction is correct.
   Scaled64 getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
-                                const SelectInst *SI);
+                                SelectLike SI);
 
   // Returns true if the target architecture supports lowering a given select.
-  bool isSelectKindSupported(SelectInst *SI);
+  bool isSelectKindSupported(SelectLike SI);
 };
 
 class SelectOptimize : public FunctionPass {
@@ -368,15 +503,26 @@ void SelectOptimizeImpl::optimizeSelectsInnerLoops(Function &F,
 /// select instructions in \p Selects, look through the defining select
 /// instruction until the true/false value is not defined in \p Selects.
 static Value *
-getTrueOrFalseValue(SelectInst *SI, bool isTrue,
-                    const SmallPtrSet<const Instruction *, 2> &Selects) {
+getTrueOrFalseValue(SelectOptimizeImpl::SelectLike SI, bool isTrue,
+                    const SmallPtrSet<const Instruction *, 2> &Selects,
+                    IRBuilder<> &IB) {
   Value *V = nullptr;
-  for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
+  for (SelectInst *DefSI = dyn_cast<SelectInst>(SI.getI());
+       DefSI != nullptr && Selects.count(DefSI);
        DefSI = dyn_cast<SelectInst>(V)) {
-    assert(DefSI->getCondition() == SI->getCondition() &&
+    assert(DefSI->getCondition() == SI.getCondition() &&
            "The condition of DefSI does not match with SI");
     V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
   }
+
+  if (isa<BinaryOperator>(SI.getI())) {
+    assert(SI.getI()->getOpcode() == Instruction::Or &&
+           "Only currently handling Or instructions.");
+    V = SI.getFalseValue();
+    if (isTrue)
+      V = IB.CreateOr(V, ConstantInt::get(V->getType(), 1));
+  }
+
   assert(V && "Failed to get select true/false value");
   return V;
 }
@@ -424,20 +570,22 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     SmallVector<std::stack<Instruction *>, 2> TrueSlices, FalseSlices;
     typedef std::stack<Instruction *>::size_type StackSizeType;
     StackSizeType maxTrueSliceLen = 0, maxFalseSliceLen = 0;
-    for (SelectInst *SI : ASI) {
+    for (SelectLike SI : ASI) {
       // For each select, compute the sinkable dependence chains of the true and
       // false operands.
-      if (auto *TI = dyn_cast<Instruction>(SI->getTrueValue())) {
+      if (auto *TI = dyn_cast_or_null<Instruction>(SI.getTrueValue())) {
         std::stack<Instruction *> TrueSlice;
-        getExclBackwardsSlice(TI, TrueSlice, SI, true);
+        getExclBackwardsSlice(TI, TrueSlice, SI.getI(), true);
         maxTrueSliceLen = std::max(maxTrueSliceLen, TrueSlice.size());
         TrueSlices.push_back(TrueSlice);
       }
-      if (auto *FI = dyn_cast<Instruction>(SI->getFalseValue())) {
-        std::stack<Instruction *> FalseSlice;
-        getExclBackwardsSlice(FI, FalseSlice, SI, true);
-        maxFalseSliceLen = std::max(maxFalseSliceLen, FalseSlice.size());
-        FalseSlices.push_back(FalseSlice);
+      if (auto *FI = dyn_cast_or_null<Instruction>(SI.getFalseValue())) {
+        if (isa<SelectInst>(SI.getI()) || !FI->hasOneUse()) {
+          std::stack<Instruction *> FalseSlice;
+          getExclBackwardsSlice(FI, FalseSlice, SI.getI(), true);
+          maxFalseSliceLen = std::max(maxFalseSliceLen, FalseSlice.size());
+          FalseSlices.push_back(FalseSlice);
+        }
       }
     }
     // In the case of multiple select instructions in the same group, the order
@@ -469,10 +617,10 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     }
 
     // We split the block containing the select(s) into two blocks.
-    SelectInst *SI = ASI.front();
-    SelectInst *LastSI = ASI.back();
-    BasicBlock *StartBlock = SI->getParent();
-    BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
+    SelectLike SI = ASI.front();
+    SelectLike LastSI = ASI.back();
+    BasicBlock *StartBlock = SI.getI()->getParent();
+    BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI.getI()));
     BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
     BFI->setBlockFreq(EndBlock, BFI->getBlockFreq(StartBlock));
     // Delete the unconditional branch that was just created by the split.
@@ -481,8 +629,8 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     // Move any debug/pseudo instructions that were in-between the select
     // group to the newly-created end block.
     SmallVector<Instruction *, 2> DebugPseudoINS;
-    auto DIt = SI->getIterator();
-    while (&*DIt != LastSI) {
+    auto DIt = SI.getI()->getIterator();
+    while (&*DIt != LastSI.getI()) {
       if (DIt->isDebugOrPseudoInst())
         DebugPseudoINS.push_back(&*DIt);
       DIt++;
@@ -496,18 +644,19 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr;
     BranchInst *TrueBranch = nullptr, *FalseBranch = nullptr;
     if (!TrueSlicesInterleaved.empty()) {
-      TrueBlock = BasicBlock::Create(LastSI->getContext(), "select.true.sink",
+      TrueBlock = BasicBlock::Create(EndBlock->getContext(), "select.true.sink",
                                      EndBlock->getParent(), EndBlock);
       TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
-      TrueBranch->setDebugLoc(LastSI->getDebugLoc());
+      TrueBranch->setDebugLoc(LastSI.getI()->getDebugLoc());
       for (Instruction *TrueInst : TrueSlicesInterleaved)
         TrueInst->moveBefore(TrueBranch);
     }
     if (!FalseSlicesInterleaved.empty()) {
-      FalseBlock = BasicBlock::Create(LastSI->getContext(), "select.false.sink",
-                                      EndBlock->getParent(), EndBlock);
+      FalseBlock =
+          BasicBlock::Create(EndBlock->getContext(), "select.false.sink",
+                             EndBlock->getParent(), EndBlock);
       FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
-      FalseBranch->setDebugLoc(LastSI->getDebugLoc());
+      FalseBranch->setDebugLoc(LastSI.getI()->getDebugLoc());
       for (Instruction *FalseInst : FalseSlicesInterleaved)
         FalseInst->moveBefore(FalseBranch);
     }
@@ -517,10 +666,10 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
       assert(TrueBlock == nullptr &&
              "Unexpected basic block transform while optimizing select");
 
-      FalseBlock = BasicBlock::Create(SI->getContext(), "select.false",
+      FalseBlock = BasicBlock::Create(StartBlock->getContext(), "select.false",
                                       EndBlock->getParent(), EndBlock);
       auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
-      FalseBranch->setDebugLoc(SI->getDebugLoc());
+      FalseBranch->setDebugLoc(SI.getI()->getDebugLoc());
     }
 
     // Insert the real conditional branch based on the original condition.
@@ -541,44 +690,36 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
       TT = TrueBlock;
       FT = FalseBlock;
     }
-    IRBuilder<> IB(SI);
-    auto *CondFr =
-        IB.CreateFreeze(SI->getCondition(), SI->getName() + ".frozen");
-    IB.CreateCondBr(CondFr, TT, FT, SI);
+    IRBuilder<> IB(SI.getI());
+    auto *CondFr = IB.CreateFreeze(SI.getCondition(),
+                                   SI.getCondition()->getName() + ".frozen");
 
     SmallPtrSet<const Instruction *, 2> INS;
-    INS.insert(ASI.begin(), ASI.end());
+    for (auto SI : ASI)
+      INS.insert(SI.getI());
+
     // Use reverse iterator because later select may use the value of the
     // earlier select, and we need to propagate value through earlier select
     // to get the PHI operand.
     for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
-      SelectInst *SI = *It;
+      SelectLike SI = *It;
       // The select itself is replaced with a PHI Node.
-      PHINode *PN = PHINode::Create(SI->getType(), 2, "");
+      PHINode *PN = PHINode::Create(SI.getType(), 2, "");
       PN->insertBefore(EndBlock->begin());
-      PN->takeName(SI);
-      PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
-      PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
-      PN->setDebugLoc(SI->getDebugLoc());
-
-      SI->replaceAllUsesWith(PN);
-      SI->eraseFromParent();
-      INS.erase(SI);
+      PN->takeName(SI.getI());
+      PN->addIncoming(getTrueOrFalseValue(SI, true, INS, IB), TrueBlock);
+      PN->addIncoming(getTrueOrFalseValue(SI, false, INS, IB), FalseBlock);
+      PN->setDebugLoc(SI.getI()->getDebugLoc());
+      SI.getI()->replaceAllUsesWith(PN);
+      INS.erase(SI.getI());
       ++NumSelectsConverted;
     }
-  }
-}
-
-static bool isSpecialSelect(SelectInst *SI) {
-  using namespace llvm::PatternMatch;
+    IB.CreateCondBr(CondFr, TT, FT, SI.getI());
 
-  // If the select is a logical-and/logical-or then it is better treated as a
-  // and/or by the backend.
-  if (match(SI, m_CombineOr(m_LogicalAnd(m_Value(), m_Value()),
-                            m_LogicalOr(m_Value(), m_Value()))))
-    return true;
-
-  return false;
+    // Remove the old select instructions, now that they are not longer used.
+    for (auto SI : ASI)
+      SI.getI()->eraseFromParent();
+  }
 }
 
 void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
@@ -586,22 +727,30 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
   BasicBlock::iterator BBIt = BB.begin();
   while (BBIt != BB.end()) {
     Instruction *I = &*BBIt++;
-    if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
-      if (isSpecialSelect(SI))
+    if (SelectLike SI = SelectLike::match(I)) {
+      if (!TTI->shouldTreatInstructionLikeSelect(I))
         continue;
 
       SelectGroup SIGroup;
       SIGroup.push_back(SI);
       while (BBIt != BB.end()) {
         Instruction *NI = &*BBIt;
-        SelectInst *NSI = dyn_cast<SelectInst>(NI);
-        if (NSI && SI->getCondition() == NSI->getCondition()) {
+        // Debug/pseudo instructions should be skipped and not prevent the
+        // formation of a select group.
+        if (NI->isDebugOrPseudoInst()) {
+          ++BBIt;
+          continue;
+        }
+        // We only allow selects in the same group, not other select-like
+        // instructions.
+        if (!isa<SelectInst>(NI))
+          break;
+
+        SelectLike NSI = SelectLike::match(NI);
+        if (NSI && SI.getCondition() == NSI.getCondition()) {
           SIGroup.push_back(NSI);
-        } else if (!NI->isDebugOrPseudoInst()) {
-          // Debug/pseudo instructions should be skipped and not prevent the
-          // formation of a select group.
+        } else
           break;
-        }
         ++BBIt;
       }
 
@@ -655,12 +804,12 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops(
     // Assuming infinite resources, the cost of a group of instructions is the
     // cost of the most expensive instruction of the group.
     Scaled64 SelectCost = Scaled64::getZero(), BranchCost = Scaled64::getZero();
-    for (SelectInst *SI : ASI) {
-      SelectCost = std::max(SelectCost, InstCostMap[SI].PredCost);
-      BranchCost = std::max(BranchCost, InstCostMap[SI].NonPredCost);
+    for (SelectLike SI : ASI) {
+      SelectCost = std::max(SelectCost, InstCostMap[SI.getI()].PredCost);
+      BranchCost = std::max(BranchCost, InstCostMap[SI.getI()].NonPredCost);
     }
     if (BranchCost < SelectCost) {
-      OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", ASI.front());
+      OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", ASI.front().getI());
       OR << "Profitable to convert to branch (loop analysis). BranchCost="
          << BranchCost.toString() << ", SelectCost=" << SelectCost.toString()
          << ". ";
@@ -668,7 +817,8 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops(
       ++NumSelectConvertedLoop;
       ProfSIGroups.push_back(ASI);
     } else {
-      OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", ASI.front());
+      OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti",
+                                      ASI.front().getI());
       ORmiss << "Select is more profitable (loop analysis). BranchCost="
              << BranchCost.toString()
              << ", SelectCost=" << SelectCost.toString() << ". ";
@@ -678,14 +828,15 @@ void SelectOptimizeImpl::findProfitableSIGroupsInnerLoops(
 }
 
 bool SelectOptimizeImpl::isConvertToBranchProfitableBase(
-    const SmallVector<SelectInst *, 2> &ASI) {
-  SelectInst *SI = ASI.front();
-  LLVM_DEBUG(dbgs() << "Analyzing select group containing " << *SI << "\n");
-  OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", SI);
-  OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", SI);
+    const SelectGroup &ASI) {
+  SelectLike SI = ASI.front();
+  LLVM_DEBUG(dbgs() << "Analyzing select group containing " << SI.getI()
+                    << "\n");
+  OptimizationRemark OR(DEBUG_TYPE, "SelectOpti", SI.getI());
+  OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", SI.getI());
 
   // Skip cold basic blocks. Better to optimize for size for cold blocks.
-  if (PSI->isColdBlock(SI->getParent(), BFI)) {
+  if (PSI->isColdBlock(SI.getI()->getParent(), BFI)) {
     ++NumSelectColdBB;
     ORmiss << "Not converted to branch because of cold basic block. ";
     EmitAndPrintRemark(ORE, ORmiss);
@@ -693,7 +844,7 @@ bool SelectOptimizeImpl::isConvertToBranchProfitableBase(
   }
 
   // If unpredictable, branch form is less profitable.
-  if (SI->getMetadata(LLVMContext::MD_unpredictable)) {
+  if (SI.getI()->getMetadata(LLVMContext::MD_unpredictable)) {
     ++NumSelectUnPred;
     ORmiss << "Not converted to branch because of unpredictable branch. ";
     EmitAndPrintRemark(ORE, ORmiss);
@@ -728,17 +879,24 @@ static InstructionCost divideNearest(InstructionCost Numerator,
   return (Numerator + (Denominator / 2)) / Denominator;
 }
 
-bool SelectOptimizeImpl::hasExpensiveColdOperand(
-    const SmallVector<SelectInst *, 2> &ASI) {
+static bool extractBranchWeights(SelectOptimizeImpl::SelectLike SI,
+                                 uint64_t &TrueVal, uint64_t &FalseVal) {
+  if (isa<SelectInst>(SI.getI()))
+    return extractBranchWeights(*SI.getI(), TrueVal, FalseVal);
+  return false;
+}
+
+bool SelectOptimizeImpl::hasExpensiveColdOperand(const SelectGroup &ASI) {
   bool ColdOperand = false;
   uint64_t TrueWeight, FalseWeight, TotalWeight;
-  if (extractBranchWeights(*ASI.front(), TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(ASI.front(), TrueWeight, FalseWeight)) {
     uint64_t MinWeight = std::min(TrueWeight, FalseWeight);
     TotalWeight = TrueWeight + FalseWeight;
     // Is there a path with frequency <ColdOperandThreshold% (default:20%) ?
     ColdOperand = TotalWeight * ColdOperandThreshold > 100 * MinWeight;
   } else if (PSI->hasProfileSummary()) {
-    OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti", ASI.front());
+    OptimizationRemarkMissed ORmiss(DEBUG_TYPE, "SelectOpti",
+                                    ASI.front().getI());
     ORmiss << "Profile data available but missing branch-weights metadata for "
               "select instruction. ";
     EmitAndPrintRemark(ORE, ORmiss);
@@ -747,19 +905,19 @@ bool SelectOptimizeImpl::hasExpensiveColdOperand(
     return false;
   // Check if the cold path's dependence slice is expensive for any of the
   // selects of the group.
-  for (SelectInst *SI : ASI) {
+  for (SelectLike SI : ASI) {
     Instruction *ColdI = nullptr;
     uint64_t HotWeight;
     if (TrueWeight < FalseWeight) {
-      ColdI = dyn_cast<Instruction>(SI->getTrueValue());
+      ColdI = dyn_cast_or_null<Instruction>(SI.getTrueValue());
       HotWeight = FalseWeight;
     } else {
-      ColdI = dyn_cast<Instruction>(SI->getFalseValue());
+      ColdI = dyn_cast_or_null<Instruction>(SI.getFalseValue());
       HotWeight = TrueWeight;
     }
     if (ColdI) {
       std::stack<Instruction *> ColdSlice;
-      getExclBackwardsSlice(ColdI, ColdSlice, SI);
+      getExclBackwardsSlice(ColdI, ColdSlice, SI.getI());
       InstructionCost SliceCost = 0;
       while (!ColdSlice.empty()) {
         SliceCost += TTI->getInstructionCost(ColdSlice.top(),
@@ -849,9 +1007,9 @@ void SelectOptimizeImpl::getExclBackwardsSlice(Instruction *I,
   }
 }
 
-bool SelectOptimizeImpl::isSelectHighlyPredictable(const SelectInst *SI) {
+bool SelectOptimizeImpl::isSelectHighlyPredictable(SelectLike SI) {
   uint64_t TrueWeight, FalseWeight;
-  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(SI, TrueWeight, FalseWeight)) {
     uint64_t Max = std::max(TrueWeight, FalseWeight);
     uint64_t Sum = TrueWeight + FalseWeight;
     if (Sum != 0) {
@@ -945,7 +1103,7 @@ bool SelectOptimizeImpl::computeLoopCosts(
     // Cost of the loop's critical path.
     CostInfo &MaxCost = LoopCost[Iter];
     for (BasicBlock *BB : L->getBlocks()) {
-      for (const Instruction &I : *BB) {
+      for (Instruction &I : *BB) {
         if (I.isDebugOrPseudoInst())
           continue;
         // Compute the predicated and non-predicated cost of the instruction.
@@ -983,21 +1141,16 @@ bool SelectOptimizeImpl::computeLoopCosts(
         // PredictedPathCost = TrueOpCost * TrueProb + FalseOpCost * FalseProb
         // MispredictCost = max(MispredictPenalty, CondCost) * MispredictRate
         if (SIset.contains(&I)) {
-          auto SI = cast<SelectInst>(&I);
-
-          Scaled64 TrueOpCost = Scaled64::getZero(),
-                   FalseOpCost = Scaled64::getZero();
-          if (auto *TI = dyn_cast<Instruction>(SI->getTrueValue()))
-            if (InstCostMap.count(TI))
-              TrueOpCost = InstCostMap[TI].NonPredCost;
-          if (auto *FI = dyn_cast<Instruction>(SI->getFalseValue()))
-            if (InstCostMap.count(FI))
-              FalseOpCost = InstCostMap[FI].NonPredCost;
+          auto SI = SelectLike::match(&I);
+          assert(SI && "Expected to match an existing SelectLike");
+
+          Scaled64 TrueOpCost = SI.getTrueOpCost(InstCostMap, TTI);
+          Scaled64 FalseOpCost = SI.getFalseOpCost(InstCostMap, TTI);
           Scaled64 PredictedPathCost =
               getPredictedPathCost(TrueOpCost, FalseOpCost, SI);
 
           Scaled64 CondCost = Scaled64::getZero();
-          if (auto *CI = dyn_cast<Instruction>(SI->getCondition()))
+          if (auto *CI = dyn_cast<Instruction>(SI.getCondition()))
             if (InstCostMap.count(CI))
               CondCost = InstCostMap[CI].NonPredCost;
           Scaled64 MispredictCost = getMispredictionCost(SI, CondCost);
@@ -1023,8 +1176,8 @@ SmallPtrSet<const Instruction *, 2>
 SelectOptimizeImpl::getSIset(const SelectGroups &SIGroups) {
   SmallPtrSet<const Instruction *, 2> SIset;
   for (const SelectGroup &ASI : SIGroups)
-    for (const SelectInst *SI : ASI)
-      SIset.insert(SI);
+    for (SelectLike SI : ASI)
+      SIset.insert(SI.getI());
   return SIset;
 }
 
@@ -1038,7 +1191,7 @@ SelectOptimizeImpl::computeInstCost(const Instruction *I) {
 }
 
 ScaledNumber<uint64_t>
-SelectOptimizeImpl::getMispredictionCost(const SelectInst *SI,
+SelectOptimizeImpl::getMispredictionCost(SelectLike SI,
                                          const Scaled64 CondCost) {
   uint64_t MispredictPenalty = TSchedModel.getMCSchedModel()->MispredictPenalty;
 
@@ -1065,10 +1218,10 @@ SelectOptimizeImpl::getMispredictionCost(const SelectInst *SI,
 // TrueCost * TrueProbability + FalseCost * FalseProbability.
 ScaledNumber<uint64_t>
 SelectOptimizeImpl::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
-                                         const SelectInst *SI) {
+                                         SelectLike SI) {
   Scaled64 PredPathCost;
   uint64_t TrueWeight, FalseWeight;
-  if (extractBranchWeights(*SI, TrueWeight, FalseWeight)) {
+  if (extractBranchWeights(SI, TrueWeight, FalseWeight)) {
     uint64_t SumWeight = TrueWeight + FalseWeight;
     if (SumWeight != 0) {
       PredPathCost = TrueCost * Scaled64::get(TrueWeight) +
@@ -1085,12 +1238,12 @@ SelectOptimizeImpl::getPredictedPathCost(Scaled64 TrueCost, Scaled64 FalseCost,
   return PredPathCost;
 }
 
-bool SelectOptimizeImpl::isSelectKindSupported(SelectInst *SI) {
-  bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
+bool SelectOptimizeImpl::isSelectKindSupported(SelectLike SI) {
+  bool VectorCond = !SI.getCondition()->getType()->isIntegerTy(1);
   if (VectorCond)
     return false;
   TargetLowering::SelectSupportKind SelectKind;
-  if (SI->getType()->isVectorTy())
+  if (SI.getType()->isVectorTy())
     SelectKind = TargetLowering::ScalarCondVectorVal;
   else
     SelectKind = TargetLowering::ScalarValSelect;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index f471294ffc25207..1012e4bb860ffc8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -412,6 +412,18 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   bool enableSelectOptimize() { return ST->enableSelectOptimize(); }
 
+  bool shouldTreatInstructionLikeSelect(Instruction *I) {
+    // For the binary operators (e.g. or) we need to be more careful than
+    // selects, here we only transform them if they are already at a natural
+    // break point in the code - the end of a block with an unconditional
+    // terminator.
+    if (I->getOpcode() == Instruction::Or &&
+        isa<BranchInst>(I->getNextNode()) &&
+        cast<BranchInst>(I->getNextNode())->isUnconditional())
+      return true;
+    return BaseT::shouldTreatInstructionLikeSelect(I);
+  }
+
   unsigned getStoreMinimumVF(unsigned VF, Type *ScalarMemTy,
                              Type *ScalarValTy) const {
     // We can vectorize store v4i8.
diff --git a/llvm/test/CodeGen/AArch64/selectopt.ll b/llvm/test/CodeGen/AArch64/selectopt.ll
index f2b86dcb9211127..acf59fa5fe4ce6c 100644
--- a/llvm/test/CodeGen/AArch64/selectopt.ll
+++ b/llvm/test/CodeGen/AArch64/selectopt.ll
@@ -341,10 +341,16 @@ define void @replace_or(ptr nocapture noundef %newst, ptr noundef %t, ptr nounde
 ; CHECKOO-NEXT:    [[TMP8:%.*]] = load i64, ptr [[FLOW83]], align 8
 ; CHECKOO-NEXT:    [[CMP84:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
 ; CHECKOO-NEXT:    [[ADD:%.*]] = zext i1 [[CMP84]] to i64
-; CHECKOO-NEXT:    [[SPEC_SELECT:%.*]] = or disjoint i64 [[MUL]], [[ADD]]
+; CHECKOO-NEXT:    [[CMP84_FROZEN:%.*]] = freeze i1 [[CMP84]]
+; CHECKOO-NEXT:    [[TMP9:%.*]] = or i64 [[MUL]], 1
+; CHECKOO-NEXT:    br i1 [[CMP84_FROZEN]], label [[SELECT_END:%.*]], label [[SELECT_FALSE:%.*]]
+; CHECKOO:       select.false:
+; CHECKOO-NEXT:    br label [[SELECT_END]]
+; CHECKOO:       select.end:
+; CHECKOO-NEXT:    [[SPEC_SELECT:%.*]] = phi i64 [ [[TMP9]], [[IF_THEN]] ], [ [[MUL]], [[SELECT_FALSE]] ]
 ; CHECKOO-NEXT:    br label [[IF_END87]]
 ; CHECKOO:       if.end87:
-; CHECKOO-NEXT:    [[CMP_1]] = phi i64 [ [[MUL]], [[WHILE_BODY]] ], [ [[SPEC_SELECT]], [[IF_THEN]] ]
+; CHECKOO-NEXT:    [[CMP_1]] = phi i64 [ [[MUL]], [[WHILE_BODY]] ], [ [[SPEC_SELECT]], [[SELECT_END]] ]
 ; CHECKOO-NEXT:    [[CMP16_NOT:%.*]] = icmp sgt i64 [[CMP_1]], [[MA]]
 ; CHECKOO-NEXT:    br i1 [[CMP16_NOT]], label [[WHILE_END]], label [[LAND_RHS]]
 ; CHECKOO:       while.end:
@@ -663,10 +669,16 @@ define i32 @or_samegroup(ptr nocapture noundef %x, i32 noundef %n, ptr nocapture
 ; CHECKOO-NEXT:    br label [[SELECT_END]]
 ; CHECKOO:       select.end:
 ; CHECKOO-NEXT:    [[SEL:%.*]] = phi i32 [ [[ADD]], [[IF_THEN]] ], [ 1, [[SELECT_FALSE]] ]
-; CHECKOO-NEXT:    [[OR:%.*]] = or i32 [[CONV]], [[SEL]]
+; CHECKOO-NEXT:    [[CMP5_FROZEN3:%.*]] = freeze i1 [[CMP5]]
+; CHECKOO-NEXT:    [[TMP2:%.*]] = or i32 [[SEL]], 1
+; CHECKOO-NEXT:    br i1 [[CMP5_FROZEN3]], label [[SELECT_END1:%.*]], label [[SELECT_FALSE2:%.*]]
+; CHECKOO:       select.false2:
+; CHECKOO-NEXT:    br label [[SELECT_END1]]
+; CHECKOO:       select.end1:
+; CHECKOO-NEXT:    [[OR:%.*]] = phi i32 [ [[TMP2]], [[SELECT_END]] ], [ [[SEL]], [[SELECT_FALSE2]] ]
 ; CHECKOO-NEXT:    br label [[IF_END]]
 ; CHECKOO:       if.end:
-; CHECKOO-NEXT:    [[Y_1]] = phi i32 [ [[SEL]], [[SELECT_END]] ], [ 0, [[FOR_BODY]] ]
+; CHECKOO-NEXT:    [[Y_1]] = phi i32 [ [[SEL]], [[SELECT_END1]] ], [ 0, [[FOR_BODY]] ]
 ; CHECKOO-NEXT:    store i32 [[Y_1]], ptr [[ARRAYIDX]], align 4
 ; CHECKOO-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
 ; CHECKOO-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]]
@@ -776,10 +788,16 @@ define i32 @or_oneusevalues(ptr nocapture noundef %x, i32 noundef %n, ptr nocapt
 ; CHECKOO-NEXT:    [[CONV:%.*]] = zext i1 [[CMP5]] to i32
 ; CHECKOO-NEXT:    [[ADD1:%.*]] = add i32 [[ADD]], 1
 ; CHECKOO-NEXT:    [[ADD2:%.*]] = or i32 [[ADD1]], 1
-; CHECKOO-NEXT:    [[OR:%.*]] = or i32 [[CONV]], [[ADD2]]
+; CHECKOO-NEXT:    [[CMP5_FROZEN:%.*]] = freeze i1 [[CMP5]]
+; CHECKOO-NEXT:    [[TMP2:%.*]] = or i32 [[ADD2]], 1
+; CHECKOO-NEXT:    br i1 [[CMP5_FROZEN]], label [[SELECT_END:%.*]], label [[SELECT_FALSE:%.*]]
+; CHECKOO:       select.false:
+; CHECKOO-NEXT:    br label [[SELECT_END]]
+; CHECKOO:       select.end:
+; CHECKOO-NEXT:    [[OR:%.*]] = phi i32 [ [[TMP2]], [[IF_THEN]] ], [ [[ADD2]], [[SELECT_FALSE]] ]
 ; CHECKOO-NEXT:    br label [[IF_END]]
 ; CHECKOO:       if.end:
-; CHECKOO-NEXT:    [[Y_1]] = phi i32 [ [[OR]], [[IF_THEN]] ], [ 0, [[FOR_BODY]] ]
+; CHECKOO-NEXT:    [[Y_1]] = phi i32 [ [[OR]], [[SELECT_END]] ], [ 0, [[FOR_BODY]] ]
 ; CHECKOO-NEXT:    store i32 [[Y_1]], ptr [[ARRAYIDX]], align 4
 ; CHECKOO-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
 ; CHECKOO-NEXT:    [[EXITCOND_NOT:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT]]



More information about the llvm-commits mailing list