[llvm] [GlobalISel] Handle div-by-pow2 (PR #83155)

Jay Foad via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 22 03:17:24 PDT 2024


================
@@ -5270,6 +5270,144 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
   return MIB.buildMul(Ty, Res, Factor);
 }
 
+bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) {
+  assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
+          MI.getOpcode() == TargetOpcode::G_UDIV) &&
+         "Expected SDIV or UDIV");
+  auto &Div = cast<GenericMachineInstr>(MI);
+  Register RHS = Div.getReg(2);
+  auto MatchPow2 = [&](const Constant *C) {
+    auto *CI = dyn_cast<ConstantInt>(C);
+    return CI && (CI->getValue().isPowerOf2() ||
+                  (IsSigned && CI->getValue().isNegatedPowerOf2()));
+  };
+  return matchUnaryPredicate(MRI, RHS, MatchPow2, /*AllowUndefs=*/false);
+}
+
+void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
+  auto &SDiv = cast<GenericMachineInstr>(MI);
+  Register Dst = SDiv.getReg(0);
+  Register LHS = SDiv.getReg(1);
+  Register RHS = SDiv.getReg(2);
+  LLT Ty = MRI.getType(Dst);
+  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
+  // to the following version:
+  //
+  // %c1 = G_CTTZ %rhs
+  // %inexact = G_SUB $bitwidth, %c1
+  // %sign = %G_ASHR %lhs, $(bitwidth - 1)
+  // %srl = G_SHR %sign, %inexact
+  // %add = G_ADD %lhs, %srl
+  // %sra = G_ASHR %add, %c1
+  // %sra = G_SELECT, %isoneorallones, %lhs, %sra
+  // %zero = G_CONSTANT $0
+  // %neg = G_NEG %sra
+  // %isneg = G_ICMP SLT %lhs, %zero
+  // %res = G_SELECT %isneg, %neg, %sra
+  //
+  // When %rhs is a constant integer, or a splat vector, we can check its value
+  // at compile time such that the first two G_ICMP conditional statements, as
+  // well as the corresponding non-taken branches, can be eliminated. This can
+  // generate compact code even w/o any constant folding afterwards. When $rhs
+  // is not a splat vector, we have to generate those checks via instructions.
+
+  unsigned Bitwidth = Ty.getScalarSizeInBits();
+  auto Zero = Builder.buildConstant(Ty, 0);
+
+  // TODO: It is not necessary to have this specialized version. We need it *for
+  // now* because the folding/combine can't handle it. Remove this large
+  // conditional statement once we can properly fold the two G_ICMP.
+  if (auto RHSC = getConstantOrConstantSplatVector(RHS)) {
+    // Special case: (sdiv X, 1) -> X
+    if (RHSC->isOne()) {
+      replaceSingleDefInstWithReg(MI, LHS);
+      return;
+    }
+    // Special Case: (sdiv X, -1) -> 0-X
+    if (RHSC->isAllOnes()) {
+      auto Neg = Builder.buildNeg(Ty, LHS);
+      replaceSingleDefInstWithReg(MI, Neg->getOperand(0).getReg());
+      return;
+    }
+
+    unsigned TrailingZeros = RHSC->countTrailingZeros();
+    auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
+    auto Inexact = Builder.buildConstant(ShiftAmtTy, Bitwidth - TrailingZeros);
+    auto Sign = Builder.buildAShr(
+        Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
+    // Add (LHS < 0) ? abs2 - 1 : 0;
+    auto Lshr = Builder.buildLShr(Ty, Sign, Inexact);
+    auto Add = Builder.buildAdd(Ty, LHS, Lshr);
+    auto Shr = Builder.buildAShr(Ty, Add, C1);
+
+    // If dividing by a positive value, we're done. Otherwise, the result must
+    // be negated.
+    auto Res = RHSC->isNegative() ? Builder.buildNeg(Ty, Shr) : Shr;
+    replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
+    return;
+  }
+
+  // RHS is not a splat vector. Build the above version with instructions.
+  auto Bits = Builder.buildConstant(ShiftAmtTy, Bitwidth);
+  auto C1 = Builder.buildCTTZ(Ty, RHS);
+  C1 = Builder.buildZExtOrTrunc(ShiftAmtTy, C1);
+  auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1);
+  auto Sign = Builder.buildAShr(
+      Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
+
+  // Add (LHS < 0) ? abs2 - 1 : 0;
+  auto Shl = Builder.buildShl(Ty, Sign, Inexact);
+  auto Add = Builder.buildAdd(Ty, LHS, Shl);
+  auto Shr = Builder.buildAShr(Ty, Add, C1);
+
+  LLT CCVT = LLT::vector(Ty.getElementCount(), 1);
+
+  auto One = Builder.buildConstant(Ty, 1);
+  auto AllOnes =
+      Builder.buildConstant(Ty, APInt::getAllOnes(Ty.getScalarSizeInBits()));
+  auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One);
+  auto IsAllOnes =
+      Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, AllOnes);
+  auto IsOneOrAllOnes = Builder.buildOr(CCVT, IsOne, IsAllOnes);
+  Shr = Builder.buildSelect(Ty, IsOneOrAllOnes, LHS, Shr);
+
+  // If dividing by a positive value, we're done. Otherwise, the result must
+  // be negated.
+  auto Neg = Builder.buildNeg(Ty, Shr);
+  auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, LHS, Zero);
+  Builder.buildSelect(MI.getOperand(0).getReg(), IsNeg, Neg, Shr);
+}
+
+void CombinerHelper::applyUDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
+  auto &UDiv = cast<GenericMachineInstr>(MI);
+  Register Dst = UDiv.getReg(0);
+  Register LHS = UDiv.getReg(1);
+  Register RHS = UDiv.getReg(2);
+  LLT Ty = MRI.getType(Dst);
+  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  // TODO: It is not necessary to have this specialized version. We need it *for
+  // now* because the folding/combine can't handle CTTZ.
+  if (auto RHSC = getConstantOrConstantSplatVector(RHS)) {
----------------
jayfoad wrote:

It would be better to get #86224 merged first.

https://github.com/llvm/llvm-project/pull/83155


More information about the llvm-commits mailing list