[llvm] [LegalizeTypes] Expand UDIV/UREM by constant via chunk summation (PR #146238)
Shivam Gupta via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 14 13:49:45 PDT 2026
================
@@ -8233,6 +8231,75 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
DAG.getConstant(0, dl, HiLoVT));
Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, Sum, Carry);
}
+ } else {
+ // If we cannot split in two halves, look for a smaller chunk width W
+ // such that (1 << W) % Divisor == 1.
+ const APInt &Divisor = CN->getAPIntValue();
+ unsigned BitWidth = VT.getScalarSizeInBits();
+ unsigned BestChunkWidth = 0;
+
+ // Determine the legal scalar integer type for chunk operations.
+ EVT LegalVT = getTypeToTransformTo(*DAG.getContext(), VT);
+ unsigned LegalWidth = LegalVT.getScalarSizeInBits();
+ unsigned MaxChunk = std::min<unsigned>(LegalWidth, BitWidth);
+
+ // Precompute 2^MaxChunk mod Divisor
+ APInt Mod(Divisor.getBitWidth(), 1);
+ for (unsigned K = 0; K != MaxChunk; ++K)
+ Mod = Mod.shl(1).urem(Divisor);
+
+ // Since Divisor is odd, modular inverse of 2 is (Divisor + 1) / 2
+ APInt Inv2 = (Divisor + 1).lshr(1);
+
+ // Search for W where 2^W % Divisor == 1
+ for (unsigned I = MaxChunk, E = MaxChunk / 2; I > E; --I) {
+ if (Mod.isOne()) {
+ // Safety Check: Ensure (NumChunks * MaxChunkValue) doesn't overflow
+ // LegalVT
+ unsigned NumChunks = divideCeil(BitWidth, I);
+ // if the ChunkWidth (I) plus the Potential Carry Bits is less than the
+ // Register Width (64), we have enough "slack" at the top of the
+ // register to let the carries pile up safely.
+ // Max sum is NumChunks * (2^I - 1) so by approximation we need
+ // NumChunks × 2^I < 2^L. Taking log on both size we will have
+ // log2(NumChunks) + I < L.
+ if (I + Log2_32_Ceil(NumChunks) < LegalWidth) {
+ BestChunkWidth = I;
+ break;
+ }
+ }
+ Mod = (Mod * Inv2).urem(Divisor);
+ }
+
+ if (!BestChunkWidth)
+ return false;
+
+ SDValue In =
+ LL ? DAG.getNode(ISD::BUILD_PAIR, dl, VT, LL, LH) : N->getOperand(0);
+ SDValue TotalSum = DAG.getConstant(0, dl, LegalVT);
+ APInt MaskVal = APInt::getLowBitsSet(LegalWidth, BestChunkWidth);
+ SDValue Mask = DAG.getConstant(MaskVal, dl, LegalVT);
+
+ for (unsigned I = 0; I < BitWidth; I += BestChunkWidth) {
+ SDValue Shift = DAG.getShiftAmountConstant(I, VT, dl);
+ SDValue Chunk = DAG.getNode(ISD::SRL, dl, VT, In, Shift);
+ // Truncate to LegalVT
+ SDValue TruncChunk = DAG.getNode(ISD::TRUNCATE, dl, LegalVT, Chunk);
+ // For the last chunk, we might not need a mask if it's smaller than
+ // BestChunkWidth, but applying it is always safe.
+ SDValue MaskedChunk =
+ DAG.getNode(ISD::AND, dl, LegalVT, TruncChunk, Mask);
+ TotalSum = DAG.getNode(ISD::ADD, dl, LegalVT, TotalSum, MaskedChunk);
+ }
+
+ // Final reduction: TotalSum % Divisor.
+ // Since TotalSum is in LegalVT, this UREM will be lowered via magic
+ // multiplication.
+ SDValue ResRem =
----------------
xgupta wrote:
Yes, I have not checked below code. I will remove this from here.
https://github.com/llvm/llvm-project/pull/146238
More information about the llvm-commits
mailing list