[PATCH] D144225: [InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`

Noah Goldstein via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 19 20:32:48 PDT 2023


goldstein.w.n marked 2 inline comments as done.
goldstein.w.n added inline comments.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1707
+  bool ShiftX = false, ShiftY = false, ShiftZ = false;
+  if ((match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
+       match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) ||
----------------
sdesmalen wrote:
> goldstein.w.n wrote:
> > sdesmalen wrote:
> > > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
> > > 
> > >   APInt Y, Z;
> > >   const APInt *MatchY = nullptr, *MatchZ = nullptr;
> > > 
> > >   // Match and normalise shift-amounts to multiplications
> > >   if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> > >     Y = *MatchY;
> > >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(mul(x, y), shl(x, z))
> > >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(mul(x, y), mul(x, z))
> > >       Z = *MatchZ;
> > >   } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> > >     Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> > >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(shl(x, y), shl(x, z))
> > >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(shl(x, y), mul(x, z))
> > >       Z = *MatchZ;
> > >   }
> > > 
> > >   if (!MatchY || !MatchZ)
> > >     return nullptr;
> > > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
> > 
> > Inlined the ShiftY/ShiftZ
> > 
> > Prefer keeping the 4 explicit cases (removed the ShiftX case and moved to next patch). Think its clearer to have the cases each explicitly laid out, rather than having nest if/else statements. LMK if thats okay, will change if you feel strongly.
> > > 
> > >   APInt Y, Z;
> > >   const APInt *MatchY = nullptr, *MatchZ = nullptr;
> > > 
> > >   // Match and normalise shift-amounts to multiplications
> > >   if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> > >     Y = *MatchY;
> > >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(mul(x, y), shl(x, z))
> > >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(mul(x, y), mul(x, z))
> > >       Z = *MatchZ;
> > >   } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> > >     Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> > >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(shl(x, y), shl(x, z))
> > >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> > >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> > >       //  rem(shl(x, y), mul(x, z))
> > >       Z = *MatchZ;
> > >   }
> > > 
> > >   if (!MatchY || !MatchZ)
> > >     return nullptr;
> > 
> > 
> This is merely a suggestion, I'll leave it to you whether to adopt.
> 
> From looking at the new way you've structured the code, it occurred to me that it can also be written as this:
> 
>   // If V is not nullptr, it will be matched using m_Specific.
>   auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
>     const APInt *Tmp = nullptr;
>     if ((!V && match(Op, m_c_Mul(m_Value(V), m_APInt(Tmp)))) ||
>         (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
>       C = *Tmp;
>     else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
>              (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
>       C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
>     return Tmp != nullptr;
>   };
> 
>   APInt Y, Z;
>   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
>   if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
>     return nullptr;
> 
> Which avoids having to spell out all the permutations. It also avoids the need for the `AdjustedY` and `AdjustedZ` variables.
Done, likewise for D147108. But couldn't find a clean way to do it for
D143417. In D143417 its not longer only APInt matches so need
branching logic for either V or Y/Z (Tmp) being nullptr.



Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144225/new/

https://reviews.llvm.org/D144225



More information about the llvm-commits mailing list