[PATCH] D143014: [InstCombine] Add combines for `(urem/srem (mul/shl X, Y), (mul/shl X, Z))`

Noah Goldstein via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 6 09:52:19 PST 2023


goldstein.w.n added a comment.

In D143014#4107084 <https://reviews.llvm.org/D143014#4107084>, @MattDevereau wrote:

> I'm thinking we could have a function which attempts to simplify irem/mul/shl patterns and returns successes if the simplification is possible. Similarly to how there are calls such as

I see, thought you meant function for each case, done in V3.

>   if (Instruction *R = FoldOpIntoSelect(I, SI))
>             return R;
>
> inside `commonIRemTransforms`, we could have something along the lines of
>
>   if (Instruction *res = SimplifyIRemMulShl(I))
>     return res;
>
> at the end of `commonIRemTransforms`.
>
> This would give us a lot more control over managing the cases and their bailout conditions separately and cleanly. I think it would also make it more obvious that we are just handling two cases primary cases as well, whereas its quite hard to parse and keep a mental note of the complicated if statements which are present currently.
> I have had a go at rearranging the pre-combine logic:
>
>   Instruction *InstCombinerImpl::SimplifyIRemMulShl(BinaryOperator &I) {
>     Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
>     Value *A, *B, *C, *D;
>     if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) ||
>           match(Op0, m_Shl(m_Value(A), m_Value(B)))) ||
>         !(match(Op1, m_Mul(m_Value(C), m_Value(D))) ||
>           match(Op1, m_Shl(m_Value(C), m_Value(D)))))
>       return nullptr;
>   
>     Value *X, *Y, *Z;
>     X = nullptr;
>     // Do this by hand as opposed to using m_Specific because either A/B (or
>     // C/D) can be our X.
>     if (A == C || A == D) {
>       X = A;
>       Y = B;
>       Z = A == C ? D : C;
>     } else if (B == C || B == D) {
>       X = B;
>       Y = A;
>       Z = B == C ? D : C;
>     }
>   
>     BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
>     BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
>     if (!BO0 || !BO1)
>       return nullptr;
>   
>     Constant *CX = X ? dyn_cast<Constant>(X) : nullptr;
>     if (!X || (CX && CX->isOneValue()))
>       return nullptr;
>       
>     Type * Ty = I.getType();
>     ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
>     ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
>     if (Ty->isVectorTy()) {
>       auto *VConstY = dyn_cast<Constant>(Y);
>       auto *VConstZ = dyn_cast<Constant>(Z);
>       if (VConstY && VConstZ) {
>         VConstY = VConstY->getSplatValue();
>         VConstZ = VConstZ->getSplatValue();
>         if (VConstY && VConstZ) {
>           ConstY = dyn_cast<ConstantInt>(VConstY);
>           ConstZ = dyn_cast<ConstantInt>(VConstZ);
>         }
>       }
>     }
>   
>     bool IsSigned = I.getOpcode() == Instruction::SRem;
>     // Check constant folds first.
>     if (ConstY && ConstZ) {
>       APInt APIntY = ConstY->getValue();
>       APInt APIntZ = ConstZ->getValue();
>   
>       // Just treat the shifts as mul, we may end up returning a mul by power
>       // of 2 but that will be cleaned up later.
>       if (BO0->getOpcode() == Instruction::Shl)
>         APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY;
>       if (BO1->getOpcode() == Instruction::Shl)
>         APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ;
>   
>       APInt RemYZ = IsSigned ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ);
>   
>       // (rem (mul nuw/nsw X, Y), (mul X, Z))
>       //      if (rem Y, Z) == 0
>       //          -> 0
>       if (RemYZ.isZero() &&
>           (IsSigned ? BO0->hasNoSignedWrap() : BO0->hasNoUnsignedWrap()))
>         return replaceInstUsesWith(I, ConstantInt::getNullValue(Ty));
>   
>       // (rem (mul X, Y), (mul nuw/nsw X, Z))
>       //      if (rem Y, Z) == Y
>       //          -> (mul nuw/nsw X, Y)
>       if (RemYZ == APIntY &&
>           (IsSigned ? BO1->hasNoSignedWrap() : BO1->hasNoUnsignedWrap())) {
>         // We are returning Op0 essentially but we can also add no wrap flags.
>         BinaryOperator *BO =
>             BinaryOperator::CreateMul(X, ConstantInt::get(Ty, APIntY));
>         // We can add nsw/nuw if remainder op is signed/unsigned, also we
>         // can copy any overflow flags from Op0.
>         if (IsSigned || BO0->hasNoSignedWrap())
>           BO->setHasNoSignedWrap();
>         if (!IsSigned || BO0->hasNoUnsignedWrap())
>           BO->setHasNoUnsignedWrap();
>         return BO;
>       }
>   
>       // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
>       //      if Y >= Z
>       //          -> (mul {nuw} nsw X, (rem Y, Z))
>       // NB: (rem Y, Z) is a constant.
>       if (APIntY.uge(APIntZ) &&
>           (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap())
>                     : BO0->hasNoUnsignedWrap())) {
>         BinaryOperator *BO =
>             BinaryOperator::CreateMul(X, ConstantInt::get(Ty, RemYZ));
>         BO->setHasNoSignedWrap();
>         if (!IsSigned || BO0->hasNoUnsignedWrap())
>           BO->setHasNoUnsignedWrap();
>         return BO;
>       }
>     }
>   
>     // Check if desirable to do generic replacement.
>     // NB: It may be beneficial to do this if we have X << Z even if there are
>     // multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power
>     // of 2 is converted to add/and) and urem is pretty expensive (maybe more
>     // sense in DAGCombiner).
>     if ((ConstY && ConstZ) ||
>         (Op0->hasOneUse() && Op1->hasOneUse() &&
>          (IsSigned ? (BO0->getOpcode() != Instruction::Shl &&
>                       BO1->getOpcode() != Instruction::Shl)
>                    : (BO0->getOpcode() != Instruction::Shl ||
>                       BO1->getOpcode() == Instruction::Shl)))) {
>       // (rem (mul nuw/nsw X, Y), (mul nuw {nsw} X, Z)
>       //        -> (mul nuw/nsw X, (rem Y, Z))
>       if (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap() &&
>                       BO1->hasNoUnsignedWrap())
>                    : (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) {
>         // Convert the shifts to multiplies, cleaned up elsewhere.
>         if (BO0->getOpcode() == Instruction::Shl)
>           Y = Builder.CreateShl(ConstantInt::get(Ty, 1), Y);
>         if (BO1->getOpcode() == Instruction::Shl)
>           Z = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
>         BinaryOperator *BO = BinaryOperator::CreateMul(
>             X, IsSigned ? Builder.CreateSRem(Y, Z) : Builder.CreateURem(Y, Z));
>   
>         if (IsSigned || BO0->hasNoSignedWrap() || BO1->hasNoSignedWrap())
>           BO->setHasNoSignedWrap();
>         if (!IsSigned || (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
>           BO->setHasNoUnsignedWrap();
>         return BO;
>       }
>     }
>   
>     return nullptr;
>   }
>
> This function could be split even further into SimplifyIRemMulShlConst and SimplifyIRemMulShlGeneric, or we could just leave the generic case in this function.





================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1756-1759
+    // BO0 = X * Y
+    BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
+    // BO1 = X * Z
+    BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
----------------
MattDevereau wrote:
> I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.
> I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.

Imo there is no real cost to the extra comment, can obv delete but seems like no harm for potential clarity.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1764-1767
+      bool NSW0 = BO0->hasNoSignedWrap();
+      bool NSW1 = BO1->hasNoSignedWrap();
+      bool NUW0 = BO0->hasNoUnsignedWrap();
+      bool NUW1 = BO1->hasNoUnsignedWrap();
----------------
MattDevereau wrote:
> In my opinion these variables names are less descriptive than having the function calls inline
> In my opinion these variables names are less descriptive than having the function calls inline

Its mostly stylistic to save column width, with the full BO....->has...() makes the if conditions very large and imo harder to follow.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D143014/new/

https://reviews.llvm.org/D143014



More information about the llvm-commits mailing list