[llvm] [GlobalISel] Handle div-by-pow2 (PR #83155)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 00:33:31 PST 2024


================
@@ -5286,6 +5286,106 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
   return MIB.buildMul(Ty, Res, Factor);
 }
 
+bool CombinerHelper::matchSDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
+  if (MI.getFlag(MachineInstr::MIFlag::IsExact))
+    return false;
+  auto &SDiv = cast<GenericMachineInstr>(MI);
+  Register RHS = SDiv.getReg(2);
+  auto MatchPow2 = [&](const Constant *C) {
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return CI->getValue().isPowerOf2() || CI->getValue().isNegatedPowerOf2();
+    return false;
+  };
+  return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
+}
+
+void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
+  auto &SDiv = cast<GenericMachineInstr>(MI);
+  Register Dst = SDiv.getReg(0);
+  Register LHS = SDiv.getReg(1);
+  Register RHS = SDiv.getReg(2);
+  LLT Ty = MRI.getType(Dst);
+  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
+  assert(RHSC.has_value() && "RHS must be a constant");
+  auto RHSCV = RHSC->Value;
+  auto Zero = Builder.buildConstant(Ty, 0);
+
+  // Special case: (sdiv X, 1) -> X
+  if (RHSCV.isOne()) {
+    replaceSingleDefInstWithReg(MI, LHS);
+    return;
+  }
+  // Special Case: (sdiv X, -1) -> 0-X
+  if (RHSCV.isAllOnes()) {
+    auto Sub = Builder.buildSub(Ty, Zero, LHS);
+    replaceSingleDefInstWithReg(MI, Sub->getOperand(0).getReg());
+    return;
+  }
+
+  unsigned Bitwidth = Ty.getScalarSizeInBits();
+  unsigned TrailingZeros = RHSCV.countTrailingZeros();
+  auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
+  auto Inexact = Builder.buildConstant(ShiftAmtTy, Bitwidth - TrailingZeros);
+  auto Sign = Builder.buildAShr(
+      Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
+  // Add (LHS < 0) ? abs2 - 1 : 0;
+  auto Srl = Builder.buildShl(Ty, Sign, Inexact);
+  auto Add = Builder.buildAdd(Ty, LHS, Srl);
+  auto Sra = Builder.buildAShr(Ty, Add, C1);
+
+  // If dividing by a positive value, we're done. Otherwise, the result must
+  // be negated.
+  auto Res = RHSCV.isNegative() ? Builder.buildSub(Ty, Zero, Sra) : Sra;
+  replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
+}
+
+bool CombinerHelper::matchUDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
+  if (MI.getFlag(MachineInstr::MIFlag::IsExact))
+    return false;
+  auto &UDiv = cast<GenericMachineInstr>(MI);
+  Register RHS = UDiv.getReg(2);
+  auto MatchPow2 = [&](const Constant *C) {
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return CI->getValue().isPowerOf2();
+    return false;
+  };
+  return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
----------------
arsenm wrote:

```suggestion
  return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs= */ false);
```

https://github.com/llvm/llvm-project/pull/83155


More information about the llvm-commits mailing list