[PATCH] D144225: [InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`
Sander de Smalen via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 19 06:41:55 PDT 2023
sdesmalen accepted this revision.
sdesmalen added a comment.
This revision is now accepted and ready to land.
LGTM with comment about redundant condition addressed.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1715-1717
+ match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) ||
+ (match(Op0, m_Mul(m_APInt(Y), m_Value(X))) &&
+ match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z))))) {
----------------
You can remove this condition, because InstCombine will already have canonicalised the constant to the RHS.
================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1707
+ bool ShiftX = false, ShiftY = false, ShiftZ = false;
+ if ((match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
+ match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) ||
----------------
goldstein.w.n wrote:
> sdesmalen wrote:
> > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
> >
> > APInt Y, Z;
> > const APInt *MatchY = nullptr, *MatchZ = nullptr;
> >
> > // Match and normalise shift-amounts to multiplications
> > if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> > Y = *MatchY;
> > if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > // rem(mul(x, y), shl(x, z))
> > Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > // rem(mul(x, y), mul(x, z))
> > Z = *MatchZ;
> > } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> > Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> > if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > // rem(shl(x, y), shl(x, z))
> > Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > // rem(shl(x, y), mul(x, z))
> > Z = *MatchZ;
> > }
> >
> > if (!MatchY || !MatchZ)
> > return nullptr;
> > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
>
> Inlined the ShiftY/ShiftZ
>
> Prefer keeping the 4 explicit cases (removed the ShiftX case and moved to next patch). Think its clearer to have the cases each explicitly laid out, rather than having nest if/else statements. LMK if thats okay, will change if you feel strongly.
> >
> > APInt Y, Z;
> > const APInt *MatchY = nullptr, *MatchZ = nullptr;
> >
> > // Match and normalise shift-amounts to multiplications
> > if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> > Y = *MatchY;
> > if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > // rem(mul(x, y), shl(x, z))
> > Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > // rem(mul(x, y), mul(x, z))
> > Z = *MatchZ;
> > } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> > Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> > if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > // rem(shl(x, y), shl(x, z))
> > Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > // rem(shl(x, y), mul(x, z))
> > Z = *MatchZ;
> > }
> >
> > if (!MatchY || !MatchZ)
> > return nullptr;
>
>
This is merely a suggestion, I'll leave it to you whether to adopt.
>From looking at the new way you've structured the code, it occurred to me that it can also be written as this:
// If V is not nullptr, it will be matched using m_Specific.
auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
const APInt *Tmp = nullptr;
if ((!V && match(Op, m_c_Mul(m_Value(V), m_APInt(Tmp)))) ||
(V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
C = *Tmp;
else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
(V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
return Tmp != nullptr;
};
APInt Y, Z;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
return nullptr;
Which avoids having to spell out all the permutations. It also avoids the need for the `AdjustedY` and `AdjustedZ` variables.
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D144225/new/
https://reviews.llvm.org/D144225
More information about the llvm-commits
mailing list