[llvm] Update the base and index value for masked gather (PR #130920)
Rohit Aggarwal via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 17 23:25:21 PDT 2025
================
@@ -56370,6 +56375,112 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
Scatter->isTruncatingStore());
}
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, const SDLoc &DL,
+ const SDValue &Gep, SelectionDAG &DAG) {
+ if (!EnableBaseIndexUpdate)
+ return false;
+
+ SDValue Nbase;
+ SDValue Nindex;
+ bool Changed = false;
+ // This function check the opcode of Index and update the index
+ auto checkAndUpdateIndex = [&](SDValue &Idx) {
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+ Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+ return false;
+
+ SDValue NewOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+ // the type of index
+
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
----------------
rohitaggarwal007 wrote:
Ok, my understanding is this
From
```
unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
**Op11.getOpcode() == ISD::BUILD_VECTOR**) {
```
to
```
unsigned IndexWidth = Op10.getScalarValueSizeInBits();
if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
Op10.getOpcode() == ISD::ZERO_EXTEND) &&
IndexWidth > 32 &&
Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
**DAG.getValidMinimumShiftAmount(Idx)**) {
```
Please correct me
https://github.com/llvm/llvm-project/pull/130920
More information about the llvm-commits
mailing list