[llvm] [RISCV] Optimize divide by constant for VP intrinsics (PR #125991)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 25 14:50:43 PST 2025
================
@@ -6492,6 +6492,137 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::ADD, dl, VT, Q, T);
}
+/// Given an ISD::VP_SDIV node expressing a divide by constant,
+/// return a DAG expression to select that will generate the same value by
+/// multiplying by a magic number.
+/// Ref: "Hacker's Delight" or "The PowerPC Compiler Writer's Guide".
+SDValue TargetLowering::BuildVPSDIV(SDNode *N, SelectionDAG &DAG,
+ bool IsAfterLegalization,
+ SmallVectorImpl<SDNode *> &Created) const {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ EVT SVT = VT.getScalarType();
+ EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
+ EVT ShSVT = ShVT.getScalarType();
+ unsigned EltBits = VT.getScalarSizeInBits();
+
+ // Check to see if we can do this.
+ if (!isTypeLegal(VT) ||
+ !isOperationLegalOrCustom(ISD::VP_MULHS, VT, IsAfterLegalization))
+ return SDValue();
+
+ bool AnyFactorOne = false;
+ bool AnyFactorNegOne = false;
+
+ SmallVector<SDValue, 16> MagicFactors, Factors, Shifts, ShiftMasks;
+
+ auto BuildSDIVPattern = [&](ConstantSDNode *C) {
+ if (C->isZero())
+ return false;
+
+ const APInt &Divisor = C->getAPIntValue();
+ SignedDivisionByConstantInfo magics =
+ SignedDivisionByConstantInfo::get(Divisor);
+ int NumeratorFactor = 0;
+ int ShiftMask = -1;
+
+ if (Divisor.isOne() || Divisor.isAllOnes()) {
+ // If d is +1/-1, we just multiply the numerator by +1/-1.
+ NumeratorFactor = Divisor.getSExtValue();
+ magics.Magic = 0;
+ magics.ShiftAmount = 0;
+ ShiftMask = 0;
+ AnyFactorOne |= Divisor.isOne();
+ AnyFactorNegOne |= Divisor.isAllOnes();
+ } else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
+ // If d > 0 and m < 0, add the numerator.
+ NumeratorFactor = 1;
+ AnyFactorOne = true;
+ } else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
+ // If d < 0 and m > 0, subtract the numerator.
+ NumeratorFactor = -1;
+ AnyFactorNegOne = true;
+ }
+
+ MagicFactors.push_back(DAG.getConstant(magics.Magic, DL, SVT));
+ Factors.push_back(DAG.getSignedConstant(NumeratorFactor, DL, SVT));
+ Shifts.push_back(DAG.getConstant(magics.ShiftAmount, DL, ShSVT));
+ ShiftMasks.push_back(DAG.getSignedConstant(ShiftMask, DL, SVT));
+ return true;
+ };
+
+ SDValue N0 = N->getOperand(0);
+ SDValue N1 = N->getOperand(1);
+ SDValue Mask = N->getOperand(2);
+ SDValue VL = N->getOperand(3);
+
+ // Collect the shifts / magic values from each element.
+ if (!ISD::matchUnaryPredicate(N1, BuildSDIVPattern))
+ return SDValue();
+
+ SDValue MagicFactor, Factor, Shift, ShiftMask;
+ if (N1.getOpcode() == ISD::BUILD_VECTOR) {
+ MagicFactor = DAG.getBuildVector(VT, DL, MagicFactors);
+ Factor = DAG.getBuildVector(VT, DL, Factors);
+ Shift = DAG.getBuildVector(ShVT, DL, Shifts);
+ ShiftMask = DAG.getBuildVector(VT, DL, ShiftMasks);
+ } else if (N1.getOpcode() == ISD::SPLAT_VECTOR) {
+ assert(MagicFactors.size() == 1 && Factors.size() == 1 &&
+ Shifts.size() == 1 && ShiftMasks.size() == 1 &&
+ "Expected matchUnaryPredicate to return one element for scalable "
+ "vectors");
+ MagicFactor = DAG.getSplatVector(VT, DL, MagicFactors[0]);
+ Factor = DAG.getSplatVector(VT, DL, Factors[0]);
+ Shift = DAG.getSplatVector(ShVT, DL, Shifts[0]);
+ ShiftMask = DAG.getSplatVector(VT, DL, ShiftMasks[0]);
+ } else {
+ assert(isa<ConstantSDNode>(N1) && "Expected a constant");
----------------
topperc wrote:
This code should be unreachable for VP opcode
https://github.com/llvm/llvm-project/pull/125991
More information about the llvm-commits
mailing list