[llvm] a6382de - AMDGPU: Refactor mfma hazard handling [NFC] (#84276)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 7 01:10:04 PST 2024
Author: Matt Arsenault
Date: 2024-03-07T14:39:59+05:30
New Revision: a6382de3999280ef7bf8bb63750686cdad889cd5
URL: https://github.com/llvm/llvm-project/commit/a6382de3999280ef7bf8bb63750686cdad889cd5
DIFF: https://github.com/llvm/llvm-project/commit/a6382de3999280ef7bf8bb63750686cdad889cd5.diff
LOG: AMDGPU: Refactor mfma hazard handling [NFC] (#84276)
Try to make this editable by using functions for the number of wait
states as a function of the number of passes. I'm assuming the current
hazard test coverage is comprehensive.
This could probably use another round to further simplify it.
Alternatively, I believe this could all be expressed in a constant table
indexed by an instruction classify function and number of passes.
Added:
Modified:
llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index 7bed0d8ef0d670..e515b729e7d7e8 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -2136,6 +2136,41 @@ int GCNHazardRecognizer::checkMAIHazards908(MachineInstr *MI) {
return WaitStatesNeeded;
}
+static int
+GFX940_XDL_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(int NumPasses) {
+ // 2 pass -> 3
+ // 4 pass -> 5
+ // 8 pass -> 9
+ // 16 pass -> 17
+ return NumPasses + 1;
+}
+
+static int
+GFX940_SMFMA_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(int NumPasses) {
+ // 2 pass -> 2
+ // 4 pass -> 4
+ // 8 pass -> 8
+ // 16 pass -> 16
+ return NumPasses;
+}
+
+static int
+GFX940_SMFMA_N_PassWritesVGPROverlappedSrcABWaitStates(int NumPasses) {
+ // 2 pass -> 4
+ // 4 pass -> 6
+ // 8 pass -> 10
+ // 16 pass -> 18
+ return NumPasses + 2;
+}
+
+static int GFX940_XDL_N_PassWritesVGPROverlappedSrcABWaitStates(int NumPasses) {
+ // 2 pass -> 5
+ // 4 pass -> 7
+ // 8 pass -> 11
+ // 16 pass -> 19
+ return NumPasses + 3;
+}
+
int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
int WaitStatesNeeded = 0;
unsigned Opc = MI->getOpcode();
@@ -2164,13 +2199,6 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
for (const MachineOperand &Use : MI->explicit_uses()) {
const int LegacyVALUNotDotWritesVGPRWaitStates = 2;
const int SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates = 2;
- const int GFX940_XDL2PassWritesVGPROverlappedSMFMASrcCWaitStates = 3;
- const int GFX940_XDL4PassWritesVGPROverlappedSMFMASrcCWaitStates = 5;
- const int GFX940_SMFMA4PassWritesVGPROverlappedSMFMASrcCWaitStates = 4;
- const int GFX940_XDL8PassWritesVGPROverlappedSMFMASrcCWaitStates = 9;
- const int GFX940_SMFMA8PassWritesVGPROverlappedSMFMASrcCWaitStates = 8;
- const int GFX940_XDL16PassWritesVGPROverlappedSMFMASrcCWaitStates = 17;
- const int GFX940_SMFMA16PassWritesVGPROverlappedSMFMASrcCWaitStates = 16;
const int SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates = 8;
const int SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates = 16;
const int SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates = 3;
@@ -2181,14 +2209,6 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
const int SMFMA4x4WritesVGPROverlappedSrcABWaitStates = 5;
const int SMFMA16x16WritesVGPROverlappedSrcABWaitStates = 11;
const int SMFMA32x32WritesVGPROverlappedSrcABWaitStates = 19;
- const int GFX940_SMFMA2PassWritesVGPROverlappedSrcABWaitStates = 4;
- const int GFX940_SMFMA4PassWritesVGPROverlappedSrcABWaitStates = 6;
- const int GFX940_SMFMA8PassWritesVGPROverlappedSrcABWaitStates = 10;
- const int GFX940_SMFMA16PassWritesVGPROverlappedSrcABWaitStates = 18;
- const int GFX940_XDL2PassWritesVGPROverlappedSrcABWaitStates = 5;
- const int GFX940_XDL4PassWritesVGPROverlappedSrcABWaitStates = 7;
- const int GFX940_XDL8PassWritesVGPROverlappedSrcABWaitStates = 11;
- const int GFX940_XDL16PassWritesVGPROverlappedSrcABWaitStates = 19;
const int DMFMA4x4WritesVGPROverlappedMFMASrcABWaitStates = 6;
const int DMFMA16x16WritesVGPROverlappedMFMASrcABWaitStates = 11;
const int DMFMA4x4WritesVGPRFullSrcCWaitStates = 4;
@@ -2250,42 +2270,40 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
NeedWaitStates = DMFMA4x4WritesVGPROverlappedSrcCWaitStates;
break;
default:
- if (ST.hasGFX940Insts() && isXDL(ST, *MI) && !isXDL(ST, *MI1))
+ int NumPasses = TSchedModel.computeInstrLatency(MI1);
+ if (ST.hasGFX940Insts()) {
+ if (isXDL(ST, *MI) && !isXDL(ST, *MI1))
+ break;
+
+ NeedWaitStates =
+ isXDL(ST, *MI1)
+ ? GFX940_XDL_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(
+ NumPasses)
+ : GFX940_SMFMA_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(
+ NumPasses);
break;
- switch (TSchedModel.computeInstrLatency(MI1)) {
+ }
+
+ switch (NumPasses) {
case 2:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL2PassWritesVGPROverlappedSMFMASrcCWaitStates
- : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates
- : isDGEMM(Opc)
- ? SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates
- : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates;
- break;
- case 4:
- assert(ST.hasGFX940Insts());
- NeedWaitStates = isXDL(ST, *MI1)
- ? GFX940_XDL4PassWritesVGPROverlappedSMFMASrcCWaitStates
- : GFX940_SMFMA4PassWritesVGPROverlappedSMFMASrcCWaitStates;
+ NeedWaitStates =
+ isDGEMM(Opc) ? SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates
+ : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates;
break;
case 8:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL8PassWritesVGPROverlappedSMFMASrcCWaitStates
- : GFX940_SMFMA8PassWritesVGPROverlappedSMFMASrcCWaitStates
- : isDGEMM(Opc)
- ? SMFMA16x16WritesVGPROverlappedDMFMASrcCWaitStates
- : SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates;
+ NeedWaitStates =
+ isDGEMM(Opc)
+ ? SMFMA16x16WritesVGPROverlappedDMFMASrcCWaitStates
+ : SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates;
+ break;
+ case 16:
+ NeedWaitStates =
+ isDGEMM(Opc)
+ ? SMFMA32x32WritesVGPROverlappedDMFMASrcCWaitStates
+ : SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates;
break;
- case 16: [[fallthrough]];
default:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL16PassWritesVGPROverlappedSMFMASrcCWaitStates
- : GFX940_SMFMA16PassWritesVGPROverlappedSMFMASrcCWaitStates
- : isDGEMM(Opc)
- ? SMFMA32x32WritesVGPROverlappedDMFMASrcCWaitStates
- : SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates;
+ llvm_unreachable("unexpected number of passes");
}
}
}
@@ -2302,34 +2320,30 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
NeedWaitStates = DMFMA4x4WritesVGPROverlappedMFMASrcABWaitStates;
break;
default:
- switch (TSchedModel.computeInstrLatency(MI1)) {
+ int NumPasses = TSchedModel.computeInstrLatency(MI1);
+
+ if (ST.hasGFX940Insts()) {
+ NeedWaitStates =
+ isXDL(ST, *MI1)
+ ? GFX940_XDL_N_PassWritesVGPROverlappedSrcABWaitStates(
+ NumPasses)
+ : GFX940_SMFMA_N_PassWritesVGPROverlappedSrcABWaitStates(
+ NumPasses);
+ break;
+ }
+
+ switch (NumPasses) {
case 2:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL2PassWritesVGPROverlappedSrcABWaitStates
- : GFX940_SMFMA2PassWritesVGPROverlappedSrcABWaitStates
- : SMFMA4x4WritesVGPROverlappedSrcABWaitStates;
+ NeedWaitStates = SMFMA4x4WritesVGPROverlappedSrcABWaitStates;
break;
case 4:
- assert(ST.hasGFX940Insts());
- NeedWaitStates = isXDL(ST, *MI1)
- ? GFX940_XDL4PassWritesVGPROverlappedSrcABWaitStates
- : GFX940_SMFMA4PassWritesVGPROverlappedSrcABWaitStates;
- break;
+ llvm_unreachable("unexpected number of passes for mfma");
case 8:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL8PassWritesVGPROverlappedSrcABWaitStates
- : GFX940_SMFMA8PassWritesVGPROverlappedSrcABWaitStates
- : SMFMA16x16WritesVGPROverlappedSrcABWaitStates;
+ NeedWaitStates = SMFMA16x16WritesVGPROverlappedSrcABWaitStates;
break;
case 16: [[fallthrough]];
default:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MI1)
- ? GFX940_XDL16PassWritesVGPROverlappedSrcABWaitStates
- : GFX940_SMFMA16PassWritesVGPROverlappedSrcABWaitStates
- : SMFMA32x32WritesVGPROverlappedSrcABWaitStates;
+ NeedWaitStates = SMFMA32x32WritesVGPROverlappedSrcABWaitStates;
}
}
}
@@ -2393,6 +2407,38 @@ int GCNHazardRecognizer::checkMAILdStHazards(MachineInstr *MI) {
return WaitStatesNeeded;
}
+static int GFX940_SMFMA_N_PassWriteVgprVALUWawWaitStates(int NumPasses) {
+ // 2 pass -> 4
+ // 4 pass -> 6
+ // 8 pass -> 10
+ // 16 pass -> 18
+ return NumPasses + 2;
+}
+
+static int GFX940_XDL_N_PassWriteVgprVALUWawWaitStates(int NumPasses) {
+ // 2 pass -> 5
+ // 4 pass -> 7
+ // 8 pass -> 11
+ // 16 pass -> 19
+ return NumPasses + 3;
+}
+
+static int GFX940_XDL_N_PassWriteVgprVALUMemExpReadWaitStates(int NumPasses) {
+ // 2 pass -> 5
+ // 4 pass -> 7
+ // 8 pass -> 11
+ // 16 pass -> 19
+ return NumPasses + 3;
+}
+
+static int GFX940_SMFMA_N_PassWriteVgprVALUMemExpReadWaitStates(int NumPasses) {
+ // 2 pass -> 4
+ // 4 pass -> 6
+ // 8 pass -> 10
+ // 16 pass -> 18
+ return NumPasses + 2;
+}
+
int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
if (!ST.hasGFX90AInsts())
return 0;
@@ -2455,14 +2501,6 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
const int SMFMA4x4WriteVgprVALUMemExpReadWaitStates = 5;
const int SMFMA16x16WriteVgprVALUMemExpReadWaitStates = 11;
const int SMFMA32x32WriteVgprVALUMemExpReadWaitStates = 19;
- const int GFX940_SMFMA2PassWriteVgprVALUMemExpReadWaitStates = 4;
- const int GFX940_SMFMA4PassWriteVgprVALUMemExpReadWaitStates = 6;
- const int GFX940_SMFMA8PassWriteVgprVALUMemExpReadWaitStates = 10;
- const int GFX940_SMFMA16PassWriteVgprVALUMemExpReadWaitStates = 18;
- const int GFX940_XDL2PassWriteVgprVALUMemExpReadWaitStates = 5;
- const int GFX940_XDL4PassWriteVgprVALUMemExpReadWaitStates = 7;
- const int GFX940_XDL8PassWriteVgprVALUMemExpReadWaitStates = 11;
- const int GFX940_XDL16PassWriteVgprVALUMemExpReadWaitStates = 19;
const int DMFMA4x4WriteVgprMemExpReadWaitStates = 9;
const int DMFMA16x16WriteVgprMemExpReadWaitStates = 18;
const int DMFMA4x4WriteVgprVALUReadWaitStates = 6;
@@ -2516,47 +2554,44 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
continue;
unsigned HazardDefLatency = TSchedModel.computeInstrLatency(MFMA);
+ int NumPasses = HazardDefLatency;
int NeedWaitStates = MaxWaitStates;
- switch (HazardDefLatency) {
- case 2:
- NeedWaitStates =
- ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA)
- ? GFX940_XDL2PassWriteVgprVALUMemExpReadWaitStates
- : GFX940_SMFMA2PassWriteVgprVALUMemExpReadWaitStates
- : SMFMA4x4WriteVgprVALUMemExpReadWaitStates;
- break;
- case 4:
- assert(isDGEMM(MFMA->getOpcode()) || ST.hasGFX940Insts());
- NeedWaitStates =
- isDGEMM(MFMA->getOpcode())
- ? IsMemOrExport ? DMFMA4x4WriteVgprMemExpReadWaitStates
- : DMFMA4x4WriteVgprVALUReadWaitStates
- : isXDL(ST, *MFMA)
- ? GFX940_XDL4PassWriteVgprVALUMemExpReadWaitStates
- : GFX940_SMFMA4PassWriteVgprVALUMemExpReadWaitStates;
- break;
- case 8:
- NeedWaitStates =
- isDGEMM(MFMA->getOpcode())
- ? IsMemOrExport ? DMFMA16x16WriteVgprMemExpReadWaitStates
- : DMFMA16x16WriteVgprVALUReadWaitStates
- : ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA)
- ? GFX940_XDL8PassWriteVgprVALUMemExpReadWaitStates
- : GFX940_SMFMA8PassWriteVgprVALUMemExpReadWaitStates
- : SMFMA16x16WriteVgprVALUMemExpReadWaitStates;
- break;
- case 16: [[fallthrough]];
- default:
- assert(!isDGEMM(MFMA->getOpcode()));
+
+ if (isDGEMM(MFMA->getOpcode())) {
+ switch (HazardDefLatency) {
+ case 4:
+ NeedWaitStates = IsMemOrExport ? DMFMA4x4WriteVgprMemExpReadWaitStates
+ : DMFMA4x4WriteVgprVALUReadWaitStates;
+ break;
+ case 8:
+ case 16:
+ NeedWaitStates = IsMemOrExport
+ ? DMFMA16x16WriteVgprMemExpReadWaitStates
+ : DMFMA16x16WriteVgprVALUReadWaitStates;
+ break;
+ default:
+ llvm_unreachable("unexpected dgemm");
+ }
+ } else if (ST.hasGFX940Insts()) {
NeedWaitStates =
- ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA)
- ? GFX940_XDL16PassWriteVgprVALUMemExpReadWaitStates
- : GFX940_SMFMA16PassWriteVgprVALUMemExpReadWaitStates
- : SMFMA32x32WriteVgprVALUMemExpReadWaitStates;
- break;
+ isXDL(ST, *MFMA)
+ ? GFX940_XDL_N_PassWriteVgprVALUMemExpReadWaitStates(NumPasses)
+ : GFX940_SMFMA_N_PassWriteVgprVALUMemExpReadWaitStates(
+ NumPasses);
+ } else {
+ switch (HazardDefLatency) {
+ case 2:
+ NeedWaitStates = SMFMA4x4WriteVgprVALUMemExpReadWaitStates;
+ break;
+ case 8:
+ NeedWaitStates = SMFMA16x16WriteVgprVALUMemExpReadWaitStates;
+ break;
+ case 16:
+ NeedWaitStates = SMFMA32x32WriteVgprVALUMemExpReadWaitStates;
+ break;
+ default:
+ llvm_unreachable("unexpected number of passes for mfma");
+ }
}
int WaitStatesNeededForUse = NeedWaitStates - WaitStatesSinceDef;
@@ -2585,14 +2620,6 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
const int SMFMA4x4WriteVgprVALUWawWaitStates = 5;
const int SMFMA16x16WriteVgprVALUWawWaitStates = 11;
const int SMFMA32x32WriteVgprVALUWawWaitStates = 19;
- const int GFX940_SMFMA2PassWriteVgprVALUWawWaitStates = 4;
- const int GFX940_SMFMA4PassWriteVgprVALUWawWaitStates = 6;
- const int GFX940_SMFMA8PassWriteVgprVALUWawWaitStates = 10;
- const int GFX940_SMFMA16PassWriteVgprVALUWawWaitStates = 18;
- const int GFX940_XDL2PassWriteVgprVALUWawWaitStates = 5;
- const int GFX940_XDL4PassWriteVgprVALUWawWaitStates = 7;
- const int GFX940_XDL8PassWriteVgprVALUWawWaitStates = 11;
- const int GFX940_XDL16PassWriteVgprVALUWawWaitStates = 19;
const int SMFMA4x4ReadVgprVALUWarWaitStates = 1;
const int GFX940_XDL4PassReadVgprVALUWarWaitStates = 3;
const int SMFMA16x16ReadVgprVALUWarWaitStates = 7;
@@ -2617,42 +2644,39 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
getWaitStatesSinceDef(Reg, IsMFMAWriteFn, MaxWaitStates);
if (MFMA) {
int NeedWaitStates = MaxWaitStates;
- switch (TSchedModel.computeInstrLatency(MFMA)) {
- case 2:
- NeedWaitStates = ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA)
- ? GFX940_XDL2PassWriteVgprVALUWawWaitStates
- : GFX940_SMFMA2PassWriteVgprVALUWawWaitStates
- : SMFMA4x4WriteVgprVALUWawWaitStates;
- break;
- case 4:
- assert(isDGEMM(MFMA->getOpcode()) || ST.hasGFX940Insts());
- NeedWaitStates = isDGEMM(MFMA->getOpcode())
- ? DMFMA4x4WriteVgprVALUWriteWaitStates
- : isXDL(ST, *MFMA)
- ? GFX940_XDL4PassWriteVgprVALUWawWaitStates
- : GFX940_SMFMA4PassWriteVgprVALUWawWaitStates;
- break;
- case 8:
- NeedWaitStates =
- isDGEMM(MFMA->getOpcode()) ? DMFMA16x16WriteVgprVALUWriteWaitStates
- :
+ int NumPasses = TSchedModel.computeInstrLatency(MFMA);
- ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA) ? GFX940_XDL8PassWriteVgprVALUWawWaitStates
- : GFX940_SMFMA8PassWriteVgprVALUWawWaitStates
- : SMFMA16x16WriteVgprVALUWawWaitStates;
- break;
- case 16: [[fallthrough]];
- default:
- assert(!isDGEMM(MFMA->getOpcode()));
+ if (isDGEMM(MFMA->getOpcode())) {
+ switch (NumPasses) {
+ case 4:
+ NeedWaitStates = DMFMA4x4WriteVgprVALUWriteWaitStates;
+ break;
+ case 8:
+ case 16:
+ NeedWaitStates = DMFMA16x16WriteVgprVALUWriteWaitStates;
+ break;
+ default:
+ llvm_unreachable("unexpected number of cycles for dgemm");
+ }
+ } else if (ST.hasGFX940Insts()) {
NeedWaitStates =
- ST.hasGFX940Insts()
- ? isXDL(ST, *MFMA)
- ? GFX940_XDL16PassWriteVgprVALUWawWaitStates
- : GFX940_SMFMA16PassWriteVgprVALUWawWaitStates
- : SMFMA32x32WriteVgprVALUWawWaitStates;
- break;
+ isXDL(ST, *MFMA)
+ ? GFX940_XDL_N_PassWriteVgprVALUWawWaitStates(NumPasses)
+ : GFX940_SMFMA_N_PassWriteVgprVALUWawWaitStates(NumPasses);
+ } else {
+ switch (NumPasses) {
+ case 2:
+ NeedWaitStates = SMFMA4x4WriteVgprVALUWawWaitStates;
+ break;
+ case 8:
+ NeedWaitStates = SMFMA16x16WriteVgprVALUWawWaitStates;
+ break;
+ case 16:
+ NeedWaitStates = SMFMA32x32WriteVgprVALUWawWaitStates;
+ break;
+ default:
+ llvm_unreachable("Unexpected number of passes for mfma");
+ }
}
int WaitStatesNeededForUse = NeedWaitStates - WaitStatesSinceDef;
More information about the llvm-commits
mailing list