[llvm] [DAGCombiner] Option --combiner-select-seq (PR #134813)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 30 13:02:36 PDT 2025


https://github.com/ppetrovic98 updated https://github.com/llvm/llvm-project/pull/134813

>From 580d309e88125caa7dad04a2678dcdcdf89a539a Mon Sep 17 00:00:00 2001
From: "Anmol P. Paralkar" <anmol.paralkar at oss.nxp.com>
Date: Tue, 8 Apr 2025 01:15:48 -0700
Subject: [PATCH] [DAGCombiner] Minimize condition-code lifetime for select
 sequences over constants in arithmetic progressions. It reduces the number of
 comparisons and shortens the condition code lifetime, improving performance.

Original patch by Anmol Paralkar (@anmolparalkar-nxp), initially proposed on Phabricator as D136047.
Modified and updated to work with the latest LLVM version. The test case has been updated to match the current output format.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 566 +++++++++++++++++-
 .../CodeGen/AArch64/combiner-select-seq.ll    | 202 +++++++
 2 files changed, 766 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/combiner-select-seq.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ea1435c3934be..cfff49f61a82e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -149,6 +149,12 @@ static cl::opt<bool> EnableShrinkLoadReplaceStoreWithStore(
     cl::desc("DAG combiner enable load/<replace bytes>/store with "
              "a narrower store"));
 
+static cl::opt<unsigned> SelectSeqMinCostBenefit(
+    "combiner-select-seq-min-cost-benefit", cl::Hidden, cl::init(1),
+    cl::desc("Transform only when the cost benefit (instruction count to be "
+             "reduced by) is at-least as specified; default: 1)"),
+    cl::ZeroOrMore);
+
 namespace {
 
   class DAGCombiner {
@@ -878,6 +884,25 @@ namespace {
     void ExtendSetCCUses(const SmallVectorImpl<SDNode *> &SetCCs,
                          SDValue OrigLoad, SDValue ExtLoad,
                          ISD::NodeType ExtType);
+    
+    /// Checks whether an SDValue is a constant 0.
+    bool isConstantInt64(const ConstantSDNode *C);
+
+    /// Checks whether these values form a "well-formed" arithmetic progression.
+    /// Returns true if the values are in an AP with a constant delta.
+    bool isWellFormedAP(const SmallVector<SDValue, 64> &OpSeq, SDValue &BaseReg,
+                        int64_t &InitialVal, int64_t &Delta,
+                        uint64_t &ADDCount);
+
+    /// Computes the cost-benefit analysis of minimizing the condition code
+    /// lifetime for a given sequence of select instructions.
+    int64_t analyzeSelectSeqCost(SmallVector<SDNode *, 64> &SelectUser,
+                                 uint64_t &TSeqADDCount, uint64_t &FSeqADDCount,
+                                 unsigned &BRIndex, SDValue &InitialValSelect);
+
+    /// Rewrite sequences of selects minimizing dependency
+    /// on the condition code register when possible.
+    bool SelectSeqMinCCLifetime(SDNode *N);
   };
 
 /// This class is a DAGUpdateListener that removes any deleted
@@ -1732,6 +1757,534 @@ bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
   return true;
 }
 
+////////////////////////////////////////////////////////////////////////////////
+//
+// SelectSeqMinCCLifetime:
+// ======================
+//
+// Summary:
+// --------
+//
+// We deal with the special case of a sequence of select instructions that use
+// the condition code register to decide between corresponding elements of two
+// sequences of constants that are each, in arithmetic progression. Typically,
+// the arguments to the selects would form a sequence of accesses to fields of
+// a struct or the elements of an array. Instead of having the CC tied up across
+// the expanse of the sequence for the pointwise selection of the constants, we
+// transform the code to instead, decide upfront between the defining constructs
+// of the two sequences of constants and generating the selected sequence
+// elements via a corresponding sequence of add instructions.
+//
+// Introduction:
+// -------------
+//
+// Consider a sequence of selects "S", consuming a setcc such as:
+//
+//   cond = setcc ...
+//     ...
+//   S[0]: reg[0] = select cond t[0] f[0]
+//     ...
+//   S[n-1]: reg[n-1] = select cond t[n-1] f[n-1]
+//     ...
+//
+// Two sequences arise from the operands of each select:
+//
+//   t[0], t[1], ..., t[n-1]
+//   f[0], f[1], ..., f[n-1]
+//
+// Denote either sequence by (say) "X". (PS: Read ':=' as "is of the form")
+//
+// If,
+//
+//   X[i] := offset[i] (forall i : i <- [0, n))
+//
+// or,
+//
+//   X[i] := reg[i] (forall i : i <- [0, n))
+//   where:
+//     reg[0] := base-register
+//            |  add base-register, offset[0]
+//     reg[i] := add base-register, offset[i] (forall i : i <- (0, n))
+//     where:
+//       base-register := BaseRegDef <>, Register:(i32|i64) %<>
+//
+// where:
+//   offset[] := Constant:i64<>
+//            |  Constant:i32<>
+//   BaseRegDef := CopyFromReg | Load
+//
+// Define:
+//
+//   Offset(X)[0] = Constant:(<i32>|<i64>)<0>
+//     if X[0] := (reg[0] := base-register)
+//   Offset(X)[i] = offset[i]
+//     if X[i] := offset[i] (forall i : i <- [0, n))
+//   Offset(X)[i] = offset[i]
+//     if X[i] := add base-register, offset[i] (forall i : i <- [0, n))
+//
+// Now, (for: n > 1) if Offset(X) is an arithmetic progression, i.e:
+//
+//   Offset(X)[i] := InitialValue(X) (if: i == 0)
+//                |  Offset(X)[i-1] + Delta(X) (forall i : i <- (0, n))
+//     where:
+//       InitialValue(X) = k (k is an arbitrary integer constant)
+//       Delta(X) = (Offset(X)[1] - Offset(X)[0])
+//
+//   Further define:
+//       BaseReg(X) = Constant:(<i32>|<i64>)<0>
+//         if X[0] := offset[0]
+//       BaseReg(X) = base-register
+//         if X[0] := base-register
+//                 |  add base-register, offset[0]
+//
+// PS: In practice, we also accept the reverse-form:
+//
+//   reg[n-1] := base-register
+//            |  add base-register, offset[n-1]
+//   reg[i] := add base-register, offset[i] (forall i : i <- [0, n))
+//   where:
+//     base-register := BaseRegDef <>, Register:(i32<>|i64) %<>
+//
+//   However, the essential idea remains as above.
+//
+// Proposition:
+// ------------
+//
+// Then, the sequence of selects "S", can be rewritten as sequence "S'", where
+// a choice is made between the two sequences via selects on the base-register,
+// the initial-value and the constant-difference (delta) between the successive
+// elements - i.e. the constructs that define such a sequence, and then these
+// are used to generate each access via a sequence of adds:
+//
+//   cond = setcc ...
+//   # Insert: 0
+//   BaseReg(S) = select cond BaseReg(t) BaseReg(f)
+//   # Insert: 1
+//   InitialValue(S) = select cond InitialValue(t) InitialValue(f)
+//   # Insert: 2
+//   Delta(S) = select cond Delta(t) Delta(f)
+//     ...
+//   # Rewrite: 0
+//   S'[0]: reg[0] = add BaseReg(S) InitialValue(S)
+//     ...
+//   # Rewrite: i
+//   S'[i]: reg[i] = add reg[i-1] Delta(S)
+//     ...
+//   # Rewrite: n-1
+//   S'[n-1]: reg[n-1] = add reg[n-2] Delta(S)
+//     ...
+//
+// Conclusion:
+// -----------
+//
+// The above code transformation has two effects:
+//
+// a. Minimization of the setcc lifetime.
+//
+//    The Rewrites: [0, n) do not have a dependency on the setcc. This is the
+//    primary motivation for performing this transformation; see:
+//
+//      [AArch64] Extremely slow code generation for series of function
+//                calls/addition #50491
+//      a.k.a: https://bugs.llvm.org/show_bug.cgi?id=51147
+//
+//    As the length of "S" grows, the lifetime of the setcc grows, creating an
+//    interference with all the other computations in the DAG, thus causing
+//    long build times. Ths transformation helps keep the setcc lifetime minimal
+//    in this case.
+//
+// b. Code size reduction.
+//
+//    Also, while we have (upto) three new selects (Inserts [0-2]), we
+//    eliminate one select and (upto) two adds with each rewrite [0, n) (which
+//    brings in an add) therefore, reducing overall code size, and potentially
+//    improving the runtime performance of the code.
+//
+//      NewCodeSize = OldCodeSize - (3 * n) + (3 + n) (Best case).
+//
+// Notes:
+// ------
+//
+// As a future extension, extend the transformation to include sequences not
+// in arithmetic progression by creating two lookup tables for storing the
+// constant offsets for the [tf]-sequences, and selecting the appropriate
+// table once, based on the setcc value. Then, the offset to add to the base
+// register can be looked up in its entry in the appropriate table; the idea
+// being similar to the scheme above.
+//
+////////////////////////////////////////////////////////////////////////////////
+bool DAGCombiner::isConstantInt64(const ConstantSDNode *C) {
+  if (C && C->getAPIntValue().getSignificantBits() <= 64)
+    return true;
+  else {
+    LLVM_DEBUG(dbgs() << "Unable to obtain value; skip.\n");
+    return false;
+  }
+};
+
+bool DAGCombiner::isWellFormedAP(const SmallVector<SDValue, 64> &OpSeq,
+                                 SDValue &BaseReg, int64_t &InitialVal,
+                                 int64_t &Delta, uint64_t &ADDCount) {
+  auto AscertainConstDiff = [](const SmallVector<int64_t, 64> &RawOffsets,
+                               int64_t diff) -> bool {
+    bool WF = true;
+    for (unsigned i = 1; WF && (i < RawOffsets.size()); ++i)
+      WF &= (diff == (RawOffsets[i] - RawOffsets[i - 1]));
+    return WF;
+  };
+  if (OpSeq.size() < 2) {
+    LLVM_DEBUG(dbgs() << "Need at least two elements in an arithmetic "
+                         "progression; skip.\n");
+    return false;
+  }
+  bool WF = true;
+  SmallVector<int64_t, 64> RawOffsets;
+  if (OpSeq[0].getOpcode() == ISD::Constant) {
+    for (auto V : OpSeq) {
+      WF &= ((V.getOpcode() == ISD::Constant));
+      if (WF) {
+        auto *C = dyn_cast<ConstantSDNode>(V);
+        if (!isConstantInt64(C))
+          return false;
+        RawOffsets.push_back(C->getSExtValue());
+      } else {
+        LLVM_DEBUG(dbgs() << "If the zeroeth element is a constant, so must be "
+                             "each element in the sequence; skip.\n");
+        return false;
+      }
+    }
+    BaseReg = DAG.getConstant(0, SDLoc(OpSeq[0]), OpSeq[0].getValueType());
+    auto *C0 = dyn_cast<ConstantSDNode>(OpSeq[0]);
+    auto *C1 = dyn_cast<ConstantSDNode>(OpSeq[1]);
+    if (!isConstantInt64(C0) || !isConstantInt64(C1))
+      return false;
+    InitialVal = C0->getSExtValue();
+    Delta = (C1->getSExtValue()) - InitialVal;
+    WF &= AscertainConstDiff(RawOffsets, Delta);
+    return WF;
+  } else {
+    // Three cases arise:
+    // 0. All the elements are ISD::ADD's
+    // 1. The zeroeth element is a base register definition and the rest are
+    //    ISD::ADD's
+    // 2. The OpSeq.size()-1'th element is a base register definition and
+    //    the rest are ISD::ADD's
+    // l: inclusive lower bound
+    // u: non-inclusive upper bound
+    auto isBaseRegDef = [](const SDValue V) -> bool {
+      return (V.getOpcode() == ISD::CopyFromReg) ||
+             (V.getOpcode() == ISD::LOAD);
+    };
+    unsigned l, u;
+    if (OpSeq[0].getOpcode() == ISD::ADD &&
+        OpSeq[OpSeq.size() - 1].getOpcode() == ISD::ADD) {
+      l = 0;
+      u = OpSeq.size();
+      if (isBaseRegDef(OpSeq[0].getOperand(0))) {
+        BaseReg = OpSeq[0].getOperand(0);
+      } else {
+        LLVM_DEBUG(dbgs() << "Unable to get BaseReg; skip.\n");
+        return false;
+      }
+    } else if (isBaseRegDef(OpSeq[0])) {
+      l = 1;
+      u = OpSeq.size();
+      BaseReg = OpSeq[0];
+    } else if (isBaseRegDef(OpSeq[OpSeq.size() - 1])) {
+      l = 0;
+      u = OpSeq.size() - 1;
+      BaseReg = OpSeq[OpSeq.size() - 1];
+    } else {
+      LLVM_DEBUG(
+          dbgs()
+          << "Sequence not in (Add)+|(BaseRegDef)(Add)+|(Add)+(BaseRegDef) "
+             "form; skip.\n");
+      return false;
+    }
+    // Ascertain that the elements in OpSeq[l, u) are all ISD::ADD's in the
+    // form: BaseReg + (constant offset)
+    for (unsigned i = l; WF && (i < u); ++i) {
+      WF &= ((OpSeq[1].getOpcode() == ISD::ADD) &&
+             (OpSeq[i].getOperand(0) == BaseReg) &&
+             (OpSeq[i].getOperand(1).getOpcode() == ISD::Constant));
+      if (WF) {
+        ++ADDCount;
+        auto *C = dyn_cast<ConstantSDNode>(OpSeq[i].getOperand(1));
+        if (!isConstantInt64(C))
+          return false;
+        RawOffsets.push_back(C->getSExtValue());
+      } else {
+        LLVM_DEBUG(dbgs() << "Sequence not in (Add = BaseReg + <constant "
+                             "offset>)+ form; skip.\n");
+        return false;
+      }
+    }
+    if (!WF)
+      return WF;
+    if (isBaseRegDef(OpSeq[0])) {
+      InitialVal = 0;
+      // Element 1 is guaranteed to be ISD::ADD in the form:
+      // BaseReg + (constant offset)
+      auto *C = dyn_cast<ConstantSDNode>(OpSeq[1].getOperand(1));
+      if (!isConstantInt64(C))
+        return false;
+      Delta = (C->getSExtValue()) - InitialVal;
+    } else { // Element 0 is an ISD::ADD
+      auto *C0 = dyn_cast<ConstantSDNode>(OpSeq[0].getOperand(1));
+      if (!isConstantInt64(C0))
+        return false;
+      InitialVal = C0->getSExtValue();
+      if (OpSeq[1].getOpcode() == ISD::ADD) {
+        auto *C1 = dyn_cast<ConstantSDNode>(OpSeq[1].getOperand(1));
+        if (!isConstantInt64(C1))
+          return false;
+        Delta = (C1->getSExtValue()) - InitialVal;
+      } else { // Element 1 is an BaseRegDef
+        Delta = 0 - InitialVal;
+      }
+    }
+    WF &= AscertainConstDiff(RawOffsets, Delta);
+    return WF;
+  }
+};
+
+int64_t DAGCombiner::analyzeSelectSeqCost(SmallVector<SDNode *, 64> &SelectUser,
+                                          uint64_t &TSeqADDCount,
+                                          uint64_t &FSeqADDCount,
+                                          unsigned &BRIndex,
+                                          SDValue &InitialValSelect) {
+  bool Adjusted = false;
+  // select instructions we think we are going to subtract:
+  uint64_t SelectInstCount = SelectUser.size(); // select's
+  // Instructions we think we are going to add:
+  uint64_t SeqSelectorCount = 3;           // 1 BaseReg, 1 InitialVal, 1 Delta
+  uint64_t AddInstCount = SelectInstCount; // add's (AccessAddr's below)
+  auto *C = dyn_cast<ConstantSDNode>(InitialValSelect);
+  if (BRIndex == 0 && InitialValSelect.getOpcode() == ISD::Constant &&
+      isConstantInt64(C)) {
+    // When the BaseReg is SelectUser[0] and the InitialValue is the
+    // constant 0; AccessAddr[0] will reduce to the BaseReg, so, we will
+    // eliminate one select less and we will add one ADD less, so make the
+    // appropriate adjustments:
+    --SelectInstCount;
+    --AddInstCount;
+    // ... but, we counted 1 for the BaseReg, and 1 for the InitialValue,
+    // which are not newly added instructions now, so:
+    SeqSelectorCount -= 2;
+    // Lastly:
+    Adjusted = true;
+  }
+  uint64_t ToBeEliminatedInstCount =
+      SelectInstCount + TSeqADDCount + FSeqADDCount;
+  uint64_t ToBeAddedInstCount = AddInstCount + SeqSelectorCount;
+  int64_t Benefit = ToBeEliminatedInstCount - ToBeAddedInstCount;
+  LLVM_DEBUG(dbgs() << "Cost Benefit Analysis:\n");
+  LLVM_DEBUG(dbgs() << "Adjusted: " << Adjusted << "\n");
+  LLVM_DEBUG(dbgs() << "Number of select's eliminated: " << SelectInstCount
+                    << "\n");
+  LLVM_DEBUG(dbgs() << "Number of TSeq ADD's eliminated: " << TSeqADDCount
+                    << "\n");
+  LLVM_DEBUG(dbgs() << "Number of FSeq ADD's eliminated: " << FSeqADDCount
+                    << "\n");
+  LLVM_DEBUG(dbgs() << "Number of Sequence Selectors added: "
+                    << SeqSelectorCount << "\n");
+  LLVM_DEBUG(dbgs() << "Number of ADD accesses added: " << AddInstCount
+                    << "\n");
+  LLVM_DEBUG(dbgs() << "CostBenefit (i.e. instruction count to be reduced by): "
+                    << Benefit << "\n");
+  LLVM_DEBUG(dbgs() << "SelectSeqMinCostBenefit: " << SelectSeqMinCostBenefit
+                    << "\n");
+  return Benefit;
+};
+
+bool DAGCombiner::SelectSeqMinCCLifetime(SDNode *N) {
+  assert(N->getOpcode() == ISD::SETCC && "Expected SETCC node!");
+
+  if (N->use_size() <= 1)
+    return false;
+  SDValue SetccVal(N, 0);
+  if (SetccVal->use_size() <= 1)
+    return false;
+  bool DAGModified = false;
+
+  LLVM_DEBUG(dbgs() << "1\n");
+
+  LLVM_DEBUG(dbgs() << "Setcc: "; N->dump());
+  // SelectUser: All the select instructions that use this setcc.
+  // SelectUser lists the selects in the order they appear in the DAG i.e.
+  // going from top to down in the DAG, they appear in left to right order, 0
+  // onwards ... Ultimately, the accesses are generated using SelectUser.
+  SmallVector<SDNode *, 64> SelectUser;
+  SmallVector<SDValue, 64> TSeq; // Operand 1 of each SelectUser
+  SmallVector<SDValue, 64> FSeq; // Operand 2 of each SelectUser
+  SDNodeFlags AccessAddrFlags;
+  unsigned i = 0;
+  // Collect the 't' & 'f' operands:
+  for (SDNode *User : SetccVal->users()) {
+    ++i;
+    if (User->getOpcode() != ISD::SELECT)
+      continue;
+    SDValue N1 = User->getOperand(1);
+    SDValue N2 = User->getOperand(2);
+    auto itU = SelectUser.begin();
+    SelectUser.insert(itU, User);
+    auto itO1 = TSeq.begin();
+    TSeq.insert(itO1, N1);
+    auto itO2 = FSeq.begin();
+    FSeq.insert(itO2, N2);
+    SDNodeFlags N1Flags = N1->getFlags();
+    SDNodeFlags N2Flags = N2->getFlags();
+    // Note the flags for later use ...
+    if (N1.getOpcode() == ISD::ADD)
+      AccessAddrFlags = N1Flags;
+    if (N2.getOpcode() == ISD::ADD)
+      AccessAddrFlags = N2Flags;
+    if (N1.getOpcode() != ISD::ADD || N2.getOpcode() != ISD::ADD)
+      continue;
+    // Both operands are ISD::ADD's Ascertain that either: Both operands have
+    // the same (NoSigned/NoUnsigned)Wrap flags or both operands have no
+    // (NoSigned/NoUnsigned)Wrap flags:
+    if (!((N1Flags.hasNoSignedWrap() && N2Flags.hasNoSignedWrap()) ||
+          (N1Flags.hasNoUnsignedWrap() && N2Flags.hasNoUnsignedWrap()) ||
+          ((!N1Flags.hasNoSignedWrap() && !N2Flags.hasNoSignedWrap()) &&
+           (!N1Flags.hasNoUnsignedWrap() && !N2Flags.hasNoUnsignedWrap())))) {
+      LLVM_DEBUG(dbgs() << "Operand (Signed/Unsigned)flags mismatch; skip.\n");
+      return false;
+    }
+  }
+  LLVM_DEBUG(dbgs() << "User:\n");
+  for (unsigned i = 0; i < SelectUser.size(); ++i) {
+    LLVM_DEBUG(dbgs() << i << " "; SelectUser[i]->dump());
+    for (const SDValue &Op : SelectUser[i]->op_values())
+      LLVM_DEBUG(dbgs() << "  "; Op.dump());
+  }
+
+  uint64_t TSeqADDCount = 0;
+  SDValue TSeqBaseReg;
+  int64_t TSeqInitialVal = 0;
+  int64_t TSeqDelta = 0;
+  bool TSeqValid = isWellFormedAP(TSeq, TSeqBaseReg, TSeqInitialVal, TSeqDelta,
+                                  TSeqADDCount);
+  uint64_t FSeqADDCount = 0;
+  SDValue FSeqBaseReg;
+  int64_t FSeqInitialVal = 0;
+  int64_t FSeqDelta = 0;
+  bool FSeqValid = isWellFormedAP(FSeq, FSeqBaseReg, FSeqInitialVal, FSeqDelta,
+                                  FSeqADDCount);
+  LLVM_DEBUG(dbgs() << "2\n");
+  if (!TSeqValid || !FSeqValid) {
+    LLVM_DEBUG(dbgs() << "Operands not well formed or not in aritmetic "
+                         "progression; skip.\n");
+    return false;
+  }
+  LLVM_DEBUG(dbgs() << "3\n");
+  EVT TSeqEltVT = TSeqBaseReg.getValueType();
+  EVT FSeqEltVT = FSeqBaseReg.getValueType();
+  EVT TSeqEltSTVT = TSeqEltVT.getScalarType();
+  EVT FSeqEltSTVT = FSeqEltVT.getScalarType();
+  if (!isIntN(TSeqEltSTVT.getSizeInBits(), TSeqDelta) ||
+      !isIntN(FSeqEltSTVT.getSizeInBits(), FSeqDelta)) {
+    LLVM_DEBUG(dbgs() << "Delta values too wide for bitwidth!\n");
+    return false;
+  }
+  LLVM_DEBUG(dbgs() << "4\n");
+  SDValue TSeqInitialValReg = DAG.getConstant(
+      APInt(TSeqEltSTVT.getSizeInBits(), (uint64_t)TSeqInitialVal, true),
+      SDLoc(N), TSeqEltVT);
+  SDValue TSeqDeltaReg = DAG.getConstant(
+      APInt(TSeqEltSTVT.getSizeInBits(), (uint64_t)TSeqDelta, true), SDLoc(N),
+      TSeqEltVT);
+  SDValue FSeqInitialValReg = DAG.getConstant(
+      APInt(FSeqEltSTVT.getSizeInBits(), (uint64_t)FSeqInitialVal, true),
+      SDLoc(N), FSeqEltVT);
+  SDValue FSeqDeltaReg = DAG.getConstant(
+      APInt(FSeqEltSTVT.getSizeInBits(), (uint64_t)FSeqDelta, true), SDLoc(N),
+      FSeqEltVT);
+  // Sequence selector nodes:
+  // Construct a select for the base register:
+  SDValue BaseRegSelect = DAG.getSelect(
+      SDLoc(SetccVal),
+      TSeqBaseReg.getValueType(), // Note: We could use FSeqBaseReg as well.
+      SetccVal, TSeqBaseReg, FSeqBaseReg);
+  LLVM_DEBUG(dbgs() << "BaseReg: "; BaseRegSelect->dump());
+  // Construct a select for the initial value:
+  SDValue InitialValSelect = DAG.getSelect(
+      SDLoc(SetccVal),
+      TSeqInitialValReg
+          .getValueType(), // Note: We could use FSeqInitialValReg as well.
+      SetccVal, TSeqInitialValReg, FSeqInitialValReg);
+  LLVM_DEBUG(dbgs() << "InitialVal: "; InitialValSelect->dump());
+  // Construct a select for the delta value:
+  SDValue DeltaSelect =
+      DAG.getSelect(SDLoc(SetccVal), TSeqDeltaReg.getValueType(),
+                    // Note: We could use FSeqDeltaReg as well.
+                    SetccVal, TSeqDeltaReg, FSeqDeltaReg);
+  LLVM_DEBUG(dbgs() << "Delta: "; DeltaSelect->dump());
+  // Check if any of the sequence selector nodes correspond to the SelectUser
+  // nodes. If so, we need to bail out here itself as we will end-up creating
+  // a cycle in the generated access address nodes, after we substitute the
+  // the generated access addresss in the uses of the select nodes that we are
+  // eliminating. The exception to this is when the BaseReg is SelectUser[0]
+  // and the InitialValue is the constant 0; in this case, AccessAddr[0] will
+  // reduce to the BaseReg, so this is not an issue.
+  auto Overlap = [](const SmallVector<SDNode *, 64> NodeList,
+                    const SDNode *Node) -> unsigned {
+    unsigned Index = 0;
+    for (auto N : NodeList)
+      if (N == Node)
+        return Index;
+      else
+        ++Index;
+    return Index;
+  };
+
+  unsigned SUSize = SelectUser.size();
+  unsigned BRIndex = Overlap(SelectUser, BaseRegSelect.getNode());
+  auto *C = dyn_cast<ConstantSDNode>(InitialValSelect);
+  if ((((BRIndex != 0) || !(InitialValSelect.getOpcode() == ISD::Constant &&
+                            isConstantInt64(C))) &&
+       (BRIndex != SUSize)) ||
+      (Overlap(SelectUser, InitialValSelect.getNode()) != SUSize) ||
+      (Overlap(SelectUser, DeltaSelect.getNode()) != SUSize)) {
+    LLVM_DEBUG(
+        dbgs() << "Generated code will introduce cycle(s) in the DAG; skip.\n");
+    return DAGModified = false;
+  }
+
+  if (analyzeSelectSeqCost(SelectUser, TSeqADDCount, FSeqADDCount, BRIndex,
+                           InitialValSelect) < SelectSeqMinCostBenefit) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Generated code will not be of cost benefit to the DAG; skip.\n");
+    return DAGModified = false;
+  }
+  // Construct the Access address nodes:
+  SmallVector<SDValue, 64> AccessAddr;
+  LLVM_DEBUG(dbgs() << "AccessAddr:\n");
+  for (unsigned i = 0; i < SelectUser.size(); ++i) {
+    SDLoc DL(SelectUser[i]);
+    SDValue CurAccessAddr;
+    if (i == 0)
+      CurAccessAddr = DAG.getNode(ISD::ADD, DL, BaseRegSelect.getValueType(),
+                                  BaseRegSelect, InitialValSelect);
+    else
+      CurAccessAddr =
+          DAG.getNode(ISD::ADD, DL, AccessAddr[i - 1].getValueType(),
+                      AccessAddr[i - 1], DeltaSelect);
+    // DAG.getNode() could have optimized and returned an ISD::SELECT if the
+    // InitialValSelect reduced to the constant zero, so check that we have an
+    // ISD:ADD before setting the flags:
+    if (CurAccessAddr.getOpcode() == ISD::ADD)
+      CurAccessAddr->setFlags(AccessAddrFlags);
+    AccessAddr.push_back(CurAccessAddr);
+    DAG.ReplaceAllUsesWith(SelectUser[i], CurAccessAddr.getNode());
+    DAGModified = true;
+    LLVM_DEBUG(dbgs() << i << " "; CurAccessAddr->dump());
+  }
+  DAG.RemoveDeadNodes();
+  return DAGModified;
+}
+
 //===----------------------------------------------------------------------===//
 //  Main DAG Combiner implementation
 //===----------------------------------------------------------------------===//
@@ -1746,13 +2299,22 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
   WorklistInserter AddNodes(*this);
 
   // Add all the dag nodes to the worklist.
-  //
+  // Also, in case of SETCC node, we trigger our 'SelectSeqMinCCLifetime'
+  // transformation here (without relying on the worklist) to ensure that this
+  // optimization, if it's applicable, happens while the full select sequence is
+  // still intact.
+
   // Note: All nodes are not added to PruningList here, this is because the only
   // nodes which can be deleted are those which have no uses and all other nodes
   // which would otherwise be added to the worklist by the first call to
   // getNextWorklistEntry are already present in it.
-  for (SDNode &Node : DAG.allnodes())
+  for (SDNode &Node : DAG.allnodes()) {
+    if (Node.getOpcode() == ISD::SETCC &&
+        !TLI.hasMultipleConditionRegisters())
+      SelectSeqMinCCLifetime(&Node);
+      
     AddToWorklist(&Node, /* IsCandidateForPruning */ Node.use_empty());
+  }
 
   // Create a dummy node (which is not added to allnodes), that adds a reference
   // to the root node, preventing it from being deleted, and tracking any
diff --git a/llvm/test/CodeGen/AArch64/combiner-select-seq.ll b/llvm/test/CodeGen/AArch64/combiner-select-seq.ll
new file mode 100644
index 0000000000000..fc1528da19c75
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/combiner-select-seq.ll
@@ -0,0 +1,202 @@
+; RUN: llc < %s -march=aarch64 -simplify-mir -stop-after=aarch64-isel -combiner-select-seq-min-cost-benefit=5 | FileCheck %s --check-prefix=PST
+
+; Description:
+; ------------
+; Given:
+;   struct A {
+;     char junk[53];
+;     int a0;
+;     int a1;
+;     int a2;
+;     int a3;
+;     int a4;
+;     int a5;
+;     int a6;
+;     int a7;
+;   };
+;   extern int sum(struct A *pa);
+;   extern int access(void *p);
+;   int sum(struct A *pa) {
+;     int s = 0;
+;     s += access(pa ? (void *) &pa->a0 : (void *) 91);
+;     s += access(pa ? (void *) &pa->a1 : (void *) 81);
+;     s += access(pa ? (void *) &pa->a2 : (void *) 71);
+;     s += access(pa ? (void *) &pa->a3 : (void *) 61);
+;     s += access(pa ? (void *) &pa->a4 : (void *) 51);
+;     s += access(pa ? (void *) &pa->a5 : (void *) 41);
+;     s += access(pa ? (void *) &pa->a6 : (void *) 31);
+;     s += access(pa ? (void *) &pa->a7 : (void *) 21);
+;     return s;
+;   }
+; Compiled into LLVM IR, thus: -O3 --target=aarch64 -emit-llvm -S
+; Do we identify the setcc:
+;   Setcc: t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+; along with its select users:
+;   User:
+;   0 t9: i64 = select t5, Constant:i64<91>, t7
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t8: i64 = Constant<91>
+;     t7: i64 = add nuw t2, Constant:i64<56>
+;   1 t26: i64 = select t5, Constant:i64<81>, t24
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t25: i64 = Constant<81>
+;     t24: i64 = add nuw t2, Constant:i64<60>
+;   2 t37: i64 = select t5, Constant:i64<71>, t35
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t36: i64 = Constant<71>
+;     t35: i64 = add nuw t2, Constant:i64<64>
+;   3 t48: i64 = select t5, Constant:i64<61>, t46
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t47: i64 = Constant<61>
+;     t46: i64 = add nuw t2, Constant:i64<68>
+;   4 t59: i64 = select t5, Constant:i64<51>, t57
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t58: i64 = Constant<51>
+;     t57: i64 = add nuw t2, Constant:i64<72>
+;   5 t70: i64 = select t5, Constant:i64<41>, t68
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t69: i64 = Constant<41>
+;     t68: i64 = add nuw t2, Constant:i64<76>
+;   6 t81: i64 = select t5, Constant:i64<31>, t79
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t80: i64 = Constant<31>
+;     t79: i64 = add nuw t2, Constant:i64<80>
+;   7 t92: i64 = select t5, Constant:i64<21>, t90
+;     t5: i1 = setcc t2, Constant:i64<0>, seteq:ch
+;     t91: i64 = Constant<21>
+;     t90: i64 = add nuw t2, Constant:i64<84>
+; defining the two sequences:
+;   t-seq: ( 0 + 91), ( 0 + 81), ..., (0  + 31), (0  + 21)
+;   f-seq: (t2 + 56), (t2 + 60), ..., (t2 + 80), (t2 + 84)
+; and derive and inject the sequence selectors:
+;   BaseReg: t104: i64 = select t5, Constant:i64<0>, t2
+;   InitialVal: t105: i64 = select t5, Constant:i64<91>, Constant:i64<56>
+;   Delta: t106: i64 = select t5, Constant:i64<-10>, Constant:i64<4>
+; that define:
+;   AccessAddr:
+;   0 t107: i64 = add nuw t104, t105
+;   1 t108: i64 = add nuw t107, t106
+;   2 t109: i64 = add nuw t108, t106
+;   3 t110: i64 = add nuw t109, t106
+;   4 t111: i64 = add nuw t110, t106
+;   5 t112: i64 = add nuw t111, t106
+;   6 t113: i64 = add nuw t112, t106
+;   7 t114: i64 = add nuw t113, t106
+; and, do we rewrite the DAG, eliminating the selects and the base-displacement
+; add's so that instead, the AccessAddr's are now used?
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32"
+target triple = "aarch64"
+
+; Function Attrs: nounwind uwtable
+define dso_local i32 @sum(ptr noundef %pa) local_unnamed_addr #0 {
+entry:
+  %tobool.not = icmp eq ptr %pa, null
+  %a0 = getelementptr inbounds nuw i8, ptr %pa, i64 56
+  %cond = select i1 %tobool.not, ptr inttoptr (i64 91 to ptr), ptr %a0
+  %call = tail call i32 @access(ptr noundef nonnull %cond) #2
+  %a1 = getelementptr inbounds nuw i8, ptr %pa, i64 60
+  %cond5 = select i1 %tobool.not, ptr inttoptr (i64 81 to ptr), ptr %a1
+  %call6 = tail call i32 @access(ptr noundef nonnull %cond5) #2
+  %add7 = add nsw i32 %call6, %call
+  %a2 = getelementptr inbounds nuw i8, ptr %pa, i64 64
+  %cond12 = select i1 %tobool.not, ptr inttoptr (i64 71 to ptr), ptr %a2
+  %call13 = tail call i32 @access(ptr noundef nonnull %cond12) #2
+  %add14 = add nsw i32 %add7, %call13
+  %a3 = getelementptr inbounds nuw i8, ptr %pa, i64 68
+  %cond19 = select i1 %tobool.not, ptr inttoptr (i64 61 to ptr), ptr %a3
+  %call20 = tail call i32 @access(ptr noundef nonnull %cond19) #2
+  %add21 = add nsw i32 %add14, %call20
+  %a4 = getelementptr inbounds nuw i8, ptr %pa, i64 72
+  %cond26 = select i1 %tobool.not, ptr inttoptr (i64 51 to ptr), ptr %a4
+  %call27 = tail call i32 @access(ptr noundef nonnull %cond26) #2
+  %add28 = add nsw i32 %add21, %call27
+  %a5 = getelementptr inbounds nuw i8, ptr %pa, i64 76
+  %cond33 = select i1 %tobool.not, ptr inttoptr (i64 41 to ptr), ptr %a5
+  %call34 = tail call i32 @access(ptr noundef nonnull %cond33) #2
+  %add35 = add nsw i32 %add28, %call34
+  %a6 = getelementptr inbounds nuw i8, ptr %pa, i64 80
+  %cond40 = select i1 %tobool.not, ptr inttoptr (i64 31 to ptr), ptr %a6
+  %call41 = tail call i32 @access(ptr noundef nonnull %cond40) #2
+  %add42 = add nsw i32 %add35, %call41
+  %a7 = getelementptr inbounds nuw i8, ptr %pa, i64 84
+  %cond47 = select i1 %tobool.not, ptr inttoptr (i64 21 to ptr), ptr %a7
+  %call48 = tail call i32 @access(ptr noundef nonnull %cond47) #2
+  %add49 = add nsw i32 %add42, %call48
+  ret i32 %add49
+}
+
+declare dso_local i32 @access(ptr noundef) local_unnamed_addr #1
+
+; PST-LABEL: name:            sum
+; PST:     %0:gpr64common = COPY $x0
+; PST-NEXT:     %1:gpr64 = SUBSXri %0, 0, 0, implicit-def $nzcv
+; PST-NEXT:     %2:gpr32 = MOVi32imm 4
+; PST-NEXT:     %3:gpr64 = SUBREG_TO_REG 0, killed %2, %subreg.sub_32
+; PST-NEXT:     %4:gpr64 = MOVi64imm -10
+; PST-NEXT:     %5:gpr64 = CSELXr killed %4, killed %3, 0, implicit $nzcv
+; PST-NEXT:     %6:gpr32 = MOVi32imm 56
+; PST-NEXT:     %7:gpr64 = SUBREG_TO_REG 0, killed %6, %subreg.sub_32
+; PST-NEXT:     %8:gpr32 = MOVi32imm 91
+; PST-NEXT:     %9:gpr64 = SUBREG_TO_REG 0, killed %8, %subreg.sub_32
+; PST-NEXT:     %10:gpr64 = CSELXr killed %9, killed %7, 0, implicit $nzcv
+; PST-NEXT:     %11:gpr64 = COPY $xzr
+; PST-NEXT:     %12:gpr64 = CSELXr %11, %0, 0, implicit $nzcv
+; PST-NEXT:     %13:gpr64 = nuw ADDXrr killed %12, killed %10
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %13
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %14:gpr32 = COPY $w0
+; PST-NEXT:     %15:gpr64 = nuw ADDXrr %13, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %15
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %16:gpr32 = COPY $w0
+; PST-NEXT:     %17:gpr32 = nsw ADDWrr %16, %14
+; PST-NEXT:     %18:gpr64 = nuw ADDXrr %15, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %18
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %19:gpr32 = COPY $w0
+; PST-NEXT:     %20:gpr32 = nsw ADDWrr killed %17, %19
+; PST-NEXT:     %21:gpr64 = nuw ADDXrr %18, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %21
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %22:gpr32 = COPY $w0
+; PST-NEXT:     %23:gpr32 = nsw ADDWrr killed %20, %22
+; PST-NEXT:     %24:gpr64 = nuw ADDXrr %21, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %24
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %25:gpr32 = COPY $w0
+; PST-NEXT:     %26:gpr32 = nsw ADDWrr killed %23, %25
+; PST-NEXT:     %27:gpr64 = nuw ADDXrr %24, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %27
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %28:gpr32 = COPY $w0
+; PST-NEXT:     %29:gpr32 = nsw ADDWrr killed %26, %28
+; PST-NEXT:     %30:gpr64 = nuw ADDXrr %27, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %30
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %31:gpr32 = COPY $w0
+; PST-NEXT:     %32:gpr32 = nsw ADDWrr killed %29, %31
+; PST-NEXT:     %33:gpr64 = nuw ADDXrr %30, %5
+; PST-NEXT:     ADJCALLSTACKDOWN 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     $x0 = COPY %33
+; PST-NEXT:     BL @access, csr_aarch64_aapcs, implicit-def dead $lr, implicit $sp, implicit $x0, implicit-def $sp, implicit-def $w0
+; PST-NEXT:     ADJCALLSTACKUP 0, 0, implicit-def dead $sp, implicit $sp
+; PST-NEXT:     %34:gpr32 = COPY $w0
+; PST-NEXT:     %35:gpr32 = nsw ADDWrr killed %32, %34
+; PST-NEXT:     $w0 = COPY %35
+; PST-NEXT:     RET_ReallyLR implicit $w0
+; PST-NEXT: ...



More information about the llvm-commits mailing list