[llvm] 7f1b9ad - [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (#87884)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 10 08:39:59 PDT 2024


Author: Craig Topper
Date: 2024-04-10T08:39:56-07:00
New Revision: 7f1b9adfc8d86c77ee87a268b3d30e0eda8ed493

URL: https://github.com/llvm/llvm-project/commit/7f1b9adfc8d86c77ee87a268b3d30e0eda8ed493
DIFF: https://github.com/llvm/llvm-project/commit/7f1b9adfc8d86c77ee87a268b3d30e0eda8ed493.diff

LOG: [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (#87884)

This improves a pattern that occurs in 531.deepsjeng_r. Reducing the
dynamic instruction count by 0.5%.

This may be possible to improve in SelectionDAG, but given the special
cases around shXadd formation, it's not obvious it can be done in a
robust way without adding multiple special cases.

I've used a GEP with 2 indices because that mostly closely resembles the
motivating case. Most of the test cases are the simplest GEP case. One
test has a logical right shift on an index which is closer to the
deepsjeng code. This requires special handling in isel to reverse a
DAGCombiner canonicalization that turns a pair of shifts into (srl (and
X, C1), C2).

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/MachineCombinerPattern.h
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/test/CodeGen/RISCV/rv64zba.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 89eed7463bd783..41b73eaae0298c 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -175,6 +175,8 @@ enum class MachineCombinerPattern {
   FMADD_XA,
   FMSUB,
   FNMSUB,
+  SHXADD_ADD_SLLI_OP1,
+  SHXADD_ADD_SLLI_OP2,
 
   // X86 VNNI
   DPWSSD,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index be63bc936ae8a1..6b75efe684d913 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1775,6 +1775,86 @@ static bool getFPPatterns(MachineInstr &Root,
   return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
 }
 
+/// Utility routine that checks if \param MO is defined by an
+/// \param CombineOpc instruction in the basic block \param MBB
+static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
+                                      const MachineOperand &MO,
+                                      unsigned CombineOpc) {
+  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  const MachineInstr *MI = nullptr;
+
+  if (MO.isReg() && MO.getReg().isVirtual())
+    MI = MRI.getUniqueVRegDef(MO.getReg());
+  // And it needs to be in the trace (otherwise, it won't have a depth).
+  if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
+    return nullptr;
+  // Must only used by the user we combine with.
+  if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+    return nullptr;
+
+  return MI;
+}
+
+/// Utility routine that checks if \param MO is defined by a SLLI in \param
+/// MBB that can be combined by splitting across 2 SHXADD instructions. The
+/// first SHXADD shift amount is given by \param OuterShiftAmt.
+static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
+                                      const MachineOperand &MO,
+                                      unsigned OuterShiftAmt) {
+  const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
+  if (!ShiftMI)
+    return false;
+
+  unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+  if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
+    return false;
+
+  return true;
+}
+
+// Returns the shift amount from a SHXADD instruction. Returns 0 if the
+// instruction is not a SHXADD.
+static unsigned getSHXADDShiftAmount(unsigned Opc) {
+  switch (Opc) {
+  default:
+    return 0;
+  case RISCV::SH1ADD:
+    return 1;
+  case RISCV::SH2ADD:
+    return 2;
+  case RISCV::SH3ADD:
+    return 3;
+  }
+}
+
+// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
+// (sh3add (sh2add Y, Z), X).
+static bool
+getSHXADDPatterns(const MachineInstr &Root,
+                  SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+  unsigned ShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
+  if (!ShiftAmt)
+    return false;
+
+  const MachineBasicBlock &MBB = *Root.getParent();
+
+  const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
+  if (!AddMI)
+    return false;
+
+  bool Found = false;
+  if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
+    Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
+    Found = true;
+  }
+  if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
+    Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
+    Found = true;
+  }
+
+  return Found;
+}
+
 bool RISCVInstrInfo::getMachineCombinerPatterns(
     MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
     bool DoRegPressureReduce) const {
@@ -1782,6 +1862,9 @@ bool RISCVInstrInfo::getMachineCombinerPatterns(
   if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
     return true;
 
+  if (getSHXADDPatterns(Root, Patterns))
+    return true;
+
   return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
                                                      DoRegPressureReduce);
 }
@@ -1864,6 +1947,68 @@ static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
   DelInstrs.push_back(&Root);
 }
 
+// Combine patterns like (sh3add Z, (add X, (slli Y, 5))) to
+// (sh3add (sh2add Y, Z), X) if the shift amount can be split across two
+// shXadd instructions. The outer shXadd keeps its original opcode.
+static void
+genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
+                  SmallVectorImpl<MachineInstr *> &InsInstrs,
+                  SmallVectorImpl<MachineInstr *> &DelInstrs,
+                  DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+  MachineFunction *MF = Root.getMF();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+  const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+  unsigned OuterShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
+  assert(OuterShiftAmt != 0 && "Unexpected opcode");
+
+  MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
+  MachineInstr *ShiftMI =
+      MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
+
+  unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+  assert(InnerShiftAmt > OuterShiftAmt && "Unexpected shift amount");
+
+  unsigned InnerOpc;
+  switch (InnerShiftAmt - OuterShiftAmt) {
+  default:
+    llvm_unreachable("Unexpected shift amount");
+  case 0:
+    InnerOpc = RISCV::ADD;
+    break;
+  case 1:
+    InnerOpc = RISCV::SH1ADD;
+    break;
+  case 2:
+    InnerOpc = RISCV::SH2ADD;
+    break;
+  case 3:
+    InnerOpc = RISCV::SH3ADD;
+    break;
+  }
+
+  const MachineOperand &X = AddMI->getOperand(3 - AddOpIdx);
+  const MachineOperand &Y = ShiftMI->getOperand(1);
+  const MachineOperand &Z = Root.getOperand(1);
+
+  Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+
+  auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
+                  .addReg(Y.getReg(), getKillRegState(Y.isKill()))
+                  .addReg(Z.getReg(), getKillRegState(Z.isKill()));
+  auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
+                      Root.getOperand(0).getReg())
+                  .addReg(NewVR, RegState::Kill)
+                  .addReg(X.getReg(), getKillRegState(X.isKill()));
+
+  InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+  InsInstrs.push_back(MIB1);
+  InsInstrs.push_back(MIB2);
+  DelInstrs.push_back(ShiftMI);
+  DelInstrs.push_back(AddMI);
+  DelInstrs.push_back(&Root);
+}
+
 void RISCVInstrInfo::genAlternativeCodeSequence(
     MachineInstr &Root, MachineCombinerPattern Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -1887,6 +2032,12 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
     combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
     return;
   }
+  case MachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
+    genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+    return;
+  case MachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
+    genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+    return;
   }
 }
 

diff  --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index 7e32253c8653f1..067addc819f7e6 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
 ;
 ; RV64ZBA-LABEL: sh6_sh3_add2:
 ; RV64ZBA:       # %bb.0: # %entry
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a1, a0
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ret
 entry:
   %shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh1_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh1add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh2_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh2add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh1:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 4
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh1add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh1add a0, a1, a0
 ; RV64ZBA-NEXT:    lh a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh2:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 5
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh2add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh2add a0, a1, a0
 ; RV64ZBA-NEXT:    lw a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ;
 ; RV64ZBA-LABEL: array_index_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
 ; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
 ; RV64ZBA:       # %bb.0:
 ; RV64ZBA-NEXT:    srli a1, a1, 58
-; RV64ZBA-NEXT:    slli a1, a1, 6
-; RV64ZBA-NEXT:    add a0, a0, a1
-; RV64ZBA-NEXT:    sh3add a0, a2, a0
+; RV64ZBA-NEXT:    sh3add a1, a1, a2
+; RV64ZBA-NEXT:    sh3add a0, a1, a0
 ; RV64ZBA-NEXT:    ld a0, 0(a0)
 ; RV64ZBA-NEXT:    ret
   %shr = lshr i64 %idx1, 58


        


More information about the llvm-commits mailing list