[llvm] Update the base and index value for masked gather (PR #130920)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 12 23:12:05 PDT 2025


================
@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
     return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
   return TargetLowering::getPrefLoopAlignment();
 }
+
+// Target override this function to decided whether it want to update the base
+// and index value of a non-uniform gep
+bool X86TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+                                           SDValue &Index, const SDLoc &DL,
+                                           const SDValue &Gep,
+                                           SelectionDAG &DAG,
+                                           const BasicBlock *CurBB) const {
+  if (!EnableBaseIndexUpdate)
+    return false;
+
+  const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (GEP && GEP->getParent() != CurBB)
+    return false;
+
+  SDValue nbase;
+  /* For the gep instruction, we are trying to properly assign the base and
+  index value We are go through the lower code and iterate backward.
+  */
+  if (Gep.getOpcode() == ISD::ADD) {
+    SDValue Op0 = Gep.getOperand(0); // base or  add
+    SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+    nbase = Op0;
+    SDValue Idx = Op1;
+    auto Flags = Gep->getFlags();
+
+    if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+      SDValue Op00 = Op0.getOperand(0); // Base
+      nbase = Op00;
+      Idx = Op0.getOperand(1);
+    } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+                 Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+      return false;
+    }
+    SDValue nIndex;
+    if (Idx.getOpcode() == ISD::SHL) {  // shl zext, BV
+      SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+      SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+      unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+      if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+           Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+          IndexWidth > 32 &&
+          Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+          DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+          Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+        KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+        bool ExtIsNonNegative = ExtKnown.isNonNegative();
+        KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+        bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+        if (!(ExtIsNonNegative && ExtOpIsNonNegative))
+          return false;
+
+        SDValue newOp10 =
+            Op10.getOperand(0);          // Get the Operand zero from the ext
+        EVT VT = newOp10.getValueType(); // Use the
+
+        auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+        if (!ConstEltNo) {
+          return false;
+        }
+        SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+                                    DAG.getConstant(ConstEltNo->getZExtValue(),
+                                                    DL, VT.getScalarType()));
+        nIndex = DAG.getNode(ISD::SHL, DL, VT, newOp10,
+                             DAG.getBuildVector(VT, DL, Ops));
+      } else {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    if (Op0 != nbase) {
+      auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+      if (!ConstEltNo) {
+        return false;
+      }
+      SmallVector<SDValue, 8> Ops(
+          nIndex.getValueType().getVectorNumElements(),
+          DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+                          nIndex.getValueType().getScalarType()));
+      nIndex = DAG.getNode(ISD::ADD, DL, nIndex.getValueType(), nIndex,
+                           DAG.getBuildVector(nIndex.getValueType(), DL, Ops),
+                           Flags);
+    }
+    Base = nbase.getOperand(0);
+    Index = nIndex;
+    LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
----------------
topperc wrote:

Successful*

https://github.com/llvm/llvm-project/pull/130920


More information about the llvm-commits mailing list