[PATCH] D143014: [InstCombine] Add combines for `(urem/srem (mul/shl X, Y), (mul/shl X, Z))`
Matt Devereau via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 6 08:47:30 PST 2023
MattDevereau added a comment.
I'm thinking we could have a function which attempts to simplify irem/mul/shl patterns and returns successes if the simplification is possible. Similarly to how there are calls such as
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
inside `commonIRemTransforms`, we could have something along the lines of
if (Instruction *res = SimplifyIRemMulShl(I))
return res;
at the end of `commonIRemTransforms`.
This would give us a lot more control over managing the cases and their bailout conditions separately and cleanly. I think it would also make it more obvious that we are just handling two cases primary cases as well, whereas its quite hard to parse and keep a mental note of the complicated if statements which are present currently.
I have had a go at rearranging the pre-combine logic:
Instruction *InstCombinerImpl::SimplifyIRemMulShl(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *A, *B, *C, *D;
if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) ||
match(Op0, m_Shl(m_Value(A), m_Value(B)))) ||
!(match(Op1, m_Mul(m_Value(C), m_Value(D))) ||
match(Op1, m_Shl(m_Value(C), m_Value(D)))))
return nullptr;
Value *X, *Y, *Z;
X = nullptr;
// Do this by hand as opposed to using m_Specific because either A/B (or
// C/D) can be our X.
if (A == C || A == D) {
X = A;
Y = B;
Z = A == C ? D : C;
} else if (B == C || B == D) {
X = B;
Y = A;
Z = B == C ? D : C;
}
BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
if (!BO0 || !BO1)
return nullptr;
Constant *CX = X ? dyn_cast<Constant>(X) : nullptr;
if (!X || (CX && CX->isOneValue()))
return nullptr;
Type * Ty = I.getType();
ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
if (Ty->isVectorTy()) {
auto *VConstY = dyn_cast<Constant>(Y);
auto *VConstZ = dyn_cast<Constant>(Z);
if (VConstY && VConstZ) {
VConstY = VConstY->getSplatValue();
VConstZ = VConstZ->getSplatValue();
if (VConstY && VConstZ) {
ConstY = dyn_cast<ConstantInt>(VConstY);
ConstZ = dyn_cast<ConstantInt>(VConstZ);
}
}
}
bool IsSigned = I.getOpcode() == Instruction::SRem;
// Check constant folds first.
if (ConstY && ConstZ) {
APInt APIntY = ConstY->getValue();
APInt APIntZ = ConstZ->getValue();
// Just treat the shifts as mul, we may end up returning a mul by power
// of 2 but that will be cleaned up later.
if (BO0->getOpcode() == Instruction::Shl)
APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY;
if (BO1->getOpcode() == Instruction::Shl)
APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ;
APInt RemYZ = IsSigned ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ);
// (rem (mul nuw/nsw X, Y), (mul X, Z))
// if (rem Y, Z) == 0
// -> 0
if (RemYZ.isZero() &&
(IsSigned ? BO0->hasNoSignedWrap() : BO0->hasNoUnsignedWrap()))
return replaceInstUsesWith(I, ConstantInt::getNullValue(Ty));
// (rem (mul X, Y), (mul nuw/nsw X, Z))
// if (rem Y, Z) == Y
// -> (mul nuw/nsw X, Y)
if (RemYZ == APIntY &&
(IsSigned ? BO1->hasNoSignedWrap() : BO1->hasNoUnsignedWrap())) {
// We are returning Op0 essentially but we can also add no wrap flags.
BinaryOperator *BO =
BinaryOperator::CreateMul(X, ConstantInt::get(Ty, APIntY));
// We can add nsw/nuw if remainder op is signed/unsigned, also we
// can copy any overflow flags from Op0.
if (IsSigned || BO0->hasNoSignedWrap())
BO->setHasNoSignedWrap();
if (!IsSigned || BO0->hasNoUnsignedWrap())
BO->setHasNoUnsignedWrap();
return BO;
}
// (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
// if Y >= Z
// -> (mul {nuw} nsw X, (rem Y, Z))
// NB: (rem Y, Z) is a constant.
if (APIntY.uge(APIntZ) &&
(IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap())
: BO0->hasNoUnsignedWrap())) {
BinaryOperator *BO =
BinaryOperator::CreateMul(X, ConstantInt::get(Ty, RemYZ));
BO->setHasNoSignedWrap();
if (!IsSigned || BO0->hasNoUnsignedWrap())
BO->setHasNoUnsignedWrap();
return BO;
}
}
// Check if desirable to do generic replacement.
// NB: It may be beneficial to do this if we have X << Z even if there are
// multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power
// of 2 is converted to add/and) and urem is pretty expensive (maybe more
// sense in DAGCombiner).
if ((ConstY && ConstZ) ||
(Op0->hasOneUse() && Op1->hasOneUse() &&
(IsSigned ? (BO0->getOpcode() != Instruction::Shl &&
BO1->getOpcode() != Instruction::Shl)
: (BO0->getOpcode() != Instruction::Shl ||
BO1->getOpcode() == Instruction::Shl)))) {
// (rem (mul nuw/nsw X, Y), (mul nuw {nsw} X, Z)
// -> (mul nuw/nsw X, (rem Y, Z))
if (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap() &&
BO1->hasNoUnsignedWrap())
: (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) {
// Convert the shifts to multiplies, cleaned up elsewhere.
if (BO0->getOpcode() == Instruction::Shl)
Y = Builder.CreateShl(ConstantInt::get(Ty, 1), Y);
if (BO1->getOpcode() == Instruction::Shl)
Z = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
BinaryOperator *BO = BinaryOperator::CreateMul(
X, IsSigned ? Builder.CreateSRem(Y, Z) : Builder.CreateURem(Y, Z));
if (IsSigned || BO0->hasNoSignedWrap() || BO1->hasNoSignedWrap())
BO->setHasNoSignedWrap();
if (!IsSigned || (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
BO->setHasNoUnsignedWrap();
return BO;
}
}
return nullptr;
}
This function could be split even further into SimplifyIRemMulShlConst and SimplifyIRemMulShlGeneric, or we could just leave the generic case in this function.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1756-1759
+ // BO0 = X * Y
+ BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
+ // BO1 = X * Z
+ BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
----------------
I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1764-1767
+ bool NSW0 = BO0->hasNoSignedWrap();
+ bool NSW1 = BO1->hasNoSignedWrap();
+ bool NUW0 = BO0->hasNoUnsignedWrap();
+ bool NUW1 = BO1->hasNoUnsignedWrap();
----------------
In my opinion these variables names are less descriptive than having the function calls inline
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1772-1789
+ // Try and get Y/Z as constants.
+ ConstantInt *ConstY = nullptr;
+ ConstantInt *ConstZ = nullptr;
+ if (Ty->isVectorTy()) {
+ auto *VConstY = dyn_cast<Constant>(Y);
+ auto *VConstZ = dyn_cast<Constant>(Z);
+ if (VConstY && VConstZ) {
----------------
We can remove the else statement to the start here, and assign ConstY/ConstZ instead like so:
```
ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
if (Ty->isVectorTy()) {
auto *VConstY = dyn_cast<Constant>(Y);
auto *VConstZ = dyn_cast<Constant>(Z);
if (VConstY && VConstZ) {
VConstY = VConstY->getSplatValue();
VConstZ = VConstZ->getSplatValue();
if (VConstY && VConstZ) {
ConstY = dyn_cast<ConstantInt>(VConstY);
ConstZ = dyn_cast<ConstantInt>(VConstZ);
}
}
}
```
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D143014/new/
https://reviews.llvm.org/D143014
More information about the llvm-commits
mailing list