[llvm] Update the base and index value for masked gather (PR #130920)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 12 23:12:05 PDT 2025
================
@@ -61619,3 +61624,99 @@ Align X86TargetLowering::getPrefLoopAlignment(MachineLoop *ML) const {
return Align(1ULL << ExperimentalPrefInnermostLoopAlignment);
return TargetLowering::getPrefLoopAlignment();
}
+
+// Target override this function to decided whether it want to update the base
+// and index value of a non-uniform gep
+bool X86TargetLowering::updateBaseAndIndex(const Value *Ptr, SDValue &Base,
+ SDValue &Index, const SDLoc &DL,
+ const SDValue &Gep,
+ SelectionDAG &DAG,
+ const BasicBlock *CurBB) const {
+ if (!EnableBaseIndexUpdate)
+ return false;
+
+ const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (GEP && GEP->getParent() != CurBB)
+ return false;
+
+ SDValue nbase;
+ /* For the gep instruction, we are trying to properly assign the base and
+ index value We are go through the lower code and iterate backward.
+ */
+ if (Gep.getOpcode() == ISD::ADD) {
+ SDValue Op0 = Gep.getOperand(0); // base or add
+ SDValue Op1 = Gep.getOperand(1); // build vector or SHL
+ nbase = Op0;
+ SDValue Idx = Op1;
+ auto Flags = Gep->getFlags();
+
+ if (Op0->getOpcode() == ISD::ADD) { // add t15(base), t18(Idx)
+ SDValue Op00 = Op0.getOperand(0); // Base
+ nbase = Op00;
+ Idx = Op0.getOperand(1);
+ } else if (!(Op0->getOpcode() == ISD::BUILD_VECTOR &&
+ Op0.getOperand(0).getOpcode() == ISD::CopyFromReg)) {
+ return false;
+ }
+ SDValue nIndex;
+ if (Idx.getOpcode() == ISD::SHL) { // shl zext, BV
+ SDValue Op10 = Idx.getOperand(0); // Zext or Sext value
+ SDValue Op11 = Idx.getOperand(1); // Build vector of constant
+
+ unsigned IndexWidth = Op10.getScalarValueSizeInBits();
+ if ((Op10.getOpcode() == ISD::SIGN_EXTEND ||
+ Op10.getOpcode() == ISD::ZERO_EXTEND) &&
+ IndexWidth > 32 &&
+ Op10.getOperand(0).getScalarValueSizeInBits() <= 32 &&
+ DAG.ComputeNumSignBits(Op10) > (IndexWidth - 32) &&
+ Op11.getOpcode() == ISD::BUILD_VECTOR) {
+
+ KnownBits ExtKnown = DAG.computeKnownBits(Op10);
+ bool ExtIsNonNegative = ExtKnown.isNonNegative();
+ KnownBits ExtOpKnown = DAG.computeKnownBits(Op10.getOperand(0));
+ bool ExtOpIsNonNegative = ExtOpKnown.isNonNegative();
+ if (!(ExtIsNonNegative && ExtOpIsNonNegative))
+ return false;
+
+ SDValue newOp10 =
+ Op10.getOperand(0); // Get the Operand zero from the ext
+ EVT VT = newOp10.getValueType(); // Use the
+
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op11.getOperand(0));
+ if (!ConstEltNo) {
+ return false;
+ }
+ SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(),
+ DAG.getConstant(ConstEltNo->getZExtValue(),
+ DL, VT.getScalarType()));
+ nIndex = DAG.getNode(ISD::SHL, DL, VT, newOp10,
+ DAG.getBuildVector(VT, DL, Ops));
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ if (Op0 != nbase) {
+ auto *ConstEltNo = dyn_cast<ConstantSDNode>(Op1.getOperand(0));
+ if (!ConstEltNo) {
+ return false;
+ }
+ SmallVector<SDValue, 8> Ops(
+ nIndex.getValueType().getVectorNumElements(),
+ DAG.getConstant(ConstEltNo->getZExtValue(), DL,
+ nIndex.getValueType().getScalarType()));
+ nIndex = DAG.getNode(ISD::ADD, DL, nIndex.getValueType(), nIndex,
+ DAG.getBuildVector(nIndex.getValueType(), DL, Ops),
+ Flags);
+ }
+ Base = nbase.getOperand(0);
+ Index = nIndex;
+ LLVM_DEBUG(dbgs() << "Successfull in updating the non uniform gep "
----------------
topperc wrote:
Successful*
https://github.com/llvm/llvm-project/pull/130920
More information about the llvm-commits
mailing list