[PATCH] D143014: [InstCombine] Add combines for `(urem/srem (mul/shl X, Y), (mul/shl X, Z))`

Matt Devereau via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 6 08:47:30 PST 2023


MattDevereau added a comment.

I'm thinking we could have a function which attempts to simplify irem/mul/shl patterns and returns successes if the simplification is possible. Similarly to how there are calls such as

  if (Instruction *R = FoldOpIntoSelect(I, SI))
            return R;

inside `commonIRemTransforms`, we could have something along the lines of

  if (Instruction *res = SimplifyIRemMulShl(I))
    return res;

at the end of `commonIRemTransforms`.

This would give us a lot more control over managing the cases and their bailout conditions separately and cleanly. I think it would also make it more obvious that we are just handling two cases primary cases as well, whereas its quite hard to parse and keep a mental note of the complicated if statements which are present currently.
I have had a go at rearranging the pre-combine logic:

  Instruction *InstCombinerImpl::SimplifyIRemMulShl(BinaryOperator &I) {
    Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
    Value *A, *B, *C, *D;
    if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) ||
          match(Op0, m_Shl(m_Value(A), m_Value(B)))) ||
        !(match(Op1, m_Mul(m_Value(C), m_Value(D))) ||
          match(Op1, m_Shl(m_Value(C), m_Value(D)))))
      return nullptr;
  
    Value *X, *Y, *Z;
    X = nullptr;
    // Do this by hand as opposed to using m_Specific because either A/B (or
    // C/D) can be our X.
    if (A == C || A == D) {
      X = A;
      Y = B;
      Z = A == C ? D : C;
    } else if (B == C || B == D) {
      X = B;
      Y = A;
      Z = B == C ? D : C;
    }
  
    BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
    BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
    if (!BO0 || !BO1)
      return nullptr;
  
    Constant *CX = X ? dyn_cast<Constant>(X) : nullptr;
    if (!X || (CX && CX->isOneValue()))
      return nullptr;
      
    Type * Ty = I.getType();
    ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
    ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
    if (Ty->isVectorTy()) {
      auto *VConstY = dyn_cast<Constant>(Y);
      auto *VConstZ = dyn_cast<Constant>(Z);
      if (VConstY && VConstZ) {
        VConstY = VConstY->getSplatValue();
        VConstZ = VConstZ->getSplatValue();
        if (VConstY && VConstZ) {
          ConstY = dyn_cast<ConstantInt>(VConstY);
          ConstZ = dyn_cast<ConstantInt>(VConstZ);
        }
      }
    }
  
    bool IsSigned = I.getOpcode() == Instruction::SRem;
    // Check constant folds first.
    if (ConstY && ConstZ) {
      APInt APIntY = ConstY->getValue();
      APInt APIntZ = ConstZ->getValue();
  
      // Just treat the shifts as mul, we may end up returning a mul by power
      // of 2 but that will be cleaned up later.
      if (BO0->getOpcode() == Instruction::Shl)
        APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY;
      if (BO1->getOpcode() == Instruction::Shl)
        APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ;
  
      APInt RemYZ = IsSigned ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ);
  
      // (rem (mul nuw/nsw X, Y), (mul X, Z))
      //      if (rem Y, Z) == 0
      //          -> 0
      if (RemYZ.isZero() &&
          (IsSigned ? BO0->hasNoSignedWrap() : BO0->hasNoUnsignedWrap()))
        return replaceInstUsesWith(I, ConstantInt::getNullValue(Ty));
  
      // (rem (mul X, Y), (mul nuw/nsw X, Z))
      //      if (rem Y, Z) == Y
      //          -> (mul nuw/nsw X, Y)
      if (RemYZ == APIntY &&
          (IsSigned ? BO1->hasNoSignedWrap() : BO1->hasNoUnsignedWrap())) {
        // We are returning Op0 essentially but we can also add no wrap flags.
        BinaryOperator *BO =
            BinaryOperator::CreateMul(X, ConstantInt::get(Ty, APIntY));
        // We can add nsw/nuw if remainder op is signed/unsigned, also we
        // can copy any overflow flags from Op0.
        if (IsSigned || BO0->hasNoSignedWrap())
          BO->setHasNoSignedWrap();
        if (!IsSigned || BO0->hasNoUnsignedWrap())
          BO->setHasNoUnsignedWrap();
        return BO;
      }
  
      // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
      //      if Y >= Z
      //          -> (mul {nuw} nsw X, (rem Y, Z))
      // NB: (rem Y, Z) is a constant.
      if (APIntY.uge(APIntZ) &&
          (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap())
                    : BO0->hasNoUnsignedWrap())) {
        BinaryOperator *BO =
            BinaryOperator::CreateMul(X, ConstantInt::get(Ty, RemYZ));
        BO->setHasNoSignedWrap();
        if (!IsSigned || BO0->hasNoUnsignedWrap())
          BO->setHasNoUnsignedWrap();
        return BO;
      }
    }
  
    // Check if desirable to do generic replacement.
    // NB: It may be beneficial to do this if we have X << Z even if there are
    // multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power
    // of 2 is converted to add/and) and urem is pretty expensive (maybe more
    // sense in DAGCombiner).
    if ((ConstY && ConstZ) ||
        (Op0->hasOneUse() && Op1->hasOneUse() &&
         (IsSigned ? (BO0->getOpcode() != Instruction::Shl &&
                      BO1->getOpcode() != Instruction::Shl)
                   : (BO0->getOpcode() != Instruction::Shl ||
                      BO1->getOpcode() == Instruction::Shl)))) {
      // (rem (mul nuw/nsw X, Y), (mul nuw {nsw} X, Z)
      //        -> (mul nuw/nsw X, (rem Y, Z))
      if (IsSigned ? (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap() &&
                      BO1->hasNoUnsignedWrap())
                   : (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) {
        // Convert the shifts to multiplies, cleaned up elsewhere.
        if (BO0->getOpcode() == Instruction::Shl)
          Y = Builder.CreateShl(ConstantInt::get(Ty, 1), Y);
        if (BO1->getOpcode() == Instruction::Shl)
          Z = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
        BinaryOperator *BO = BinaryOperator::CreateMul(
            X, IsSigned ? Builder.CreateSRem(Y, Z) : Builder.CreateURem(Y, Z));
  
        if (IsSigned || BO0->hasNoSignedWrap() || BO1->hasNoSignedWrap())
          BO->setHasNoSignedWrap();
        if (!IsSigned || (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
          BO->setHasNoUnsignedWrap();
        return BO;
      }
    }
  
    return nullptr;
  }

This function could be split even further into SimplifyIRemMulShlConst and SimplifyIRemMulShlGeneric, or we could just leave the generic case in this function.



================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1756-1759
+    // BO0 = X * Y
+    BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
+    // BO1 = X * Z
+    BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
----------------
I think we already know `BO0 = X * Y` and `BO1 = X * Z`by now from other comments and matches.


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1764-1767
+      bool NSW0 = BO0->hasNoSignedWrap();
+      bool NSW1 = BO1->hasNoSignedWrap();
+      bool NUW0 = BO0->hasNoUnsignedWrap();
+      bool NUW1 = BO1->hasNoUnsignedWrap();
----------------
In my opinion these variables names are less descriptive than having the function calls inline


================
Comment at: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp:1772-1789
+      // Try and get Y/Z as constants.
+      ConstantInt *ConstY = nullptr;
+      ConstantInt *ConstZ = nullptr;
+      if (Ty->isVectorTy()) {
+        auto *VConstY = dyn_cast<Constant>(Y);
+        auto *VConstZ = dyn_cast<Constant>(Z);
+        if (VConstY && VConstZ) {
----------------
We can remove the else statement to the start here, and assign ConstY/ConstZ instead like so:


```
  ConstantInt *ConstY = dyn_cast<ConstantInt>(Y);
  ConstantInt *ConstZ = dyn_cast<ConstantInt>(Z);
  if (Ty->isVectorTy()) {
    auto *VConstY = dyn_cast<Constant>(Y);
    auto *VConstZ = dyn_cast<Constant>(Z);
    if (VConstY && VConstZ) {
      VConstY = VConstY->getSplatValue();
      VConstZ = VConstZ->getSplatValue();
      if (VConstY && VConstZ) {
        ConstY = dyn_cast<ConstantInt>(VConstY);
        ConstZ = dyn_cast<ConstantInt>(VConstZ);
      }
    }
  }
```


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D143014/new/

https://reviews.llvm.org/D143014



More information about the llvm-commits mailing list