[llvm] Update the base and index value for masked gather (PR #130920)

Rohit Aggarwal via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 17 23:25:21 PDT 2025


================
@@ -56370,6 +56375,112 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
                               Scatter->isTruncatingStore());
 }
 
+// Target override this function to decide whether it want to update the base
+// and index value of a non-uniform gep
+static bool updateBaseAndIndex(SDValue &Base, SDValue &Index, const SDLoc &DL,
+                               const SDValue &Gep, SelectionDAG &DAG) {
+  if (!EnableBaseIndexUpdate)
+    return false;
+
+  SDValue Nbase;
+  SDValue Nindex;
+  bool Changed = false;
+  // This function check the opcode of Index and update the index
+  auto checkAndUpdateIndex = [&](SDValue &Idx) {
+    if (Idx.getOpcode() == ISD::SHL) {  // shl zext, BV
+      SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+      SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+          IndexWidth > 32 &&
+          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+          Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+        KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+        bool ExtIsNonNegative = ExtKnown.isNonNegative();
+        KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+        bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+        if (!ExtIsNonNegative || !ExtOpIsNonNegative)
+          return false;
+
+        SDValue NewOp10 =
+            Op10.getOperand(0);          // Get the Operand zero from the ext
+        EVT VT = NewOp10.getValueType(); // Use the operand's type to determine
+                                         // the type of index
+
+        auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
----------------
rohitaggarwal007 wrote:

Ok, my understanding is this
From
```
      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **Op11.getOpcode() == ISD::BUILD_VECTOR**) {
``` 
to 
```
 unsigned IndexWidth = Op10.getScalarValueSizeInBits();
      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
          IndexWidth > 32 &&
          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
          **DAG.getValidMinimumShiftAmount(Idx)**) {
``` 
Please correct me


https://github.com/llvm/llvm-project/pull/130920


More information about the llvm-commits mailing list