[PATCH] D144225: [InstCombine] Add constant combines for `(urem/srem (shl X, Y), (shl X, Z))`

Sander de Smalen via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 19 06:41:55 PDT 2023


sdesmalen accepted this revision.
sdesmalen added a comment.
This revision is now accepted and ready to land.

LGTM with comment about redundant condition addressed.



================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1715-1717
+       match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) ||
+      (match(Op0, m_Mul(m_APInt(Y), m_Value(X))) &&
+       match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z))))) {
----------------
You can remove this condition, because InstCombine will already have canonicalised the constant to the RHS.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1707
+  bool ShiftX = false, ShiftY = false, ShiftZ = false;
+  if ((match(Op0, m_Mul(m_Value(X), m_APInt(Y))) &&
+       match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) ||
----------------
goldstein.w.n wrote:
> sdesmalen wrote:
> > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
> > 
> >   APInt Y, Z;
> >   const APInt *MatchY = nullptr, *MatchZ = nullptr;
> > 
> >   // Match and normalise shift-amounts to multiplications
> >   if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> >     Y = *MatchY;
> >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(mul(x, y), shl(x, z))
> >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(mul(x, y), mul(x, z))
> >       Z = *MatchZ;
> >   } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> >     Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(shl(x, y), shl(x, z))
> >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(shl(x, y), mul(x, z))
> >       Z = *MatchZ;
> >   }
> > 
> >   if (!MatchY || !MatchZ)
> >     return nullptr;
> > It might be a bit easier to follow, if you explicitly do the scaling while doing the matching, i.e.
> 
> Inlined the ShiftY/ShiftZ
> 
> Prefer keeping the 4 explicit cases (removed the ShiftX case and moved to next patch). Think its clearer to have the cases each explicitly laid out, rather than having nest if/else statements. LMK if thats okay, will change if you feel strongly.
> > 
> >   APInt Y, Z;
> >   const APInt *MatchY = nullptr, *MatchZ = nullptr;
> > 
> >   // Match and normalise shift-amounts to multiplications
> >   if (match(Op0, m_c_Mul(m_Value(X), m_APInt(MatchY)))) {
> >     Y = *MatchY;
> >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(mul(x, y), shl(x, z))
> >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(mul(x, y), mul(x, z))
> >       Z = *MatchZ;
> >   } else if (match(Op0, m_Shl(m_Value(X), m_APInt(MatchY)))) {
> >     Y = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchY);
> >     if (match(Op1, m_Shl(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(shl(x, y), shl(x, z))
> >       Z = APInt(X->getType()->getScalarSizeInBits(), 1).shl(*MatchZ);
> >     else if (match(Op1, m_c_Mul(m_Specific(X), m_APInt(MatchZ))))
> >       //  rem(shl(x, y), mul(x, z))
> >       Z = *MatchZ;
> >   }
> > 
> >   if (!MatchY || !MatchZ)
> >     return nullptr;
> 
> 
This is merely a suggestion, I'll leave it to you whether to adopt.

>From looking at the new way you've structured the code, it occurred to me that it can also be written as this:

  // If V is not nullptr, it will be matched using m_Specific.
  auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool {
    const APInt *Tmp = nullptr;
    if ((!V && match(Op, m_c_Mul(m_Value(V), m_APInt(Tmp)))) ||
        (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
      C = *Tmp;
    else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
             (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
      C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
    return Tmp != nullptr;
  };

  APInt Y, Z;
  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
  if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z))
    return nullptr;

Which avoids having to spell out all the permutations. It also avoids the need for the `AdjustedY` and `AdjustedZ` variables.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144225/new/

https://reviews.llvm.org/D144225



More information about the llvm-commits mailing list