[llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 9 10:07:31 PDT 2024
- Previous message: [llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)
- Next message: [llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/87884
>From 1af7725ad971663d1719fba27bfea7f3a6c134fc Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Sat, 6 Apr 2024 09:40:11 -0700
Subject: [PATCH 1/2] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X,
(slli Y, 6))) -> (sh3add (sh3add Y, Z), X).
This is an alternative to the new pass proposed in #87544.
This improves a pattern that occurs in 531.deepsjeng_r. Reducing
the dynamic instruction count by 0.5%.
This may be possible to improve in SelectionDAG, but given the special
cases around shXadd formation, it's not obvious it can be done in a
robust way without adding multiple special cases.
I've used a GEP with 2 indices because that mostly closely resembles
the motivating case. Most of the test cases are the simplest GEP case.
One test has a logical right shift on an index which is closer to
the deepsjeng code. This requires special handling in isel to reverse
a DAGCombiner canonicalization that turns a pair of shifts into
(srl (and X, C1), C2).
See also #85734 which had a hacky version of a similar optimization.
---
.../llvm/CodeGen/MachineCombinerPattern.h | 2 +
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 159 ++++++++++++++++++
llvm/test/CodeGen/RISCV/rv64zba.ll | 40 ++---
3 files changed, 177 insertions(+), 24 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 89eed7463bd783..41b73eaae0298c 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -175,6 +175,8 @@ enum class MachineCombinerPattern {
FMADD_XA,
FMSUB,
FNMSUB,
+ SHXADD_ADD_SLLI_OP1,
+ SHXADD_ADD_SLLI_OP2,
// X86 VNNI
DPWSSD,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 5582de51b17d19..8b36db23e94c41 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1829,6 +1829,84 @@ static bool getFPPatterns(MachineInstr &Root,
return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
}
+/// Utility routine that checks if \param MO is defined by an
+/// \param CombineOpc instruction in the basic block \param MBB
+static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
+ const MachineOperand &MO,
+ unsigned CombineOpc) {
+ const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+ const MachineInstr *MI = nullptr;
+
+ if (MO.isReg() && MO.getReg().isVirtual())
+ MI = MRI.getUniqueVRegDef(MO.getReg());
+ // And it needs to be in the trace (otherwise, it won't have a depth).
+ if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
+ return nullptr;
+ // Must only used by the user we combine with.
+ if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+ return nullptr;
+
+ return MI;
+}
+
+/// Utility routine that checks if \param MO is defined by a SLLI in \param
+/// MBB that can be combined by splitting across 2 SHXADD instructions. The
+/// first SHXADD shift amount is given by \param OuterShiftAmt.
+static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
+ const MachineOperand &MO,
+ unsigned OuterShiftAmt) {
+ const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
+ if (!ShiftMI)
+ return false;
+
+ unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+ if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
+ return false;
+
+ return true;
+}
+
+// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
+// (sh3add (sh2add Y, Z), X).
+static bool
+getSHXADDPatterns(const MachineInstr &Root,
+ SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+ unsigned Opc = Root.getOpcode();
+
+ unsigned ShiftAmt;
+ switch (Opc) {
+ default:
+ return false;
+ case RISCV::SH1ADD:
+ ShiftAmt = 1;
+ break;
+ case RISCV::SH2ADD:
+ ShiftAmt = 2;
+ break;
+ case RISCV::SH3ADD:
+ ShiftAmt = 3;
+ break;
+ }
+
+ const MachineBasicBlock &MBB = *Root.getParent();
+
+ const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
+ if (!AddMI)
+ return false;
+
+ bool Found = false;
+ if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
+ Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
+ Found = true;
+ }
+ if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
+ Patterns.push_back(MachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
+ Found = true;
+ }
+
+ return Found;
+}
+
bool RISCVInstrInfo::getMachineCombinerPatterns(
MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
bool DoRegPressureReduce) const {
@@ -1836,6 +1914,9 @@ bool RISCVInstrInfo::getMachineCombinerPatterns(
if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
return true;
+ if (getSHXADDPatterns(Root, Patterns))
+ return true;
+
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
DoRegPressureReduce);
}
@@ -1918,6 +1999,78 @@ static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
DelInstrs.push_back(&Root);
}
+// Combine (sh3add Z, (add X, (slli Y, 5))) to (sh3add (sh2add Y, Z), X).
+static void
+genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
+ SmallVectorImpl<MachineInstr *> &InsInstrs,
+ SmallVectorImpl<MachineInstr *> &DelInstrs,
+ DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+ MachineFunction *MF = Root.getMF();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+ const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+ unsigned OuterShiftAmt;
+ switch (Root.getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case RISCV::SH1ADD:
+ OuterShiftAmt = 1;
+ break;
+ case RISCV::SH2ADD:
+ OuterShiftAmt = 2;
+ break;
+ case RISCV::SH3ADD:
+ OuterShiftAmt = 3;
+ break;
+ }
+
+ MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
+ MachineInstr *ShiftMI =
+ MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
+
+ unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
+ assert(InnerShiftAmt > OuterShiftAmt && "Unexpected shift amount");
+
+ unsigned InnerOpc;
+ switch (InnerShiftAmt - OuterShiftAmt) {
+ default:
+ llvm_unreachable("Unexpected shift amount");
+ case 0:
+ InnerOpc = RISCV::ADD;
+ break;
+ case 1:
+ InnerOpc = RISCV::SH1ADD;
+ break;
+ case 2:
+ InnerOpc = RISCV::SH2ADD;
+ break;
+ case 3:
+ InnerOpc = RISCV::SH3ADD;
+ break;
+ }
+
+ Register X = AddMI->getOperand(3 - AddOpIdx).getReg();
+ Register Y = ShiftMI->getOperand(1).getReg();
+ Register Z = Root.getOperand(1).getReg();
+
+ Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
+
+ auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
+ .addReg(Y)
+ .addReg(Z);
+ auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
+ Root.getOperand(0).getReg())
+ .addReg(NewVR)
+ .addReg(X);
+
+ InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
+ InsInstrs.push_back(MIB1);
+ InsInstrs.push_back(MIB2);
+ DelInstrs.push_back(ShiftMI);
+ DelInstrs.push_back(AddMI);
+ DelInstrs.push_back(&Root);
+}
+
void RISCVInstrInfo::genAlternativeCodeSequence(
MachineInstr &Root, MachineCombinerPattern Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -1941,6 +2094,12 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
return;
}
+ case MachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
+ genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+ return;
+ case MachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
+ genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+ return;
}
}
diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll
index 7e32253c8653f1..067addc819f7e6 100644
--- a/llvm/test/CodeGen/RISCV/rv64zba.ll
+++ b/llvm/test/CodeGen/RISCV/rv64zba.ll
@@ -1404,9 +1404,8 @@ define i64 @sh6_sh3_add2(i64 noundef %x, i64 noundef %y, i64 noundef %z) {
;
; RV64ZBA-LABEL: sh6_sh3_add2:
; RV64ZBA: # %bb.0: # %entry
-; RV64ZBA-NEXT: slli a1, a1, 6
-; RV64ZBA-NEXT: add a0, a1, a0
-; RV64ZBA-NEXT: sh3add a0, a2, a0
+; RV64ZBA-NEXT: sh3add a1, a1, a2
+; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ret
entry:
%shl = shl i64 %z, 3
@@ -2111,9 +2110,8 @@ define i64 @array_index_sh1_sh3(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh1_sh3:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 4
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh3add a0, a2, a0
+; RV64ZBA-NEXT: sh1add a1, a1, a2
+; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ld a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [2 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2174,9 +2172,8 @@ define i32 @array_index_sh2_sh2(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh2_sh2:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 4
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh2add a0, a2, a0
+; RV64ZBA-NEXT: sh2add a1, a1, a2
+; RV64ZBA-NEXT: sh2add a0, a1, a0
; RV64ZBA-NEXT: lw a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [4 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2196,9 +2193,8 @@ define i64 @array_index_sh2_sh3(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh2_sh3:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 5
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh3add a0, a2, a0
+; RV64ZBA-NEXT: sh2add a1, a1, a2
+; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ld a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [4 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2238,9 +2234,8 @@ define i16 @array_index_sh3_sh1(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh3_sh1:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 4
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh1add a0, a2, a0
+; RV64ZBA-NEXT: sh3add a1, a1, a2
+; RV64ZBA-NEXT: sh1add a0, a1, a0
; RV64ZBA-NEXT: lh a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [8 x i16], ptr %p, i64 %idx1, i64 %idx2
@@ -2260,9 +2255,8 @@ define i32 @array_index_sh3_sh2(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh3_sh2:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 5
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh2add a0, a2, a0
+; RV64ZBA-NEXT: sh3add a1, a1, a2
+; RV64ZBA-NEXT: sh2add a0, a1, a0
; RV64ZBA-NEXT: lw a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [8 x i32], ptr %p, i64 %idx1, i64 %idx2
@@ -2282,9 +2276,8 @@ define i64 @array_index_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
;
; RV64ZBA-LABEL: array_index_sh3_sh3:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: slli a1, a1, 6
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh3add a0, a2, a0
+; RV64ZBA-NEXT: sh3add a1, a1, a2
+; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ld a0, 0(a0)
; RV64ZBA-NEXT: ret
%a = getelementptr inbounds [8 x i64], ptr %p, i64 %idx1, i64 %idx2
@@ -2308,9 +2301,8 @@ define i64 @array_index_lshr_sh3_sh3(ptr %p, i64 %idx1, i64 %idx2) {
; RV64ZBA-LABEL: array_index_lshr_sh3_sh3:
; RV64ZBA: # %bb.0:
; RV64ZBA-NEXT: srli a1, a1, 58
-; RV64ZBA-NEXT: slli a1, a1, 6
-; RV64ZBA-NEXT: add a0, a0, a1
-; RV64ZBA-NEXT: sh3add a0, a2, a0
+; RV64ZBA-NEXT: sh3add a1, a1, a2
+; RV64ZBA-NEXT: sh3add a0, a1, a0
; RV64ZBA-NEXT: ld a0, 0(a0)
; RV64ZBA-NEXT: ret
%shr = lshr i64 %idx1, 58
>From f1e72389fc4b515cf8ff1ef7565ed8c543ce30b9 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 9 Apr 2024 09:59:47 -0700
Subject: [PATCH 2/2] fixup! Add helper to get shift amount from SHXADD.
---
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 48 ++++++++++--------------
1 file changed, 19 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 8b36db23e94c41..eed3dcd136587f 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -1866,27 +1866,29 @@ static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
return true;
}
-// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
-// (sh3add (sh2add Y, Z), X).
-static bool
-getSHXADDPatterns(const MachineInstr &Root,
- SmallVectorImpl<MachineCombinerPattern> &Patterns) {
- unsigned Opc = Root.getOpcode();
-
- unsigned ShiftAmt;
+// Returns the shift amount from a SHXADD instruction. Returns 0 if the
+// instruction is not a SHXADD.
+static unsigned getSHXADDShiftAmount(unsigned Opc) {
switch (Opc) {
default:
- return false;
+ return 0;
case RISCV::SH1ADD:
- ShiftAmt = 1;
- break;
+ return 1;
case RISCV::SH2ADD:
- ShiftAmt = 2;
- break;
+ return 2;
case RISCV::SH3ADD:
- ShiftAmt = 3;
- break;
+ return 3;
}
+}
+
+// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
+// (sh3add (sh2add Y, Z), X).
+static bool
+getSHXADDPatterns(const MachineInstr &Root,
+ SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+ unsigned ShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
+ if (!ShiftAmt)
+ return false;
const MachineBasicBlock &MBB = *Root.getParent();
@@ -2009,20 +2011,8 @@ genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
MachineRegisterInfo &MRI = MF->getRegInfo();
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
- unsigned OuterShiftAmt;
- switch (Root.getOpcode()) {
- default:
- llvm_unreachable("Unexpected opcode");
- case RISCV::SH1ADD:
- OuterShiftAmt = 1;
- break;
- case RISCV::SH2ADD:
- OuterShiftAmt = 2;
- break;
- case RISCV::SH3ADD:
- OuterShiftAmt = 3;
- break;
- }
+ unsigned OuterShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
+ assert(OuterShiftAmt != 0 && "Unexpected opcode");
MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
MachineInstr *ShiftMI =
- Previous message: [llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)
- Next message: [llvm] [RISCV] Add MachineCombiner to fold (sh3add Z, (add X, (slli Y, 6))) -> (sh3add (sh3add Y, Z), X). (PR #87884)
- Messages sorted by:
[ date ]
[ thread ]
[ subject ]
[ author ]
More information about the llvm-commits
mailing list