[llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)

via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 6 10:37:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

This is an alternative to the new pass proposed in #<!-- -->87544.

This improves a pattern that occurs in 531.deepsjeng_r. Reducing the dynamic instruction count by 0.5%.

This may be possible to improve in SelectionDAG, but given the special cases around shXadd formation, it's not obvious it can be done in a robust way without adding multiple special cases.

I've used a GEP with 2 indices because that mostly closely resembles the motivating case. Most of the test cases are the simplest GEP case. One test has a logical right shift on an index which is closer to the deepsjeng code. This requires special handling in isel to reverse a DAGCombiner canonicalization that turns a pair of shifts into (srl (and X, C1), C2).

See also #<!-- -->85734 which had a hacky version of a similar optimization.

---
Full diff: https://github.com/llvm/llvm-project/pull/87884.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/MachineCombinerPattern.h (+2) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+159) 
- (modified) llvm/test/CodeGen/RISCV/rv64zba.ll (+16-24) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 89eed7463bd783..41b73eaae0298c 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -175,6 +175,8 @@ enum class MachineCombinerPattern {
   FMADD_XA,
   FMSUB,
   FNMSUB,
+  SHXADD_ADD_SLLI_OP1,
+  SHXADD_ADD_SLLI_OP2,
 
   // X86 VNNI
   DPWSSD,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5582de51b17d19..8b36db23e94c41 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1829,6 +1829,84 @@ static bool getFPPatterns(MachineInstr &Root,
   return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
 }
 
+/// Utility routine that checks if \param MO is defined by an
+/// \param CombineOpc instruction in the basic block \param MBB
+static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
+                                      const MachineOperand &MO,
+                                      unsigned CombineOpc) {
+  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  const MachineInstr *MI = nullptr;
+
+  if (MO.isReg() && MO.getReg().isVirtual())
+    MI = MRI.getUniqueVRegDef(MO.getReg());
+  // And it needs to be in the trace (otherwise, it won't have a depth).
+  if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
+    return nullptr;
+  // Must only used by the user we combine with.
+  if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+    return nullptr;
+
+  return MI;
+}
+
+/// Utility routine that checks if \param MO is defined by a SLLI in \param
+/// MBB that can be combined by splitting across 2 SHXADD instructions. The
+/// first SHXADD shift amount is given by \param OuterShiftAmt.
+static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
+                                      const MachineOperand &MO,
+                                      unsigned OuterShiftAmt) {
+  const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
+  if (!ShiftMI)
+    return false;
+
+  unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+  if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
+    return false;
+
+  return true;
+}
+
+// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
+// (sh3add (sh2add Y, Z), X).
+static bool
+getSHXADDPatterns(const MachineInstr &Root,
+                  SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+  unsigned Opc = Root.getOpcode();
+
+  unsigned ShiftAmt;
+  switch (Opc) {
+  default:
+    return false;
+  case RISCV::SH1ADD:
+    ShiftAmt = 1;
+    break;
+  case RISCV::SH2ADD:
+    ShiftAmt = 2;
+    break;
+  case RISCV::SH3ADD:
+    ShiftAmt = 3;
+    break;
+  }
+
+  const MachineBasicBlock &MBB = *Root.getParent();
+
+  const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
+  if (!AddMI)
+    return false;
+
+  bool Found = false;
+  if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
+    Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
+    Found = true;
+  }
+  if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
+    Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
+    Found = true;
+  }
+
+  return Found;
+}
+
 bool RISCVInstrInfo::getMachineCombinerPatterns(
     MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
     bool DoRegPressureReduce) const {
@@ -1836,6 +1914,9 @@ bool RISCVInstrInfo::getMachineCombinerPatterns(
   if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
     return true;
 
+  if (getSHXADDPatterns(Root, Patterns))
+    return true;
+
   return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
                                                      DoRegPressureReduce);
 }
@@ -1918,6 +1999,78 @@ static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
   DelInstrs.push_back(&Root);
 }
 
+// Combine (sh3add Z, (add X, (slli Y, 5))) to (sh3add (sh2add Y, Z), X).
+static void
+genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
+                  SmallVectorImpl<MachineInstr *> &InsInstrs,
+                  SmallVectorImpl<MachineInstr *> &DelInstrs,
+                  DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+  MachineFunction *MF = Root.getMF();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+  const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+  unsigned OuterShiftAmt;
+  switch (Root.getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case RISCV::SH1ADD:
+    OuterShiftAmt = 1;
+    break;
+  case RISCV::SH2ADD:
+    OuterShiftAmt = 2;
+    break;
+  case RISCV::SH3ADD:
+    OuterShiftAmt = 3;
+    break;
+  }
+
+  MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
+  MachineInstr *ShiftMI =
+      MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
+
+  unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+  assert(InnerShiftAmt > OuterShiftAmt && "Unexpected shift amount");
+
+  unsigned InnerOpc;
+  switch (InnerShiftAmt - OuterShiftAmt) {
+  default:
+    llvm_unreachable("Unexpected shift amount");
+  case 0:
+    InnerOpc = RISCV::ADD;
+    break;
+  case 1:
+    InnerOpc = RISCV::SH1ADD;
+    break;
+  case 2:
+    InnerOpc = RISCV::SH2ADD;
+    break;
+  case 3:
+    InnerOpc = RISCV::SH3ADD;
+    break;
+  }
+
+  Register X = AddMI->getOperand(3 - AddOpIdx).getReg();
+  Register Y = ShiftMI->getOperand(1).getReg();
+  Register Z = Root.getOperand(1).getReg();
+
+  Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+
+  auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
+                  .addReg(Y)
+                  .addReg(Z);
+  auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
+                      Root.getOperand(0).getReg())
+                  .addReg(NewVR)
+                  .addReg(X);
+
+  InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+  InsInstrs.push_back(MIB1);
+  InsInstrs.push_back(MIB2);
+  DelInstrs.push_back(ShiftMI);
+  DelInstrs.push_back(AddMI);
+  DelInstrs.push_back(&Root);
+}
+
 void RISCVInstrInfo::genAlternativeCodeSequence(
     MachineInstr &Root, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -1941,6 +2094,12 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
     combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
     return;
   }
+  case MachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
+    genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+    return;
+  case MachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
+    genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+    return;
   }
 }
 
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index 7e32253c8653f1..067addc819f7e6 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
 ;
 ; RV64ZBA-LABEL: sh6_sh3_add2:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a1, a0
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ret
 entry:
   %shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh1_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh1add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh1:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh1add a0, a1, a0
 ; RV64ZBA-NEXT:    lh a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
 ; RV64ZBA-NEXT:    srli a1, a1, 58
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %shr = lshr i64 %idx1, 58

``````````

</details>


https://github.com/llvm/llvm-project/pull/87884


More information about the llvm-commits mailing list