[PATCH] D143014: [InstCombine] Add combines for `(urem/srem (mul/shl X, Y), (mul/shl X, Z))`
Noah Goldstein via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 6 09:52:19 PST 2023
goldstein.w.n added a comment.
In D143014#4107084 <https://reviews.llvm.org/D143014#4107084>, @MattDevereau wrote:
> I'm thinking we could have a function which attempts to simplify irem/mul/shl patterns and returns successes if the simplification is possible. Similarly to how there are calls such as
I see, thought you meant function for each case, done in V3.
> if (Instruction *R = FoldOpIntoSelect(I, SI))
> return R;
>
> inside `commonIRemTransforms`, we could have something along the lines of
>
> if (Instruction *res = SimplifyIRemMulShl(I))
> return res;
>
> at the end of `commonIRemTransforms`.
>
> This would give us a lot more control over managing the cases and their bailout conditions separately and cleanly. I think it would also make it more obvious that we are just handling two cases primary cases as well, whereas its quite hard to parse and keep a mental note of the complicated if statements which are present currently.
> I have had a go at rearranging the pre-combine logic:
>
> Instruction *InstCombinerImpl::SimplifyIRemMulShl(BinaryOperator &I) {
> Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
> Value *A, *B, *C, *D;
> if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) ||
> match(Op0, m_Shl(m_Value(A), m_Value(B)))) ||
> !(match(Op1, m_Mul(m_Value(C), m_Value(D))) ||
> match(Op1, m_Shl(m_Value(C), m_Value(D)))))
> return nullptr;
>
> Value *X, *Y, *Z;
> X = nullptr;
> // Do this by hand as opposed to using m_Specific because either A/B (or
> // C/D) can be our X.
> if (A == C || A == D) {
> X = A;
> Y = B;
> Z = A == C ? D : C;
> } else if (B == C || B == D) {
> X = B;
> Y = A;
> Z = B == C ? D : C;
> }
>
> BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
> BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
> if (!BO0 || !BO1)
> return nullptr;
>
> Constant *CX = X ? dyn_cast<Constant>(X) : nullptr;
> if (!X || (CX && CX->isOneValue()))
> return nullptr;
>
> Type * Ty = I.getType();
> ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
> ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
> if (Ty->isVectorTy()) {
> auto *VConstY = dyn_cast<Constant>(Y);
> auto *VConstZ = dyn_cast<Constant>(Z);
> if (VConstY && VConstZ) {
> VConstY = VConstY->getSplatValue();
> VConstZ = VConstZ->getSplatValue();
> if (VConstY && VConstZ) {
> ConstY = dyn_cast<ConstantInt>(VConstY);
> ConstZ = dyn_cast<ConstantInt>(VConstZ);
> }
> }
> }
>
> bool IsSigned = I.getOpcode() == Instruction::SRem;
> // Check constant folds first.
> if (ConstY && ConstZ) {
> APInt APIntY = ConstY->getValue();
> APInt APIntZ = ConstZ->getValue();
>
> // Just treat the shifts as mul, we may end up returning a mul by power
> // of 2 but that will be cleaned up later.
> if (BO0->getOpcode() == Instruction::Shl)
> APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY;
> if (BO1->getOpcode() == Instruction::Shl)
> APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ;
>
> APInt RemYZ = IsSigned ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ);
>
> // (rem (mul nuw/nsw X, Y), (mul X, Z))
> // if (rem Y, Z) == 0
> // -> 0
> if (RemYZ.isZero() &&
> (IsSigned ? BO0->hasNoSignedWrap() : BO0->hasNoUnsignedWrap()))
> return replaceInstUsesWith(I, ConstantInt::getNullValue(Ty));
>
> // (rem (mul X, Y), (mul nuw/nsw X, Z))
> // if (rem Y, Z) == Y
> // -> (mul nuw/nsw X, Y)
> if (RemYZ == APIntY &&
> (IsSigned ? BO1->hasNoSignedWrap() : BO1->hasNoUnsignedWrap())) {
> // We are returning Op0 essentially but we can also add no wrap flags.
> BinaryOperator *BO =
> BinaryOperator::CreateMul(X, ConstantInt::get(Ty, APIntY));
> // We can add nsw/nuw if remainder op is signed/unsigned, also we
> // can copy any overflow flags from Op0.
> if (IsSigned || BO0->hasNoSignedWrap())
> BO->setHasNoSignedWrap();
> if (!IsSigned || BO0->hasNoUnsignedWrap())
> BO->setHasNoUnsignedWrap();
> return BO;
> }
>
> // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
> // if Y >= Z
> // -> (mul {nuw} nsw X, (rem Y, Z))
> // NB: (rem Y, Z) is a constant.
> if (APIntY.uge(APIntZ) &&
> (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap())
> : BO0->hasNoUnsignedWrap())) {
> BinaryOperator *BO =
> BinaryOperator::CreateMul(X, ConstantInt::get(Ty, RemYZ));
> BO->setHasNoSignedWrap();
> if (!IsSigned || BO0->hasNoUnsignedWrap())
> BO->setHasNoUnsignedWrap();
> return BO;
> }
> }
>
> // Check if desirable to do generic replacement.
> // NB: It may be beneficial to do this if we have X << Z even if there are
> // multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power
> // of 2 is converted to add/and) and urem is pretty expensive (maybe more
> // sense in DAGCombiner).
> if ((ConstY && ConstZ) ||
> (Op0->hasOneUse() && Op1->hasOneUse() &&
> (IsSigned ? (BO0->getOpcode() != Instruction::Shl &&
> BO1->getOpcode() != Instruction::Shl)
> : (BO0->getOpcode() != Instruction::Shl ||
> BO1->getOpcode() == Instruction::Shl)))) {
> // (rem (mul nuw/nsw X, Y), (mul nuw {nsw} X, Z)
> // -> (mul nuw/nsw X, (rem Y, Z))
> if (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap() &&
> BO1->hasNoUnsignedWrap())
> : (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) {
> // Convert the shifts to multiplies, cleaned up elsewhere.
> if (BO0->getOpcode() == Instruction::Shl)
> Y = Builder.CreateShl(ConstantInt::get(Ty, 1), Y);
> if (BO1->getOpcode() == Instruction::Shl)
> Z = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
> BinaryOperator *BO = BinaryOperator::CreateMul(
> X, IsSigned ? Builder.CreateSRem(Y, Z) : Builder.CreateURem(Y, Z));
>
> if (IsSigned || BO0->hasNoSignedWrap() || BO1->hasNoSignedWrap())
> BO->setHasNoSignedWrap();
> if (!IsSigned || (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
> BO->setHasNoUnsignedWrap();
> return BO;
> }
> }
>
> return nullptr;
> }
>
> This function could be split even further into SimplifyIRemMulShlConst and SimplifyIRemMulShlGeneric, or we could just leave the generic case in this function.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1756-1759
+ // BO0 = X * Y
+ BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
+ // BO1 = X * Z
+ BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
----------------
MattDevereau wrote:
> I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.
> I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.
Imo there is no real cost to the extra comment, can obv delete but seems like no harm for potential clarity.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1764-1767
+ bool NSW0 = BO0->hasNoSignedWrap();
+ bool NSW1 = BO1->hasNoSignedWrap();
+ bool NUW0 = BO0->hasNoUnsignedWrap();
+ bool NUW1 = BO1->hasNoUnsignedWrap();
----------------
MattDevereau wrote:
> In my opinion these variables names are less descriptive than having the function calls inline
> In my opinion these variables names are less descriptive than having the function calls inline
Its mostly stylistic to save column width, with the full BO....->has...() makes the if conditions very large and imo harder to follow.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D143014/new/
https://reviews.llvm.org/D143014
More information about the llvm-commits
mailing list