[llvm] [GlobalISel] Handle div-by-pow2 (PR #83155)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 10:04:54 PST 2024


https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/83155

>From e279053c2266ab59508b9a75ad5bbbe2ab60ba69 Mon Sep 17 00:00:00 2001
From: Shilei Tian <i at tianshilei.me>
Date: Wed, 28 Feb 2024 13:04:43 -0500
Subject: [PATCH] [GlobalISel] Handle div-by-pow2

This patch adds similar handling of div-by-pow2 as in `SelectionDAG`.
---
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |  10 ++
 .../include/llvm/Target/GlobalISel/Combine.td |  21 +++-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 102 +++++++++++++++++-
 llvm/test/CodeGen/AMDGPU/div_i128.ll          |   4 +-
 4 files changed, 130 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index 23728636498ba0..c19efba984d0d9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -673,6 +673,16 @@ class CombinerHelper {
   bool matchSDivByConst(MachineInstr &MI);
   void applySDivByConst(MachineInstr &MI);
 
+  /// Given an G_SDIV \p MI expressing a signed divided by a pow2 constant,
+  /// return expressions that implements it by shifting.
+  bool matchSDivByPow2(MachineInstr &MI);
+  void applySDivByPow2(MachineInstr &MI);
+
+  /// Given an G_UDIV \p MI expressing an unsigned divided by a pow2 constant,
+  /// return expressions that implements it by shifting.
+  bool matchUDivByPow2(MachineInstr &MI);
+  void applyUDivByPow2(MachineInstr &MI);
+
   // G_UMULH x, (1 << c)) -> x >> (bitwidth - c)
   bool matchUMulHToLShr(MachineInstr &MI);
   void applyUMulHToLShr(MachineInstr &MI);
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 17757ca3e41111..1d9a60bd27e7ac 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -264,7 +264,7 @@ def combine_extracted_vector_load : GICombineRule<
   (match (wip_match_opcode G_EXTRACT_VECTOR_ELT):$root,
         [{ return Helper.matchCombineExtractedVectorLoad(*${root}, ${matchinfo}); }]),
   (apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
-  
+
 def combine_indexed_load_store : GICombineRule<
   (defs root:$root, indexed_load_store_matchdata:$matchinfo),
   (match (wip_match_opcode G_LOAD, G_SEXTLOAD, G_ZEXTLOAD, G_STORE):$root,
@@ -1005,7 +1005,20 @@ def sdiv_by_const : GICombineRule<
    [{ return Helper.matchSDivByConst(*${root}); }]),
   (apply [{ Helper.applySDivByConst(*${root}); }])>;
 
-def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const]>;
+def sdiv_by_pow2 : GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_SDIV):$root,
+   [{ return Helper.matchSDivByPow2(*${root}); }]),
+  (apply [{ Helper.applySDivByPow2(*${root}); }])>;
+
+def udiv_by_pow2 : GICombineRule<
+  (defs root:$root),
+  (match (wip_match_opcode G_UDIV):$root,
+   [{ return Helper.matchUDivByPow2(*${root}); }]),
+  (apply [{ Helper.applyUDivByPow2(*${root}); }])>;
+
+def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const,
+                                      sdiv_by_pow2, udiv_by_pow2]>;
 
 def reassoc_ptradd : GICombineRule<
   (defs root:$root, build_fn_matchinfo:$matchinfo),
@@ -1325,7 +1338,7 @@ def constant_fold_binops : GICombineGroup<[constant_fold_binop,
 
 def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
     extract_vec_elt_combines, combines_for_extload, combine_extracted_vector_load,
-    undef_combines, identity_combines, phi_combines, 
+    undef_combines, identity_combines, phi_combines,
     simplify_add_to_sub, hoist_logic_op_with_same_opcode_hands, shifts_too_big,
     reassocs, ptr_add_immed_chain,
     shl_ashr_to_sext_inreg, sext_inreg_of_load,
@@ -1342,7 +1355,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
     intdiv_combines, mulh_combines, redundant_neg_operands,
     and_or_disjoint_mask, fma_combines, fold_binop_into_select,
     sub_add_reg, select_to_minmax, redundant_binop_in_equality,
-    fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors, 
+    fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
     combine_concat_vector]>;
 
 // A combine group used to for prelegalizer combiners at -O0. The combines in
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 2f18a64ca285bd..d094fcd0ec3af8 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -1490,7 +1490,7 @@ void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
   Observer.changedInstr(*BrCond);
 }
 
- 
+
 bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
   MachineIRBuilder HelperBuilder(MI);
   GISelObserverWrapper DummyObserver;
@@ -5286,6 +5286,106 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
   return MIB.buildMul(Ty, Res, Factor);
 }
 
+bool CombinerHelper::matchSDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
+  if (MI.getFlag(MachineInstr::MIFlag::IsExact))
+    return false;
+  auto &SDiv = cast<GenericMachineInstr>(MI);
+  Register RHS = SDiv.getReg(2);
+  auto MatchPow2 = [&](const Constant *C) {
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return CI->getValue().isPowerOf2() || CI->getValue().isNegatedPowerOf2();
+    return false;
+  };
+  return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
+}
+
+void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
+  auto &SDiv = cast<GenericMachineInstr>(MI);
+  Register Dst = SDiv.getReg(0);
+  Register LHS = SDiv.getReg(1);
+  Register RHS = SDiv.getReg(2);
+  LLT Ty = MRI.getType(Dst);
+  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
+  assert(RHSC.has_value() && "RHS must be a constant");
+  auto RHSCV = RHSC->Value;
+  auto Zero = Builder.buildConstant(Ty, 0);
+
+  // Special case: (sdiv X, 1) -> X
+  if (RHSCV.isOne()) {
+    replaceSingleDefInstWithReg(MI, LHS);
+    return;
+  }
+  // Special Case: (sdiv X, -1) -> 0-X
+  if (RHSCV.isAllOnes()) {
+    auto Sub = Builder.buildSub(Ty, Zero, LHS);
+    replaceSingleDefInstWithReg(MI, Sub->getOperand(0).getReg());
+    return;
+  }
+
+  unsigned Bitwidth = Ty.getScalarSizeInBits();
+  unsigned TrailingZeros = RHSCV.countTrailingZeros();
+  auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
+  auto Inexact = Builder.buildConstant(ShiftAmtTy, Bitwidth - TrailingZeros);
+  auto Sign = Builder.buildAShr(
+      Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
+  // Add (LHS < 0) ? abs2 - 1 : 0;
+  auto Srl = Builder.buildShl(Ty, Sign, Inexact);
+  auto Add = Builder.buildAdd(Ty, LHS, Srl);
+  auto Sra = Builder.buildAShr(Ty, Add, C1);
+
+  // If dividing by a positive value, we're done. Otherwise, the result must
+  // be negated.
+  auto Res = RHSCV.isNegative() ? Builder.buildSub(Ty, Zero, Sra) : Sra;
+  replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
+}
+
+bool CombinerHelper::matchUDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
+  if (MI.getFlag(MachineInstr::MIFlag::IsExact))
+    return false;
+  auto &UDiv = cast<GenericMachineInstr>(MI);
+  Register RHS = UDiv.getReg(2);
+  auto MatchPow2 = [&](const Constant *C) {
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return CI->getValue().isPowerOf2();
+    return false;
+  };
+  return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs */ false);
+}
+
+void CombinerHelper::applyUDivByPow2(MachineInstr &MI) {
+  assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected SDIV");
+  auto &UDiv = cast<GenericMachineInstr>(MI);
+  Register Dst = UDiv.getReg(0);
+  Register LHS = UDiv.getReg(1);
+  Register RHS = UDiv.getReg(2);
+  LLT Ty = MRI.getType(Dst);
+  LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
+  assert(RHSC.has_value() && "RHS must be a constant");
+  auto RHSCV = RHSC->Value;
+
+  // Special case: (udiv X, 1) -> X
+  if (RHSCV.isOne()) {
+    replaceSingleDefInstWithReg(MI, LHS);
+    return;
+  }
+
+  unsigned TrailingZeros = RHSCV.countTrailingZeros();
+  auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
+  auto Res = Builder.buildLShr(Ty, LHS, C1);
+  replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
+}
+
 bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
   assert(MI.getOpcode() == TargetOpcode::G_UMULH);
   Register RHS = MI.getOperand(2).getReg();
diff --git a/llvm/test/CodeGen/AMDGPU/div_i128.ll b/llvm/test/CodeGen/AMDGPU/div_i128.ll
index 5296ad3ab51d31..0abbef2df38ccd 100644
--- a/llvm/test/CodeGen/AMDGPU/div_i128.ll
+++ b/llvm/test/CodeGen/AMDGPU/div_i128.ll
@@ -3,8 +3,8 @@
 ; RUN: llc -O0 -global-isel=0 -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 -o - %s | FileCheck -check-prefixes=GFX9-O0,GFX9-SDAG-O0 %s
 
 ; FIXME: GlobalISel missing the power-of-2 cases in legalization. https://github.com/llvm/llvm-project/issues/80671
-; xUN: llc -global-isel=1 -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 -o - %s | FileCheck -check-prefixes=GFX9,GFX9 %s
-; xUN: llc -O0 -global-isel=1 -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 -o - %s | FileCheck -check-prefixes=GFX9-O0,GFX9-O0 %s
+; RUN: llc -global-isel=1 -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 -o - %s | FileCheck -check-prefixes=GFX9,GFX9 %s
+; RUN: llc -O0 -global-isel=1 -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 -o - %s | FileCheck -check-prefixes=GFX9-O0,GFX9-O0 %s
 
 define i128 @v_sdiv_i128_vv(i128 %lhs, i128 %rhs) {
 ; GFX9-LABEL: v_sdiv_i128_vv:



More information about the llvm-commits mailing list