[llvm] [SelectOpt] Refactor to prepare for support more select-like operations (PR #117582)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 09:17:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Igor Kirillov (igogo-x86)

<details>
<summary>Changes</summary>

* Enables conversion of several select-like instructions within one group
* Any number of auxiliary instructions depending on the same condition can be in between select-like instructions
* After splitting the basic block, move select-like instructions into the relevant basic blocks and optimise them
* Make it easier to add support shift-base select-like instructions and also any mixture of zext/sext/not instructions

---

Patch is 34.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117582.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectOptimize.cpp (+262-220) 
- (modified) llvm/test/CodeGen/AArch64/selectopt.ll (+16-21) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 81796fcf2842a8..d480642171e8e5 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -127,77 +127,26 @@ class SelectOptimizeImpl {
   /// act like selects. For example Or(Zext(icmp), X) can be treated like
   /// select(icmp, X|1, X).
   class SelectLike {
-    SelectLike(Instruction *I) : I(I) {}
-
     /// The select (/or) instruction.
     Instruction *I;
     /// Whether this select is inverted, "not(cond), FalseVal, TrueVal", as
     /// opposed to the original condition.
     bool Inverted = false;
 
-  public:
-    /// Match a select or select-like instruction, returning a SelectLike.
-    static SelectLike match(Instruction *I) {
-      // Select instruction are what we are usually looking for.
-      if (isa<SelectInst>(I))
-        return SelectLike(I);
-
-      // An Or(zext(i1 X), Y) can also be treated like a select, with condition
-      // C and values Y|1 and Y.
-      Value *X;
-      if (PatternMatch::match(
-              I, m_c_Or(m_OneUse(m_ZExt(m_Value(X))), m_Value())) &&
-          X->getType()->isIntegerTy(1))
-        return SelectLike(I);
-
-      return SelectLike(nullptr);
-    }
+    /// The index of the operand that depends on condition. Only for select-like
+    /// instruction such as Or/Add.
+    unsigned CondIdx;
 
-    bool isValid() { return I; }
-    operator bool() { return isValid(); }
-
-    /// Invert the select by inverting the condition and switching the operands.
-    void setInverted() {
-      assert(!Inverted && "Trying to invert an inverted SelectLike");
-      assert(isa<Instruction>(getCondition()) &&
-             cast<Instruction>(getCondition())->getOpcode() ==
-                 Instruction::Xor);
-      Inverted = true;
-    }
-    bool isInverted() const { return Inverted; }
+  public:
+    SelectLike(Instruction *I, bool Inverted = false, unsigned CondIdx = 0)
+        : I(I), Inverted(Inverted), CondIdx(CondIdx) {}
 
     Instruction *getI() { return I; }
     const Instruction *getI() const { return I; }
 
     Type *getType() const { return I->getType(); }
 
-    Value *getNonInvertedCondition() const {
-      if (auto *Sel = dyn_cast<SelectInst>(I))
-        return Sel->getCondition();
-      // Or(zext) case
-      if (auto *BO = dyn_cast<BinaryOperator>(I)) {
-        Value *X;
-        if (PatternMatch::match(BO->getOperand(0),
-                                m_OneUse(m_ZExt(m_Value(X)))))
-          return X;
-        if (PatternMatch::match(BO->getOperand(1),
-                                m_OneUse(m_ZExt(m_Value(X)))))
-          return X;
-      }
-
-      llvm_unreachable("Unhandled case in getCondition");
-    }
-
-    /// Return the condition for the SelectLike instruction. For example the
-    /// condition of a select or c in `or(zext(c), x)`
-    Value *getCondition() const {
-      Value *CC = getNonInvertedCondition();
-      // For inverted conditions the CC is checked when created to be a not
-      // (xor) instruction.
-      if (Inverted)
-        return cast<Instruction>(CC)->getOperand(0);
-      return CC;
-    }
+    unsigned getConditionOpIndex() { return CondIdx; };
 
     /// Return the true value for the SelectLike instruction. Note this may not
     /// exist for all SelectLike instructions. For example, for `or(zext(c), x)`
@@ -224,74 +173,56 @@ class SelectOptimizeImpl {
         return getTrueValue(/*HonorInverts=*/false);
       if (auto *Sel = dyn_cast<SelectInst>(I))
         return Sel->getFalseValue();
-      // Or(zext) case - return the operand which is not the zext.
-      if (auto *BO = dyn_cast<BinaryOperator>(I)) {
-        Value *X;
-        if (PatternMatch::match(BO->getOperand(0),
-                                m_OneUse(m_ZExt(m_Value(X)))))
-          return BO->getOperand(1);
-        if (PatternMatch::match(BO->getOperand(1),
-                                m_OneUse(m_ZExt(m_Value(X)))))
-          return BO->getOperand(0);
-      }
+      // We are on the branch where the condition is zero, which means BinOp
+      // does not perform any computation, and we can simply return the operand
+      // that is not related to the condition
+      if (auto *BO = dyn_cast<BinaryOperator>(I))
+        return BO->getOperand(1 - CondIdx);
 
       llvm_unreachable("Unhandled case in getFalseValue");
     }
 
-    /// Return the NonPredCost cost of the true op, given the costs in
-    /// InstCostMap. This may need to be generated for select-like instructions.
-    Scaled64 getTrueOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
-                           const TargetTransformInfo *TTI) {
-      if (isa<SelectInst>(I))
-        if (auto *I = dyn_cast<Instruction>(getTrueValue())) {
-          auto It = InstCostMap.find(I);
-          return It != InstCostMap.end() ? It->second.NonPredCost
-                                         : Scaled64::getZero();
-        }
-
-      // Or case - add the cost of an extra Or to the cost of the False case.
-      if (isa<BinaryOperator>(I))
-        if (auto I = dyn_cast<Instruction>(getFalseValue())) {
-          auto It = InstCostMap.find(I);
-          if (It != InstCostMap.end()) {
-            InstructionCost OrCost = TTI->getArithmeticInstrCost(
-                Instruction::Or, I->getType(), TargetTransformInfo::TCK_Latency,
-                {TargetTransformInfo::OK_AnyValue,
-                 TargetTransformInfo::OP_None},
-                {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
-            return It->second.NonPredCost + Scaled64::get(*OrCost.getValue());
-          }
-        }
-
-      return Scaled64::getZero();
-    }
-
-    /// Return the NonPredCost cost of the false op, given the costs in
-    /// InstCostMap. This may need to be generated for select-like instructions.
-    Scaled64
-    getFalseOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
-                   const TargetTransformInfo *TTI) {
-      if (isa<SelectInst>(I))
-        if (auto *I = dyn_cast<Instruction>(getFalseValue())) {
-          auto It = InstCostMap.find(I);
+    /// Return the NonPredCost cost of the op on \p isTrue branch, given the
+    /// costs in \p InstCostMap. This may need to be generated for select-like
+    /// instructions.
+    Scaled64 getOpCostOnBranch(
+        bool IsTrue, const DenseMap<const Instruction *, CostInfo> &InstCostMap,
+        const TargetTransformInfo *TTI) {
+      auto *V = IsTrue ? getTrueValue() : getFalseValue();
+      if (V) {
+        if (auto *IV = dyn_cast<Instruction>(V)) {
+          auto It = InstCostMap.find(IV);
           return It != InstCostMap.end() ? It->second.NonPredCost
                                          : Scaled64::getZero();
         }
-
-      // Or case - return the cost of the false case
-      if (isa<BinaryOperator>(I))
-        if (auto I = dyn_cast<Instruction>(getFalseValue()))
-          if (auto It = InstCostMap.find(I); It != InstCostMap.end())
-            return It->second.NonPredCost;
-
-      return Scaled64::getZero();
+        return Scaled64::getZero();
+      }
+      // If getTrue(False)Value() return nullptr, it means we are dealing with
+      // select-like instructions on the branch where the actual computation is
+      // happening. In that case the cost is equal to the cost of computation +
+      // cost of non-dependant on condition operand
+      InstructionCost Cost = TTI->getArithmeticInstrCost(
+          getI()->getOpcode(), I->getType(), TargetTransformInfo::TCK_Latency,
+          {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+          {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
+      auto TotalCost = Scaled64::get(*Cost.getValue());
+      if (auto *OpI = dyn_cast<Instruction>(I->getOperand(1 - CondIdx))) {
+        auto It = InstCostMap.find(OpI);
+        if (It != InstCostMap.end())
+          TotalCost += It->second.NonPredCost;
+      }
+      return TotalCost;
     }
   };
 
 private:
-  // Select groups consist of consecutive select instructions with the same
-  // condition.
-  using SelectGroup = SmallVector<SelectLike, 2>;
+  // Select groups consist of consecutive select-like instructions with the same
+  // condition. Between select-likes could be any number of auxiliary
+  // instructions related to the condition like not, zext
+  struct SelectGroup {
+    Value *Condition;
+    SmallVector<SelectLike, 2> Selects;
+  };
   using SelectGroups = SmallVector<SelectGroup, 2>;
 
   // Converts select instructions of a function to conditional jumps when deemed
@@ -351,6 +282,11 @@ class SelectOptimizeImpl {
   SmallDenseMap<const Instruction *, SelectLike, 2>
   getSImap(const SelectGroups &SIGroups);
 
+  // Returns a map from select-like instructions to the corresponding select
+  // group.
+  SmallDenseMap<const Instruction *, const SelectGroup *, 2>
+  getSGmap(const SelectGroups &SIGroups);
+
   // Returns the latency cost of a given instruction.
   std::optional<uint64_t> computeInstCost(const Instruction *I);
 
@@ -529,34 +465,45 @@ void SelectOptimizeImpl::optimizeSelectsInnerLoops(Function &F,
   }
 }
 
-/// If \p isTrue is true, return the true value of \p SI, otherwise return
-/// false value of \p SI. If the true/false value of \p SI is defined by any
-/// select instructions in \p Selects, look through the defining select
-/// instruction until the true/false value is not defined in \p Selects.
-static Value *
-getTrueOrFalseValue(SelectOptimizeImpl::SelectLike SI, bool isTrue,
-                    const SmallPtrSet<const Instruction *, 2> &Selects,
-                    IRBuilder<> &IB) {
-  Value *V = nullptr;
-  for (SelectInst *DefSI = dyn_cast<SelectInst>(SI.getI());
-       DefSI != nullptr && Selects.count(DefSI);
-       DefSI = dyn_cast<SelectInst>(V)) {
-    if (DefSI->getCondition() == SI.getCondition())
-      V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
-    else // Handle inverted SI
-      V = (!isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
+/// Returns optimised value on \p IsTrue branch. For SelectInst that would be
+/// either True or False value. For (BinaryOperator) instructions, where the
+/// condition may be skipped, the operation will use a non-conditional operand.
+/// For example, for `or(V,zext(cond))` this function would return V.
+/// However, if the conditional operand on \p IsTrue branch matters, we create a
+/// clone of instruction at the end of that branch \p B and replace the
+/// condition operand with a constant.
+///
+/// Also /p OptSelects contains previously optimised select-like instructions.
+/// If the current value uses one of the optimised values, we can optimise it
+/// further by replacing it with the corresponding value on the given branch
+static Value *getTrueOrFalseValue(
+    SelectOptimizeImpl::SelectLike &SI, bool isTrue,
+    SmallDenseMap<Instruction *, std::pair<Value *, Value *>, 2> &OptSelects,
+    BasicBlock *B) {
+  Value *V = isTrue ? SI.getTrueValue() : SI.getFalseValue();
+  if (V) {
+    auto *IV = dyn_cast<Instruction>(V);
+    if (IV && OptSelects.count(IV))
+      return isTrue ? OptSelects[IV].first : OptSelects[IV].second;
+    return V;
   }
 
-  if (isa<BinaryOperator>(SI.getI())) {
-    assert(SI.getI()->getOpcode() == Instruction::Or &&
-           "Only currently handling Or instructions.");
-    V = SI.getFalseValue();
-    if (isTrue)
-      V = IB.CreateOr(V, ConstantInt::get(V->getType(), 1));
-  }
+  auto *BO = cast<BinaryOperator>(SI.getI());
+  assert(BO->getOpcode() == Instruction::Or &&
+         "Only currently handling Or instructions.");
+
+  auto *CBO = BO->clone();
+  auto CondIdx = SI.getConditionOpIndex();
+  CBO->setOperand(CondIdx, ConstantInt::get(CBO->getType(), 1));
 
-  assert(V && "Failed to get select true/false value");
-  return V;
+  unsigned OtherIdx = 1 - CondIdx;
+  if (auto *IV = dyn_cast<Instruction>(CBO->getOperand(OtherIdx))) {
+    if (OptSelects.count(IV))
+      CBO->setOperand(OtherIdx,
+                      isTrue ? OptSelects[IV].first : OptSelects[IV].second);
+  }
+  CBO->insertBefore(B->getTerminator());
+  return CBO;
 }
 
 void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
@@ -602,7 +549,9 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     SmallVector<std::stack<Instruction *>, 2> TrueSlices, FalseSlices;
     typedef std::stack<Instruction *>::size_type StackSizeType;
     StackSizeType maxTrueSliceLen = 0, maxFalseSliceLen = 0;
-    for (SelectLike SI : ASI) {
+    for (SelectLike &SI : ASI.Selects) {
+      if (!isa<SelectInst>(SI.getI()))
+        continue;
       // For each select, compute the sinkable dependence chains of the true and
       // false operands.
       if (auto *TI = dyn_cast_or_null<Instruction>(SI.getTrueValue())) {
@@ -649,8 +598,8 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     }
 
     // We split the block containing the select(s) into two blocks.
-    SelectLike SI = ASI.front();
-    SelectLike LastSI = ASI.back();
+    SelectLike &SI = ASI.Selects.front();
+    SelectLike &LastSI = ASI.Selects.back();
     BasicBlock *StartBlock = SI.getI()->getParent();
     BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI.getI()));
     // With RemoveDIs turned off, SplitPt can be a dbg.* intrinsic. With
@@ -664,19 +613,21 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     // Delete the unconditional branch that was just created by the split.
     StartBlock->getTerminator()->eraseFromParent();
 
-    // Move any debug/pseudo instructions and not's that were in-between the
+    // Move any debug/pseudo and auxiliary instructions that were in-between the
     // select group to the newly-created end block.
     SmallVector<Instruction *, 2> SinkInstrs;
     auto DIt = SI.getI()->getIterator();
+    auto NIt = ASI.Selects.begin();
     while (&*DIt != LastSI.getI()) {
-      if (DIt->isDebugOrPseudoInst())
-        SinkInstrs.push_back(&*DIt);
-      if (match(&*DIt, m_Not(m_Specific(SI.getCondition()))))
+      if (NIt != ASI.Selects.end() && &*DIt == NIt->getI())
+        ++NIt;
+      else
         SinkInstrs.push_back(&*DIt);
       DIt++;
     }
+    auto InsertionPoint = EndBlock->getFirstInsertionPt();
     for (auto *DI : SinkInstrs)
-      DI->moveBeforePreserving(&*EndBlock->getFirstInsertionPt());
+      DI->moveBeforePreserving(&*InsertionPoint);
 
     // Duplicate implementation for DbgRecords, the non-instruction debug-info
     // format. Helper lambda for moving DbgRecords to the end block.
@@ -700,7 +651,15 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
     // At least one will become an actual new basic block.
     BasicBlock *TrueBlock = nullptr, *FalseBlock = nullptr;
     BranchInst *TrueBranch = nullptr, *FalseBranch = nullptr;
-    if (!TrueSlicesInterleaved.empty()) {
+    // Checks if select-like instruction would materialise on the given branch
+    auto HasSelectLike = [](SelectGroup &SG, bool IsTrue) {
+      for (auto &SL : SG.Selects) {
+        if ((IsTrue ? SL.getTrueValue() : SL.getFalseValue()) == nullptr)
+          return true;
+      }
+      return false;
+    };
+    if (!TrueSlicesInterleaved.empty() || HasSelectLike(ASI, true)) {
       TrueBlock = BasicBlock::Create(EndBlock->getContext(), "select.true.sink",
                                      EndBlock->getParent(), EndBlock);
       TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
@@ -708,7 +667,7 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
       for (Instruction *TrueInst : TrueSlicesInterleaved)
         TrueInst->moveBefore(TrueBranch);
     }
-    if (!FalseSlicesInterleaved.empty()) {
+    if (!FalseSlicesInterleaved.empty() || HasSelectLike(ASI, false)) {
       FalseBlock =
           BasicBlock::Create(EndBlock->getContext(), "select.false.sink",
                              EndBlock->getParent(), EndBlock);
@@ -748,93 +707,167 @@ void SelectOptimizeImpl::convertProfitableSIGroups(SelectGroups &ProfSIGroups) {
       FT = FalseBlock;
     }
     IRBuilder<> IB(SI.getI());
-    auto *CondFr = IB.CreateFreeze(SI.getCondition(),
-                                   SI.getCondition()->getName() + ".frozen");
+    auto *CondFr =
+        IB.CreateFreeze(ASI.Condition, ASI.Condition->getName() + ".frozen");
 
-    SmallPtrSet<const Instruction *, 2> INS;
-    for (auto SI : ASI)
-      INS.insert(SI.getI());
+    SmallDenseMap<Instruction *, std::pair<Value *, Value *>, 2> INS;
 
     // Use reverse iterator because later select may use the value of the
     // earlier select, and we need to propagate value through earlier select
     // to get the PHI operand.
-    for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
-      SelectLike SI = *It;
+    InsertionPoint = EndBlock->begin();
+    for (SelectLike &SI : ASI.Selects) {
       // The select itself is replaced with a PHI Node.
       PHINode *PN = PHINode::Create(SI.getType(), 2, "");
-      PN->insertBefore(EndBlock->begin());
+      PN->insertBefore(InsertionPoint);
       PN->takeName(SI.getI());
-      PN->addIncoming(getTrueOrFalseValue(SI, true, INS, IB), TrueBlock);
-      PN->addIncoming(getTrueOrFalseValue(SI, false, INS, IB), FalseBlock);
-      PN->setDebugLoc(SI.getI()->getDebugLoc());
+      // Current instruction might be a condition of some other group, so we
+      // need to replace it there to avoid dangling pointer
+      if (PN->getType()->isIntegerTy(1)) {
+        for (auto &SG : ProfSIGroups) {
+          if (SG.Condition == SI.getI())
+            SG.Condition = PN;
+        }
+      }
       SI.getI()->replaceAllUsesWith(PN);
-      INS.erase(SI.getI());
+      auto *TV = getTrueOrFalseValue(SI, true, INS, TrueBlock);
+      auto *FV = getTrueOrFalseValue(SI, false, INS, FalseBlock);
+      INS[PN] = {TV, FV};
+      PN->addIncoming(TV, TrueBlock);
+      PN->addIncoming(FV, FalseBlock);
+      PN->setDebugLoc(SI.getI()->getDebugLoc());
       ++NumSelectsConverted;
     }
     IB.CreateCondBr(CondFr, TT, FT, SI.getI());
 
     // Remove the old select instructions, now that they are not longer used.
-    for (auto SI : ASI)
+    for (SelectLike &SI : ASI.Selects)
       SI.getI()->eraseFromParent();
   }
 }
 
 void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
                                              SelectGroups &SIGroups) {
+  // Represents something that can be considered as select instruction.
+  // Auxiliary instruction are instructions that depends on a condition and have
+  // zero or some constant value on True/False branch, such as:
+  // * ZExt(1bit)
+  // * Not(1bit)
+  struct SelectLikeInfo {
+    Value *Cond;
+    bool IsAuxiliary;
+    bool IsInverted;
+    unsigned ConditionIdx;
+  };
+
+  DenseMap<Value *, SelectLikeInfo> SelectInfo;
+
+  // Check if the instruction is SelectLike or might be part of SelectLike
+  // expression, put information into SelectInfo and return the iterator to the
+  // inserted position.
+  auto ProcessSelectInfo = [&SelectInfo](Instruction *I) {
+    Value *Cond;
+    if (match(I, m_OneUse(m_ZExt(m_Value(Cond)))) &&
+        Cond->getType()->isIntegerTy(1)) {
+      bool Inverted = match(Cond, m_Not(m_Value(Cond)));
+      return SelectInfo.insert({I, {Cond, true, Inverted, 0}}).first;
+    }
+
+    if (match(I, m_Not(m_Value(Cond)))) {
+      return SelectInfo.insert({I, {Cond, true, true, 0}}).first;
+    }
+
+    // Select instruction are what we are usually looking for.
+    if (match(I, m_Select(m_Value(Cond), m_Value(), m_Value()))) {
+      bool Inverted = match(Cond, m_Not(m_Value(Cond)));
+      return SelectInfo.insert({I, {Cond, false, Inverted, 0}}).first;
+    }
+
+    // An Or(zext(i1 X), Y) can also be treated like a select, with condition X
+    // and values Y|1 and Y.
+    if (auto *BO = dyn_cast<BinaryOperator>(I)) {
+      if (BO->getType()->isIntegerTy(1) || BO->getOpcode() != Instruction::Or)
+        return SelectInfo.end();
+
+      for (unsigned Idx = 0; Idx < 2; Idx++) {
+        auto *Op = BO->getOperand(Idx);
+        auto It = SelectInfo.find(Op);
+        if (It != SelectInfo.end() && It->second.IsAuxiliary) {
+          Cond = It->second.Cond;
+          bool Inverted =...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/117582


More information about the llvm-commits mailing list