[llvm] a6382de - AMDGPU: Refactor mfma hazard handling [NFC] (#84276)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 01:10:04 PST 2024


Author: Matt Arsenault
Date: 2024-03-07T14:39:59+05:30
New Revision: a6382de3999280ef7bf8bb63750686cdad889cd5

URL: https://github.com/llvm/llvm-project/commit/a6382de3999280ef7bf8bb63750686cdad889cd5
DIFF: https://github.com/llvm/llvm-project/commit/a6382de3999280ef7bf8bb63750686cdad889cd5.diff

LOG: AMDGPU: Refactor mfma hazard handling [NFC] (#84276)

Try to make this editable by using functions for the number of wait
states as a function of the number of passes. I'm assuming the current
hazard test coverage is comprehensive.

This could probably use another round to further simplify it.
Alternatively, I believe this could all be expressed in a constant table
indexed by an instruction classify function and number of passes.

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
index 7bed0d8ef0d670..e515b729e7d7e8 100644
--- a/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
+++ b/llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp
@@ -2136,6 +2136,41 @@ int GCNHazardRecognizer::checkMAIHazards908(MachineInstr *MI) {
   return WaitStatesNeeded;
 }
 
+static int
+GFX940_XDL_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(int NumPasses) {
+  // 2 pass -> 3
+  // 4 pass -> 5
+  // 8 pass -> 9
+  // 16 pass -> 17
+  return NumPasses + 1;
+}
+
+static int
+GFX940_SMFMA_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(int NumPasses) {
+  // 2 pass -> 2
+  // 4 pass -> 4
+  // 8 pass -> 8
+  // 16 pass -> 16
+  return NumPasses;
+}
+
+static int
+GFX940_SMFMA_N_PassWritesVGPROverlappedSrcABWaitStates(int NumPasses) {
+  // 2 pass -> 4
+  // 4 pass -> 6
+  // 8 pass -> 10
+  // 16 pass -> 18
+  return NumPasses + 2;
+}
+
+static int GFX940_XDL_N_PassWritesVGPROverlappedSrcABWaitStates(int NumPasses) {
+  // 2 pass -> 5
+  // 4 pass -> 7
+  // 8 pass -> 11
+  // 16 pass -> 19
+  return NumPasses + 3;
+}
+
 int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
   int WaitStatesNeeded = 0;
   unsigned Opc = MI->getOpcode();
@@ -2164,13 +2199,6 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
   for (const MachineOperand &Use : MI->explicit_uses()) {
     const int LegacyVALUNotDotWritesVGPRWaitStates = 2;
     const int SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates = 2;
-    const int GFX940_XDL2PassWritesVGPROverlappedSMFMASrcCWaitStates = 3;
-    const int GFX940_XDL4PassWritesVGPROverlappedSMFMASrcCWaitStates = 5;
-    const int GFX940_SMFMA4PassWritesVGPROverlappedSMFMASrcCWaitStates = 4;
-    const int GFX940_XDL8PassWritesVGPROverlappedSMFMASrcCWaitStates = 9;
-    const int GFX940_SMFMA8PassWritesVGPROverlappedSMFMASrcCWaitStates = 8;
-    const int GFX940_XDL16PassWritesVGPROverlappedSMFMASrcCWaitStates = 17;
-    const int GFX940_SMFMA16PassWritesVGPROverlappedSMFMASrcCWaitStates = 16;
     const int SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates = 8;
     const int SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates = 16;
     const int SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates = 3;
@@ -2181,14 +2209,6 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
     const int SMFMA4x4WritesVGPROverlappedSrcABWaitStates = 5;
     const int SMFMA16x16WritesVGPROverlappedSrcABWaitStates = 11;
     const int SMFMA32x32WritesVGPROverlappedSrcABWaitStates = 19;
-    const int GFX940_SMFMA2PassWritesVGPROverlappedSrcABWaitStates = 4;
-    const int GFX940_SMFMA4PassWritesVGPROverlappedSrcABWaitStates = 6;
-    const int GFX940_SMFMA8PassWritesVGPROverlappedSrcABWaitStates = 10;
-    const int GFX940_SMFMA16PassWritesVGPROverlappedSrcABWaitStates = 18;
-    const int GFX940_XDL2PassWritesVGPROverlappedSrcABWaitStates = 5;
-    const int GFX940_XDL4PassWritesVGPROverlappedSrcABWaitStates = 7;
-    const int GFX940_XDL8PassWritesVGPROverlappedSrcABWaitStates = 11;
-    const int GFX940_XDL16PassWritesVGPROverlappedSrcABWaitStates = 19;
     const int DMFMA4x4WritesVGPROverlappedMFMASrcABWaitStates = 6;
     const int DMFMA16x16WritesVGPROverlappedMFMASrcABWaitStates = 11;
     const int DMFMA4x4WritesVGPRFullSrcCWaitStates = 4;
@@ -2250,42 +2270,40 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
             NeedWaitStates = DMFMA4x4WritesVGPROverlappedSrcCWaitStates;
           break;
         default:
-          if (ST.hasGFX940Insts() && isXDL(ST, *MI) && !isXDL(ST, *MI1))
+          int NumPasses = TSchedModel.computeInstrLatency(MI1);
+          if (ST.hasGFX940Insts()) {
+            if (isXDL(ST, *MI) && !isXDL(ST, *MI1))
+              break;
+
+            NeedWaitStates =
+                isXDL(ST, *MI1)
+                    ? GFX940_XDL_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(
+                          NumPasses)
+                    : GFX940_SMFMA_N_PassWritesVGPROverlappedSMFMASrcCWaitStates(
+                          NumPasses);
             break;
-          switch (TSchedModel.computeInstrLatency(MI1)) {
+          }
+
+          switch (NumPasses) {
           case 2:
-            NeedWaitStates = ST.hasGFX940Insts()
-              ? isXDL(ST, *MI1)
-                ? GFX940_XDL2PassWritesVGPROverlappedSMFMASrcCWaitStates
-                : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates
-              : isDGEMM(Opc)
-                ? SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates
-                : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates;
-            break;
-          case 4:
-            assert(ST.hasGFX940Insts());
-            NeedWaitStates = isXDL(ST, *MI1)
-              ? GFX940_XDL4PassWritesVGPROverlappedSMFMASrcCWaitStates
-              : GFX940_SMFMA4PassWritesVGPROverlappedSMFMASrcCWaitStates;
+            NeedWaitStates =
+                isDGEMM(Opc) ? SMFMA4x4WritesVGPROverlappedDMFMASrcCWaitStates
+                             : SMFMA4x4WritesVGPROverlappedSMFMASrcCWaitStates;
             break;
           case 8:
-            NeedWaitStates = ST.hasGFX940Insts()
-              ? isXDL(ST, *MI1)
-                ? GFX940_XDL8PassWritesVGPROverlappedSMFMASrcCWaitStates
-                : GFX940_SMFMA8PassWritesVGPROverlappedSMFMASrcCWaitStates
-              : isDGEMM(Opc)
-                ? SMFMA16x16WritesVGPROverlappedDMFMASrcCWaitStates
-                : SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates;
+            NeedWaitStates =
+                isDGEMM(Opc)
+                    ? SMFMA16x16WritesVGPROverlappedDMFMASrcCWaitStates
+                    : SMFMA16x16WritesVGPROverlappedSMFMASrcCWaitStates;
+            break;
+          case 16:
+            NeedWaitStates =
+                isDGEMM(Opc)
+                    ? SMFMA32x32WritesVGPROverlappedDMFMASrcCWaitStates
+                    : SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates;
             break;
-          case 16: [[fallthrough]];
           default:
-            NeedWaitStates = ST.hasGFX940Insts()
-              ? isXDL(ST, *MI1)
-                ? GFX940_XDL16PassWritesVGPROverlappedSMFMASrcCWaitStates
-                : GFX940_SMFMA16PassWritesVGPROverlappedSMFMASrcCWaitStates
-              : isDGEMM(Opc)
-                ? SMFMA32x32WritesVGPROverlappedDMFMASrcCWaitStates
-                : SMFMA32x32WritesVGPROverlappedSMFMASrcCWaitStates;
+            llvm_unreachable("unexpected number of passes");
           }
         }
       }
@@ -2302,34 +2320,30 @@ int GCNHazardRecognizer::checkMAIHazards90A(MachineInstr *MI) {
         NeedWaitStates = DMFMA4x4WritesVGPROverlappedMFMASrcABWaitStates;
         break;
       default:
-        switch (TSchedModel.computeInstrLatency(MI1)) {
+        int NumPasses = TSchedModel.computeInstrLatency(MI1);
+
+        if (ST.hasGFX940Insts()) {
+          NeedWaitStates =
+              isXDL(ST, *MI1)
+                  ? GFX940_XDL_N_PassWritesVGPROverlappedSrcABWaitStates(
+                        NumPasses)
+                  : GFX940_SMFMA_N_PassWritesVGPROverlappedSrcABWaitStates(
+                        NumPasses);
+          break;
+        }
+
+        switch (NumPasses) {
         case 2:
-          NeedWaitStates = ST.hasGFX940Insts()
-            ? isXDL(ST, *MI1)
-              ? GFX940_XDL2PassWritesVGPROverlappedSrcABWaitStates
-              : GFX940_SMFMA2PassWritesVGPROverlappedSrcABWaitStates
-            : SMFMA4x4WritesVGPROverlappedSrcABWaitStates;
+          NeedWaitStates = SMFMA4x4WritesVGPROverlappedSrcABWaitStates;
           break;
         case 4:
-          assert(ST.hasGFX940Insts());
-          NeedWaitStates = isXDL(ST, *MI1)
-            ? GFX940_XDL4PassWritesVGPROverlappedSrcABWaitStates
-            : GFX940_SMFMA4PassWritesVGPROverlappedSrcABWaitStates;
-          break;
+          llvm_unreachable("unexpected number of passes for mfma");
         case 8:
-          NeedWaitStates = ST.hasGFX940Insts()
-            ? isXDL(ST, *MI1)
-              ? GFX940_XDL8PassWritesVGPROverlappedSrcABWaitStates
-              : GFX940_SMFMA8PassWritesVGPROverlappedSrcABWaitStates
-            : SMFMA16x16WritesVGPROverlappedSrcABWaitStates;
+          NeedWaitStates = SMFMA16x16WritesVGPROverlappedSrcABWaitStates;
           break;
         case 16: [[fallthrough]];
         default:
-          NeedWaitStates = ST.hasGFX940Insts()
-            ? isXDL(ST, *MI1)
-              ? GFX940_XDL16PassWritesVGPROverlappedSrcABWaitStates
-              : GFX940_SMFMA16PassWritesVGPROverlappedSrcABWaitStates
-            : SMFMA32x32WritesVGPROverlappedSrcABWaitStates;
+          NeedWaitStates = SMFMA32x32WritesVGPROverlappedSrcABWaitStates;
         }
       }
     }
@@ -2393,6 +2407,38 @@ int GCNHazardRecognizer::checkMAILdStHazards(MachineInstr *MI) {
   return WaitStatesNeeded;
 }
 
+static int GFX940_SMFMA_N_PassWriteVgprVALUWawWaitStates(int NumPasses) {
+  // 2 pass -> 4
+  // 4 pass -> 6
+  // 8 pass -> 10
+  // 16 pass -> 18
+  return NumPasses + 2;
+}
+
+static int GFX940_XDL_N_PassWriteVgprVALUWawWaitStates(int NumPasses) {
+  // 2 pass -> 5
+  // 4 pass -> 7
+  // 8 pass -> 11
+  // 16 pass -> 19
+  return NumPasses + 3;
+}
+
+static int GFX940_XDL_N_PassWriteVgprVALUMemExpReadWaitStates(int NumPasses) {
+  // 2 pass -> 5
+  // 4 pass -> 7
+  // 8 pass -> 11
+  // 16 pass -> 19
+  return NumPasses + 3;
+}
+
+static int GFX940_SMFMA_N_PassWriteVgprVALUMemExpReadWaitStates(int NumPasses) {
+  // 2 pass -> 4
+  // 4 pass -> 6
+  // 8 pass -> 10
+  // 16 pass -> 18
+  return NumPasses + 2;
+}
+
 int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
   if (!ST.hasGFX90AInsts())
     return 0;
@@ -2455,14 +2501,6 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
     const int SMFMA4x4WriteVgprVALUMemExpReadWaitStates = 5;
     const int SMFMA16x16WriteVgprVALUMemExpReadWaitStates = 11;
     const int SMFMA32x32WriteVgprVALUMemExpReadWaitStates = 19;
-    const int GFX940_SMFMA2PassWriteVgprVALUMemExpReadWaitStates = 4;
-    const int GFX940_SMFMA4PassWriteVgprVALUMemExpReadWaitStates = 6;
-    const int GFX940_SMFMA8PassWriteVgprVALUMemExpReadWaitStates = 10;
-    const int GFX940_SMFMA16PassWriteVgprVALUMemExpReadWaitStates = 18;
-    const int GFX940_XDL2PassWriteVgprVALUMemExpReadWaitStates = 5;
-    const int GFX940_XDL4PassWriteVgprVALUMemExpReadWaitStates = 7;
-    const int GFX940_XDL8PassWriteVgprVALUMemExpReadWaitStates = 11;
-    const int GFX940_XDL16PassWriteVgprVALUMemExpReadWaitStates = 19;
     const int DMFMA4x4WriteVgprMemExpReadWaitStates = 9;
     const int DMFMA16x16WriteVgprMemExpReadWaitStates = 18;
     const int DMFMA4x4WriteVgprVALUReadWaitStates = 6;
@@ -2516,47 +2554,44 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
         continue;
 
       unsigned HazardDefLatency = TSchedModel.computeInstrLatency(MFMA);
+      int NumPasses = HazardDefLatency;
       int NeedWaitStates = MaxWaitStates;
-      switch (HazardDefLatency) {
-      case 2:
-        NeedWaitStates =
-          ST.hasGFX940Insts()
-            ? isXDL(ST, *MFMA)
-              ? GFX940_XDL2PassWriteVgprVALUMemExpReadWaitStates
-              : GFX940_SMFMA2PassWriteVgprVALUMemExpReadWaitStates
-            : SMFMA4x4WriteVgprVALUMemExpReadWaitStates;
-        break;
-      case 4:
-        assert(isDGEMM(MFMA->getOpcode()) || ST.hasGFX940Insts());
-        NeedWaitStates =
-          isDGEMM(MFMA->getOpcode())
-            ? IsMemOrExport ? DMFMA4x4WriteVgprMemExpReadWaitStates
-                            : DMFMA4x4WriteVgprVALUReadWaitStates
-            : isXDL(ST, *MFMA)
-              ? GFX940_XDL4PassWriteVgprVALUMemExpReadWaitStates
-              : GFX940_SMFMA4PassWriteVgprVALUMemExpReadWaitStates;
-        break;
-      case 8:
-        NeedWaitStates =
-            isDGEMM(MFMA->getOpcode())
-                ? IsMemOrExport ? DMFMA16x16WriteVgprMemExpReadWaitStates
-                                : DMFMA16x16WriteVgprVALUReadWaitStates
-            : ST.hasGFX940Insts()
-                ? isXDL(ST, *MFMA)
-                      ? GFX940_XDL8PassWriteVgprVALUMemExpReadWaitStates
-                      : GFX940_SMFMA8PassWriteVgprVALUMemExpReadWaitStates
-                : SMFMA16x16WriteVgprVALUMemExpReadWaitStates;
-        break;
-      case 16: [[fallthrough]];
-      default:
-        assert(!isDGEMM(MFMA->getOpcode()));
+
+      if (isDGEMM(MFMA->getOpcode())) {
+        switch (HazardDefLatency) {
+        case 4:
+          NeedWaitStates = IsMemOrExport ? DMFMA4x4WriteVgprMemExpReadWaitStates
+                                         : DMFMA4x4WriteVgprVALUReadWaitStates;
+          break;
+        case 8:
+        case 16:
+          NeedWaitStates = IsMemOrExport
+                               ? DMFMA16x16WriteVgprMemExpReadWaitStates
+                               : DMFMA16x16WriteVgprVALUReadWaitStates;
+          break;
+        default:
+          llvm_unreachable("unexpected dgemm");
+        }
+      } else if (ST.hasGFX940Insts()) {
         NeedWaitStates =
-            ST.hasGFX940Insts()
-                ? isXDL(ST, *MFMA)
-                      ? GFX940_XDL16PassWriteVgprVALUMemExpReadWaitStates
-                      : GFX940_SMFMA16PassWriteVgprVALUMemExpReadWaitStates
-                : SMFMA32x32WriteVgprVALUMemExpReadWaitStates;
-        break;
+            isXDL(ST, *MFMA)
+                ? GFX940_XDL_N_PassWriteVgprVALUMemExpReadWaitStates(NumPasses)
+                : GFX940_SMFMA_N_PassWriteVgprVALUMemExpReadWaitStates(
+                      NumPasses);
+      } else {
+        switch (HazardDefLatency) {
+        case 2:
+          NeedWaitStates = SMFMA4x4WriteVgprVALUMemExpReadWaitStates;
+          break;
+        case 8:
+          NeedWaitStates = SMFMA16x16WriteVgprVALUMemExpReadWaitStates;
+          break;
+        case 16:
+          NeedWaitStates = SMFMA32x32WriteVgprVALUMemExpReadWaitStates;
+          break;
+        default:
+          llvm_unreachable("unexpected number of passes for mfma");
+        }
       }
 
       int WaitStatesNeededForUse = NeedWaitStates - WaitStatesSinceDef;
@@ -2585,14 +2620,6 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
     const int SMFMA4x4WriteVgprVALUWawWaitStates = 5;
     const int SMFMA16x16WriteVgprVALUWawWaitStates = 11;
     const int SMFMA32x32WriteVgprVALUWawWaitStates = 19;
-    const int GFX940_SMFMA2PassWriteVgprVALUWawWaitStates = 4;
-    const int GFX940_SMFMA4PassWriteVgprVALUWawWaitStates = 6;
-    const int GFX940_SMFMA8PassWriteVgprVALUWawWaitStates = 10;
-    const int GFX940_SMFMA16PassWriteVgprVALUWawWaitStates = 18;
-    const int GFX940_XDL2PassWriteVgprVALUWawWaitStates = 5;
-    const int GFX940_XDL4PassWriteVgprVALUWawWaitStates = 7;
-    const int GFX940_XDL8PassWriteVgprVALUWawWaitStates = 11;
-    const int GFX940_XDL16PassWriteVgprVALUWawWaitStates = 19;
     const int SMFMA4x4ReadVgprVALUWarWaitStates = 1;
     const int GFX940_XDL4PassReadVgprVALUWarWaitStates = 3;
     const int SMFMA16x16ReadVgprVALUWarWaitStates = 7;
@@ -2617,42 +2644,39 @@ int GCNHazardRecognizer::checkMAIVALUHazards(MachineInstr *MI) {
         getWaitStatesSinceDef(Reg, IsMFMAWriteFn, MaxWaitStates);
     if (MFMA) {
       int NeedWaitStates = MaxWaitStates;
-      switch (TSchedModel.computeInstrLatency(MFMA)) {
-      case 2:
-        NeedWaitStates = ST.hasGFX940Insts()
-          ? isXDL(ST, *MFMA)
-            ? GFX940_XDL2PassWriteVgprVALUWawWaitStates
-            : GFX940_SMFMA2PassWriteVgprVALUWawWaitStates
-          : SMFMA4x4WriteVgprVALUWawWaitStates;
-        break;
-      case 4:
-        assert(isDGEMM(MFMA->getOpcode()) || ST.hasGFX940Insts());
-        NeedWaitStates = isDGEMM(MFMA->getOpcode())
-            ? DMFMA4x4WriteVgprVALUWriteWaitStates
-            : isXDL(ST, *MFMA)
-              ? GFX940_XDL4PassWriteVgprVALUWawWaitStates
-              : GFX940_SMFMA4PassWriteVgprVALUWawWaitStates;
-        break;
-      case 8:
-        NeedWaitStates =
-            isDGEMM(MFMA->getOpcode()) ? DMFMA16x16WriteVgprVALUWriteWaitStates
-            :
+      int NumPasses = TSchedModel.computeInstrLatency(MFMA);
 
-            ST.hasGFX940Insts()
-                ? isXDL(ST, *MFMA) ? GFX940_XDL8PassWriteVgprVALUWawWaitStates
-                                   : GFX940_SMFMA8PassWriteVgprVALUWawWaitStates
-                : SMFMA16x16WriteVgprVALUWawWaitStates;
-        break;
-      case 16: [[fallthrough]];
-      default:
-        assert(!isDGEMM(MFMA->getOpcode()));
+      if (isDGEMM(MFMA->getOpcode())) {
+        switch (NumPasses) {
+        case 4:
+          NeedWaitStates = DMFMA4x4WriteVgprVALUWriteWaitStates;
+          break;
+        case 8:
+        case 16:
+          NeedWaitStates = DMFMA16x16WriteVgprVALUWriteWaitStates;
+          break;
+        default:
+          llvm_unreachable("unexpected number of cycles for dgemm");
+        }
+      } else if (ST.hasGFX940Insts()) {
         NeedWaitStates =
-            ST.hasGFX940Insts()
-                ? isXDL(ST, *MFMA)
-                      ? GFX940_XDL16PassWriteVgprVALUWawWaitStates
-                      : GFX940_SMFMA16PassWriteVgprVALUWawWaitStates
-                : SMFMA32x32WriteVgprVALUWawWaitStates;
-        break;
+            isXDL(ST, *MFMA)
+                ? GFX940_XDL_N_PassWriteVgprVALUWawWaitStates(NumPasses)
+                : GFX940_SMFMA_N_PassWriteVgprVALUWawWaitStates(NumPasses);
+      } else {
+        switch (NumPasses) {
+        case 2:
+          NeedWaitStates = SMFMA4x4WriteVgprVALUWawWaitStates;
+          break;
+        case 8:
+          NeedWaitStates = SMFMA16x16WriteVgprVALUWawWaitStates;
+          break;
+        case 16:
+          NeedWaitStates = SMFMA32x32WriteVgprVALUWawWaitStates;
+          break;
+        default:
+          llvm_unreachable("Unexpected number of passes for mfma");
+        }
       }
 
       int WaitStatesNeededForUse = NeedWaitStates - WaitStatesSinceDef;


        


More information about the llvm-commits mailing list