[llvm] [DAGCombiner] Option --combiner-select-seq (PR #134813)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 21 04:55:19 PDT 2025
================
@@ -1732,11 +1748,570 @@ bool DAGCombiner::recursivelyDeleteUnusedNodes(SDNode *N) {
return true;
}
+////////////////////////////////////////////////////////////////////////////////
+//
+// --combiner-select-seq:
+// ======================
+//
+// Summary:
+// --------
+//
+// We deal with the special case of a sequence of select instructions that use
+// the condition code register to decide between corresponding elements of two
+// sequences of constants that are each, in arithmetic progression. Typically,
+// the arguments to the selects would form a sequence of accesses to fields of
+// a struct or the elements of an array. Instead of having the CC tied up across
+// the expanse of the sequence for the pointwise selection of the constants, we
+// transform the code to instead, decide upfront between the defining constructs
+// of the two sequences of constants and generating the selected sequence
+// elements via a corresponding sequence of add instructions.
+//
+// Introduction:
+// -------------
+//
+// Consider a sequence of selects "S", consuming a setcc such as:
+//
+// cond = setcc ...
+// ...
+// S[0]: reg[0] = select cond t[0] f[0]
+// ...
+// S[n-1]: reg[n-1] = select cond t[n-1] f[n-1]
+// ...
+//
+// Two sequences arise from the operands of each select:
+//
+// t[0], t[1], ..., t[n-1]
+// f[0], f[1], ..., f[n-1]
+//
+// Denote either sequence by (say) "X". (PS: Read ':=' as "is of the form")
+//
+// If,
+//
+// X[i] := offset[i] (forall i : i <- [0, n))
+//
+// or,
+//
+// X[i] := reg[i] (forall i : i <- [0, n))
+// where:
+// reg[0] := base-register
+// | add base-register, offset[0]
+// reg[i] := add base-register, offset[i] (forall i : i <- (0, n))
+// where:
+// base-register := BaseRegDef <>, Register:(i32|i64) %<>
+//
+// where:
+// offset[] := Constant:i64<>
+// | Constant:i32<>
+// BaseRegDef := CopyFromReg | Load
+//
+// Define:
+//
+// Offset(X)[0] = Constant:(<i32>|<i64>)<0>
+// if X[0] := (reg[0] := base-register)
+// Offset(X)[i] = offset[i]
+// if X[i] := offset[i] (forall i : i <- [0, n))
+// Offset(X)[i] = offset[i]
+// if X[i] := add base-register, offset[i] (forall i : i <- [0, n))
+//
+// Now, (for: n > 1) if Offset(X) is an arithmetic progression, i.e:
+//
+// Offset(X)[i] := InitialValue(X) (if: i == 0)
+// | Offset(X)[i-1] + Delta(X) (forall i : i <- (0, n))
+// where:
+// InitialValue(X) = k (k is an arbitrary integer constant)
+// Delta(X) = (Offset(X)[1] - Offset(X)[0])
+//
+// Further define:
+// BaseReg(X) = Constant:(<i32>|<i64>)<0>
+// if X[0] := offset[0]
+// BaseReg(X) = base-register
+// if X[0] := base-register
+// | add base-register, offset[0]
+//
+// PS: In practice, we also accept the reverse-form:
+//
+// reg[n-1] := base-register
+// | add base-register, offset[n-1]
+// reg[i] := add base-register, offset[i] (forall i : i <- [0, n))
+// where:
+// base-register := BaseRegDef <>, Register:(i32<>|i64) %<>
+//
+// However, the essential idea remains as above.
+//
+// Proposition:
+// ------------
+//
+// Then, the sequence of selects "S", can be rewritten as sequence "S'", where
+// a choice is made between the two sequences via selects on the base-register,
+// the initial-value and the constant-difference (delta) between the successive
+// elements - i.e. the constructs that define such a sequence, and then these
+// are used to generate each access via a sequence of adds:
+//
+// cond = setcc ...
+// # Insert: 0
+// BaseReg(S) = select cond BaseReg(t) BaseReg(f)
+// # Insert: 1
+// InitialValue(S) = select cond InitialValue(t) InitialValue(f)
+// # Insert: 2
+// Delta(S) = select cond Delta(t) Delta(f)
+// ...
+// # Rewrite: 0
+// S'[0]: reg[0] = add BaseReg(S) InitialValue(S)
+// ...
+// # Rewrite: i
+// S'[i]: reg[i] = add reg[i-1] Delta(S)
+// ...
+// # Rewrite: n-1
+// S'[n-1]: reg[n-1] = add reg[n-2] Delta(S)
+// ...
+//
+// Conclusion:
+// -----------
+//
+// The above code transformation has two effects:
+//
+// a. Minimization of the setcc lifetime.
+//
+// The Rewrites: [0, n) do not have a dependency on the setcc. This is the
+// primary motivation for performing this transformation; see:
+//
+// [AArch64] Extremely slow code generation for series of function
+// calls/addition #50491
+// a.k.a: https://bugs.llvm.org/show_bug.cgi?id=51147
+//
+// As the length of "S" grows, the lifetime of the setcc grows, creating an
+// interference with all the other computations in the DAG, thus causing
+// long build times. Ths transformation helps keep the setcc lifetime minimal
+// in this case.
+//
+// b. Code size reduction.
+//
+// Also, while we have (upto) three new selects (Inserts [0-2]), we
+// eliminate one select and (upto) two adds with each rewrite [0, n) (which
+// brings in an add) therefore, reducing overall code size, and potentially
+// improving the runtime performance of the code.
+//
+// NewCodeSize = OldCodeSize - (3 * n) + (3 + n) (Best case).
+//
+// Notes:
+// ------
+//
+// As a future extension, extend the transformation to include sequences not
+// in arithmetic progression by creating two lookup tables for storing the
+// constant offsets for the [tf]-sequences, and selecting the appropriate
+// table once, based on the setcc value. Then, the offset to add to the base
+// register can be looked up in its entry in the appropriate table; the idea
+// being similar to the scheme above.
+//
+////////////////////////////////////////////////////////////////////////////////
+bool DAGCombiner::SelectSeqMinCCLifetime(void) {
+
+ LLVM_DEBUG(dbgs() << "SelectSeqMinCCLifetime:\n");
+
+ LLVM_DEBUG(dbgs() << "DAG PRE:\n");
+ for (SDNode &Node : DAG.allnodes())
+ LLVM_DEBUG(dbgs() << ""; Node.dump());
+
+ // Run through the DAG, looking for selects that use setcc,
+ // collect the setcc operand if we have not collected it already:
+ SmallSet<SDValue, 16> CandidateSetcc;
+ for (SDNode &Node : DAG.allnodes()) {
+ if (Node.getOpcode() != ISD::SELECT)
+ continue;
+ SDValue Op0 = Node.getOperand(0);
+ if (Op0.getOpcode() != ISD::SETCC)
+ continue;
+ if (!CandidateSetcc.contains(Op0))
+ CandidateSetcc.insert(Op0);
+ }
+
+ auto ProcessSetcc = [this](SDValue N) -> bool {
+ bool DAGModified = false;
+ assert(N.getOpcode() == ISD::SETCC);
+ LLVM_DEBUG(dbgs() << "Setcc: "; N.dump());
+ // SelectUser: All the select instructions that use this setcc.
+ // SelectUser lists the selects in the order they appear in the DAG i.e.
+ // going from top to down in the DAG, they appear in left to right order, 0
+ // onwards ... Ultimately, the accesses are generated using SelectUser.
+ SmallVector<SDNode *, 64> SelectUser;
+ SmallVector<SDValue, 64> TSeq; // Operand 1 of each SelectUser
+ SmallVector<SDValue, 64> FSeq; // Operand 2 of each SelectUser
+ SDNodeFlags AccessAddrFlags;
+ unsigned i = 0;
+ // Collect the 't' & 'f' operands:
+ for (SDNode *User : N->users()) {
+ // NOTE: The SDNode::use_iterator presents the use's in *reverse* order.
+ ++i;
+ if (User->getOpcode() != ISD::SELECT)
+ continue;
+ SDValue N1 = User->getOperand(1);
+ SDValue N2 = User->getOperand(2);
+ auto itU = SelectUser.begin();
+ SelectUser.insert(itU, User);
+ auto itO1 = TSeq.begin();
+ TSeq.insert(itO1, N1);
+ auto itO2 = FSeq.begin();
+ FSeq.insert(itO2, N2);
+ SDNodeFlags N1Flags = N1->getFlags();
+ SDNodeFlags N2Flags = N2->getFlags();
+ // Note the flags for later use ...
+ if (N1.getOpcode() == ISD::ADD)
+ AccessAddrFlags = N1Flags;
+ if (N2.getOpcode() == ISD::ADD)
+ AccessAddrFlags = N2Flags;
+ if (N1.getOpcode() != ISD::ADD || N2.getOpcode() != ISD::ADD)
+ continue;
+ // Both operands are ISD::ADD's Ascertain that either: Both operands have
+ // the same (NoSigned/NoUnsigned)Wrap flags or both operands have no
+ // (NoSigned/NoUnsigned)Wrap flags:
+ if (!((N1Flags.hasNoSignedWrap() && N2Flags.hasNoSignedWrap()) ||
+ (N1Flags.hasNoUnsignedWrap() && N2Flags.hasNoUnsignedWrap()) ||
+ ((!N1Flags.hasNoSignedWrap() && !N2Flags.hasNoSignedWrap()) &&
+ (!N1Flags.hasNoUnsignedWrap() && !N2Flags.hasNoUnsignedWrap())))) {
+ LLVM_DEBUG(
+ dbgs() << "Operand (Signed/Unsigned)flags mismatch; skip.\n");
+ return false;
+ }
+ }
+ LLVM_DEBUG(dbgs() << "User:\n");
+ for (unsigned i = 0; i < SelectUser.size(); ++i) {
+ LLVM_DEBUG(dbgs() << i << " "; SelectUser[i]->dump());
+ for (const SDValue &Op : SelectUser[i]->op_values())
+ LLVM_DEBUG(dbgs() << " "; Op.dump());
+ }
+ auto WellFormedAP = [this](const SmallVector<SDValue, 64> &OpSeq,
+ SDValue &BaseReg, int64_t &InitialVal,
+ int64_t &Delta, uint64_t &ADDCount) -> bool {
+ auto AscertainConstDiff = [](const SmallVector<int64_t, 64> &RawOffsets,
+ int64_t diff) -> bool {
+ bool WF = true;
+ for (unsigned i = 1; WF && (i < RawOffsets.size()); ++i)
+ WF &= (diff == (RawOffsets[i] - RawOffsets[i - 1]));
+ return WF;
+ };
+ if (OpSeq.size() < 2) {
+ LLVM_DEBUG(dbgs() << "Need at least two elements in an arithmetic "
+ "progression; skip.\n");
+ return false;
+ }
+ bool WF = true;
+ SmallVector<int64_t, 64> RawOffsets;
+ if (OpSeq[0].getOpcode() == ISD::Constant) {
+ for (auto V : OpSeq) {
+ WF &= ((V.getOpcode() == ISD::Constant));
+ if (WF) {
+ auto *C = dyn_cast<ConstantSDNode>(V);
+ if (C && C->getAPIntValue().getSignificantBits() <= 64)
+ RawOffsets.push_back(C->getSExtValue());
+ else {
+ LLVM_DEBUG(dbgs() << "Unable to obtain value; skip.\n");
+ return false;
+ }
+ } else {
+ LLVM_DEBUG(dbgs()
+ << "If the zeroeth element is a constant, so must be "
+ "each element in the sequence; skip.\n");
+ return false;
+ }
+ }
+ BaseReg = DAG.getConstant(0, SDLoc(OpSeq[0]), OpSeq[0].getValueType());
+ auto *C0 = dyn_cast<ConstantSDNode>(OpSeq[0]);
+ auto *C1 = dyn_cast<ConstantSDNode>(OpSeq[1]);
+ if ((C0 && C0->getAPIntValue().getSignificantBits() <= 64) &&
+ (C1 && C1->getAPIntValue().getSignificantBits() <= 64)) {
+ InitialVal = C0->getSExtValue();
+ Delta = (C1->getSExtValue()) - InitialVal;
+ WF &= AscertainConstDiff(RawOffsets, Delta);
+ return WF;
+ } else {
+ LLVM_DEBUG(dbgs() << "Unable to obtain value; skip.\n");
+ return false;
+ }
+ } else {
+ // Three cases arise:
+ // 0. All the elements are ISD::ADD's
+ // 1. The zeroeth element is a base register definition and the rest are
+ // ISD::ADD's
+ // 2. The OpSeq.size()-1'th element is a base register definition and
+ // the rest are ISD::ADD's
+ // l: inclusive lower bound
+ // u: non-inclusive upper bound
+ auto isBaseRegDef = [](const SDValue V) -> bool {
+ return (V.getOpcode() == ISD::CopyFromReg) ||
+ (V.getOpcode() == ISD::LOAD);
+ };
+ unsigned l, u;
+ if (OpSeq[0].getOpcode() == ISD::ADD &&
----------------
arsenm wrote:
isAddLike? Should catch the disjoint or case
https://github.com/llvm/llvm-project/pull/134813
More information about the llvm-commits
mailing list