[llvm] d70573b - [RISCV][NFC] Make Reduction scheduler resources SEW aware

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Tue May 30 07:27:43 PDT 2023


Author: Michael Maitland
Date: 2023-05-30T07:27:25-07:00
New Revision: d70573b18e9af94dcae7de2287ca56c77da27e7c

URL: https://github.com/llvm/llvm-project/commit/d70573b18e9af94dcae7de2287ca56c77da27e7c
DIFF: https://github.com/llvm/llvm-project/commit/d70573b18e9af94dcae7de2287ca56c77da27e7c.diff

LOG: [RISCV][NFC] Make Reduction scheduler resources SEW aware

Create SchedWrites, WriteRes for reduction instructions that
are SEW specific. Future patches can use these resources
to customize the behavior of these resources depending on SEW.

Differential Revision: https://reviews.llvm.org/D151470

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
    llvm/lib/Target/RISCV/RISCVSchedSiFive7.td
    llvm/lib/Target/RISCV/RISCVScheduleV.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 85046f1b40a3..d0d462287726 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -115,8 +115,14 @@ defvar MxListF = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
 
 // Used for widening and narrowing instructions as it doesn't contain M8.
 defvar MxListW = [V_MF8, V_MF4, V_MF2, V_M1, V_M2, V_M4];
+// Used for widening reductions. It can contain M8 because wider operands are
+// scalar operands.
+defvar MxListWRed = MxList;
 // For floating point which don't need MF8.
 defvar MxListFW = [V_MF4, V_MF2, V_M1, V_M2, V_M4];
+// For widening floating-point Reduction as it doesn't contain MF8. It can
+// contain M8 because wider operands are scalar operands.
+defvar MxListFWRed = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
 
 // Use for zext/sext.vf2
 defvar MxListVF2 = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
@@ -3180,16 +3186,14 @@ multiclass VPseudoTernaryWithTailPolicy_E<VReg RetClass,
                                           RegisterClass Op1Class,
                                           DAGOperand Op2Class,
                                           LMULInfo MInfo,
+                                          int sew,
                                           string Constraint = "",
                                           bit Commutable = 0> {
   let VLMul = MInfo.value in {
     defvar mx = MInfo.MX;
-    defvar sews = SchedSEWSet<mx>.val;
-    foreach e = sews in {
       let isCommutable = Commutable in
-      def "_" # mx # "_E" # e : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-      def "_" # mx # "_E" # e # "_MASK" : VPseudoBinaryTailPolicy<RetClass, Op1Class, Op2Class, Constraint>;
-    }
+      def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
+      def "_" # mx # "_E" # sew # "_MASK" : VPseudoBinaryTailPolicy<RetClass, Op1Class, Op2Class, Constraint>;
   }
 }
 
@@ -3448,50 +3452,60 @@ multiclass VPseudoVCMPM_VX_VI {
 multiclass VPseudoVRED_VS {
   foreach m = MxList in {
     defvar mx = m.MX;
-    defvar WriteVIRedV_From_MX = !cast<SchedWrite>("WriteVIRedV_From_" # mx);
-    defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
-               Sched<[WriteVIRedV_From_MX, ReadVIRedV, ReadVIRedV, ReadVIRedV,
-                      ReadVMask]>;
+    foreach e = SchedSEWSet<mx>.val in {
+      defvar WriteVIRedV_From_MX_E = !cast<SchedWrite>("WriteVIRedV_From_" # mx # "_E" # e);
+      defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
+                 Sched<[WriteVIRedV_From_MX_E, ReadVIRedV, ReadVIRedV, ReadVIRedV,
+                        ReadVMask]>;
+    }
   }
 }
 
 multiclass VPseudoVWRED_VS {
-  foreach m = MxList in {
+  foreach m = MxListWRed in {
     defvar mx = m.MX;
-    defvar WriteVIWRedV_From_MX = !cast<SchedWrite>("WriteVIWRedV_From_" # mx);
-    defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
-               Sched<[WriteVIWRedV_From_MX, ReadVIWRedV, ReadVIWRedV,
-                      ReadVIWRedV, ReadVMask]>;
+    foreach e = SchedSEWSet<mx, 1>.val in {
+      defvar WriteVIWRedV_From_MX_E = !cast<SchedWrite>("WriteVIWRedV_From_" # mx # "_E" # e);
+      defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
+                 Sched<[WriteVIWRedV_From_MX_E, ReadVIWRedV, ReadVIWRedV,
+                        ReadVIWRedV, ReadVMask]>;
+    }
   }
 }
 
 multiclass VPseudoVFRED_VS {
   foreach m = MxListF in {
     defvar mx = m.MX;
-    defvar WriteVFRedV_From_MX = !cast<SchedWrite>("WriteVFRedV_From_" # mx);
-    defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
-               Sched<[WriteVFRedV_From_MX, ReadVFRedV, ReadVFRedV, ReadVFRedV,
-                      ReadVMask]>;
+    foreach e = SchedSEWSetF<mx>.val in {
+      defvar WriteVFRedV_From_MX_E = !cast<SchedWrite>("WriteVFRedV_From_" # mx # "_E" # e);
+      defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
+                 Sched<[WriteVFRedV_From_MX_E, ReadVFRedV, ReadVFRedV, ReadVFRedV,
+                        ReadVMask]>;
+    }
   }
 }
 
 multiclass VPseudoVFREDO_VS {
   foreach m = MxListF in {
     defvar mx = m.MX;
-    defvar WriteVFRedOV_From_MX = !cast<SchedWrite>("WriteVFRedOV_From_" # mx);
-    defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
-               Sched<[WriteVFRedOV_From_MX, ReadVFRedOV, ReadVFRedOV,
-                      ReadVFRedOV, ReadVMask]>;
+    foreach e = SchedSEWSetF<mx>.val in {
+      defvar WriteVFRedOV_From_MX_E = !cast<SchedWrite>("WriteVFRedOV_From_" # mx # "_E" # e);
+      defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
+                 Sched<[WriteVFRedOV_From_MX_E, ReadVFRedOV, ReadVFRedOV,
+                        ReadVFRedOV, ReadVMask]>;
+    }
   }
 }
 
 multiclass VPseudoVFWRED_VS {
-  foreach m = MxListF in {
+  foreach m = MxListFWRed in {
     defvar mx = m.MX;
-    defvar WriteVFWRedV_From_MX = !cast<SchedWrite>("WriteVFWRedV_From_" # mx);
-    defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
-               Sched<[WriteVFWRedV_From_MX, ReadVFWRedV, ReadVFWRedV,
-                      ReadVFWRedV, ReadVMask]>;
+    foreach e = SchedSEWSetF<mx, 1>.val in {
+      defvar WriteVFWRedV_From_MX_E = !cast<SchedWrite>("WriteVFWRedV_From_" # mx # "_E" # e);
+      defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
+                 Sched<[WriteVFWRedV_From_MX_E, ReadVFWRedV, ReadVFWRedV,
+                        ReadVFWRedV, ReadVMask]>;
+    }
   }
 }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVSchedSiFive7.td b/llvm/lib/Target/RISCV/RISCVSchedSiFive7.td
index d38051d2420a..345dd90157e2 100644
--- a/llvm/lib/Target/RISCV/RISCVSchedSiFive7.td
+++ b/llvm/lib/Target/RISCV/RISCVSchedSiFive7.td
@@ -620,12 +620,12 @@ foreach mx = SchedMxListFW in {
 
 // 14. Vector Reduction Operations
 let Latency = 32 in {
-defm "" : LMULWriteRes<"WriteVIRedV_From", [SiFive7VA]>;
-defm "" : LMULWriteRes<"WriteVIWRedV_From", [SiFive7VA]>;
-defm "" : LMULWriteRes<"WriteVFRedV_From", [SiFive7VA]>;
-defm "" : LMULWriteRes<"WriteVFRedOV_From", [SiFive7VA]>;
-defm "" : LMULWriteResFWRed<"WriteVFWRedV_From", [SiFive7VA]>;
-defm "" : LMULWriteResFWRed<"WriteVFWRedOV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteRes<"WriteVIRedV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteRes<"WriteVIWRedV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteRes<"WriteVFRedV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteRes<"WriteVFRedOV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteResFWRed<"WriteVFWRedV_From", [SiFive7VA]>;
+defm "" : LMULSEWWriteResFWRed<"WriteVFWRedOV_From", [SiFive7VA]>;
 }
 
 // 15. Vector Mask Instructions

diff  --git a/llvm/lib/Target/RISCV/RISCVScheduleV.td b/llvm/lib/Target/RISCV/RISCVScheduleV.td
index b6ab10454cfd..5863f170d5d9 100644
--- a/llvm/lib/Target/RISCV/RISCVScheduleV.td
+++ b/llvm/lib/Target/RISCV/RISCVScheduleV.td
@@ -12,30 +12,35 @@
 defvar SchedMxList = ["MF8", "MF4", "MF2", "M1", "M2", "M4", "M8"];
 // Used for widening and narrowing instructions as it doesn't contain M8.
 defvar SchedMxListW = !listremove(SchedMxList, ["M8"]);
+// Used for widening reductions, which does contain M8.
+defvar SchedMxListWRed = SchedMxList;
 defvar SchedMxListFW = !listremove(SchedMxList, ["M8", "MF8"]);
 // Used for floating-point as it doesn't contain MF8.
 defvar SchedMxListF = !listremove(SchedMxList, ["MF8"]);
 // Used for widening floating-point Reduction as it doesn't contain MF8.
 defvar SchedMxListFWRed = SchedMxListF;
 
-class SchedSEWSet<string mx> {
-  list<int> val = !cond(!eq(mx, "M1"):  [8, 16, 32, 64],
-                        !eq(mx, "M2"):  [8, 16, 32, 64],
-                        !eq(mx, "M4"):  [8, 16, 32, 64],
-                        !eq(mx, "M8"):  [8, 16, 32, 64],
-                        !eq(mx, "MF2"): [8, 16, 32],
-                        !eq(mx, "MF4"): [8, 16],
-                        !eq(mx, "MF8"): [8]);
+// For widening instructions, SEW will not be 64.
+class SchedSEWSet<string mx, bit isWidening = 0> {
+  defvar t = !cond(!eq(mx, "M1"):  [8, 16, 32, 64],
+                   !eq(mx, "M2"):  [8, 16, 32, 64],
+                   !eq(mx, "M4"):  [8, 16, 32, 64],
+                   !eq(mx, "M8"):  [8, 16, 32, 64],
+                   !eq(mx, "MF2"): [8, 16, 32],
+                   !eq(mx, "MF4"): [8, 16],
+                   !eq(mx, "MF8"): [8]);
+  list<int> val = !if(isWidening, !listremove(t, [64]), t);
 }
 
 // For floating-point instructions, SEW won't be 8.
-class SchedSEWSetF<string mx> {
-  list<int> val = !cond(!eq(mx, "M1"):  [16, 32, 64],
-                        !eq(mx, "M2"):  [16, 32, 64],
-                        !eq(mx, "M4"):  [16, 32, 64],
-                        !eq(mx, "M8"):  [16, 32, 64],
-                        !eq(mx, "MF2"): [16, 32],
-                        !eq(mx, "MF4"): [16]);
+class SchedSEWSetF<string mx, bit isWidening = 0> {
+  defvar t = !cond(!eq(mx, "M1"):  [16, 32, 64],
+                   !eq(mx, "M2"):  [16, 32, 64],
+                   !eq(mx, "M4"):  [16, 32, 64],
+                   !eq(mx, "M8"):  [16, 32, 64],
+                   !eq(mx, "MF2"): [16, 32],
+                   !eq(mx, "MF4"): [16]);
+  list<int> val = !if(isWidening, !listremove(t, [64]), t);
 }
 
 // Helper function to get the largest LMUL from MxList
@@ -102,34 +107,46 @@ multiclass LMULReadAdvanceImpl<string name, int val,
 // ReadAdvance for each (name, LMUL, SEW) tuple for each LMUL in each of the
 // SchedMxList variants above. Each multiclass is responsible for defining
 // a record that represents the WorseCase behavior for name.
-multiclass LMULSEWSchedWritesImpl<string name, list<string> MxList, bit isF = 0> {
+multiclass LMULSEWSchedWritesImpl<string name, list<string> MxList, bit isF = 0,
+                                  bit isWidening = 0> {
   def name # "_WorstCase" : SchedWrite;
   foreach mx = MxList in {
-    foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
+    foreach sew = !if(isF, SchedSEWSetF<mx, isWidening>.val,
+                      SchedSEWSet<mx, isWidening>.val) in
       def name # "_" # mx # "_E" # sew : SchedWrite;
   }
 }
-multiclass LMULSEWSchedReadsImpl<string name, list<string> MxList, bit isF = 0> {
+multiclass LMULSEWSchedReadsImpl<string name, list<string> MxList, bit isF = 0,
+                                 bit isWidening = 0> {
   def name # "_WorstCase" : SchedRead;
   foreach mx = MxList in {
-    foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
+    foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
+                      SchedSEWSet<mx, isWidening>.val) in
       def name # "_" # mx # "_E" # sew : SchedRead;
   }
 }
 multiclass LMULSEWWriteResImpl<string name, list<ProcResourceKind> resources,
-                               bit isF = 0> {
-  def : WriteRes<!cast<SchedWrite>(name # "_WorstCase"), resources>;
-  foreach mx = !if(isF, SchedMxListF, SchedMxList) in {
-    foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
-      def : WriteRes<!cast<SchedWrite>(name # "_" # mx # "_E" # sew), resources>;
+                               list<string> MxList, bit isF = 0,
+                               bit isWidening = 0> {
+  if !exists<SchedWrite>(name # "_WorstCase") then
+    def : WriteRes<!cast<SchedWrite>(name # "_WorstCase"), resources>;
+  foreach mx = MxList in {
+    foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
+                      SchedSEWSet<mx, isWidening>.val) in
+      if !exists<SchedWrite>(name # "_" # mx # "_E" # sew) then
+        def : WriteRes<!cast<SchedWrite>(name # "_" # mx # "_E" # sew), resources>;
   }
 }
 multiclass LMULSEWReadAdvanceImpl<string name, int val, list<SchedWrite> writes = [],
-                                  bit isF = 0> {
-  def : ReadAdvance<!cast<SchedRead>(name # "_WorstCase"), val, writes>;
-  foreach mx = !if(isF, SchedMxListF, SchedMxList) in {
-    foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
-      def : ReadAdvance<!cast<SchedRead>(name # "_" # mx # "_E" # sew), val, writes>;
+                                  list<string> MxList, bit isF = 0,
+                                  bit isWidening = 0> {
+  if !exists<SchedRead>(name # "_WorstCase") then
+    def : ReadAdvance<!cast<SchedRead>(name # "_WorstCase"), val, writes>;
+  foreach mx = MxList in {
+    foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
+                      SchedSEWSet<mx, isWidening>.val) in
+      if !exists<SchedRead>(name # "_" # mx # "_E" # sew) then
+        def : ReadAdvance<!cast<SchedRead>(name # "_" # mx # "_E" # sew), val, writes>;
   }
 }
 // Define classes to define list containing all SchedWrites for each (name, LMUL)
@@ -159,16 +176,26 @@ class LMULSchedWriteList<list<string> names> : LMULSchedWriteListImpl<names, Sch
 multiclass LMULSEWSchedWrites<string name> : LMULSEWSchedWritesImpl<name, SchedMxList>;
 multiclass LMULSEWSchedReads<string name> : LMULSEWSchedReadsImpl<name, SchedMxList>;
 multiclass LMULSEWWriteRes<string name, list<ProcResourceKind> resources>
-  : LMULSEWWriteResImpl<name, resources>;
+  : LMULSEWWriteResImpl<name, resources, SchedMxList>;
 multiclass LMULSEWReadAdvance<string name, int val, list<SchedWrite> writes = []>
-  : LMULSEWReadAdvanceImpl<name, val, writes>;
+  : LMULSEWReadAdvanceImpl<name, val, writes, SchedMxList>;
+
+multiclass LMULSEWSchedWritesWRed<string name>
+    : LMULSEWSchedWritesImpl<name, SchedMxListWRed, 0, 1>;
+multiclass LMULSEWWriteResWRed<string name, list<ProcResourceKind> resources>
+    : LMULSEWWriteResImpl<name, resources, SchedMxListWRed, 0, 1>;
+
+multiclass LMULSEWSchedWritesFWRed<string name>
+    : LMULSEWSchedWritesImpl<name, SchedMxListFWRed, 1, 1>;
+multiclass LMULSEWWriteResFWRed<string name, list<ProcResourceKind> resources>
+    : LMULSEWWriteResImpl<name, resources, SchedMxListFWRed, 1, 1>;
 
 multiclass LMULSEWSchedWritesF<string name> : LMULSEWSchedWritesImpl<name, SchedMxListF, 1>;
 multiclass LMULSEWSchedReadsF<string name> : LMULSEWSchedReadsImpl<name, SchedMxListF, 1>;
 multiclass LMULSEWWriteResF<string name, list<ProcResourceKind> resources>
-  : LMULSEWWriteResImpl<name, resources, 1>;
+  : LMULSEWWriteResImpl<name, resources, SchedMxListF, 1>;
 multiclass LMULSEWReadAdvanceF<string name, int val, list<SchedWrite> writes = []>
-  : LMULSEWReadAdvanceImpl<name, val, writes, 1>;
+  : LMULSEWReadAdvanceImpl<name, val, writes, SchedMxListF, 1>;
 
 multiclass LMULSchedWritesW<string name> : LMULSchedWritesImpl<name, SchedMxListW>;
 multiclass LMULSchedReadsW<string name> : LMULSchedReadsImpl<name, SchedMxListW>;
@@ -186,12 +213,6 @@ multiclass LMULReadAdvanceFW<string name, int val, list<SchedWrite> writes = []>
   : LMULReadAdvanceImpl<name, val, writes>;
 class LMULSchedWriteListFW<list<string> names> : LMULSchedWriteListImpl<names, SchedMxListFW>;
 
-multiclass LMULSchedWritesFWRed<string name> : LMULSchedWritesImpl<name, SchedMxListFWRed>;
-multiclass LMULWriteResFWRed<string name, list<ProcResourceKind> resources>
-  : LMULWriteResImpl<name, resources>;
-class LMULSchedWriteListFWRed<list<string> names> : LMULSchedWriteListImpl<names, SchedMxListFWRed>;
-
-
 // 3.6 Vector Byte Length vlenb
 def WriteRdVLENB      : SchedWrite;
 
@@ -389,15 +410,15 @@ defm "" : LMULSchedWritesFW<"WriteVFNCvtFToFV">;
 // MF8 and M8. Use the _From suffix to indicate the number of the
 // LMUL from VS2.
 // 14.1. Vector Single-Width Integer Reduction Instructions
-defm "" : LMULSchedWrites<"WriteVIRedV_From">;
+defm "" : LMULSEWSchedWrites<"WriteVIRedV_From">;
 // 14.2. Vector Widening Integer Reduction Instructions
-defm "" : LMULSchedWrites<"WriteVIWRedV_From">;
+defm "" : LMULSEWSchedWritesWRed<"WriteVIWRedV_From">;
 // 14.3. Vector Single-Width Floating-Point Reduction Instructions
-defm "" : LMULSchedWrites<"WriteVFRedV_From">;
-defm "" : LMULSchedWrites<"WriteVFRedOV_From">;
+defm "" : LMULSEWSchedWritesF<"WriteVFRedV_From">;
+defm "" : LMULSEWSchedWritesF<"WriteVFRedOV_From">;
 // 14.4. Vector Widening Floating-Point Reduction Instructions
-defm "" : LMULSchedWritesFWRed<"WriteVFWRedV_From">;
-defm "" : LMULSchedWritesFWRed<"WriteVFWRedOV_From">;
+defm "" : LMULSEWSchedWritesFWRed<"WriteVFWRedV_From">;
+defm "" : LMULSEWSchedWritesFWRed<"WriteVFWRedOV_From">;
 
 // 15. Vector Mask Instructions
 // 15.1. Vector Mask-Register Logical Instructions
@@ -821,12 +842,12 @@ defm "" : LMULWriteResW<"WriteVFNCvtFToIV", []>;
 defm "" : LMULWriteResFW<"WriteVFNCvtFToFV", []>;
 
 // 14. Vector Reduction Operations
-defm "" : LMULWriteRes<"WriteVIRedV_From", []>;
-defm "" : LMULWriteRes<"WriteVIWRedV_From", []>;
-defm "" : LMULWriteRes<"WriteVFRedV_From", []>;
-defm "" : LMULWriteRes<"WriteVFRedOV_From", []>;
-defm "" : LMULWriteResFWRed<"WriteVFWRedV_From", []>;
-defm "" : LMULWriteResFWRed<"WriteVFWRedOV_From", []>;
+defm "" : LMULSEWWriteRes<"WriteVIRedV_From", []>;
+defm "" : LMULSEWWriteResWRed<"WriteVIWRedV_From", []>;
+defm "" : LMULSEWWriteResF<"WriteVFRedV_From", []>;
+defm "" : LMULSEWWriteResF<"WriteVFRedOV_From", []>;
+defm "" : LMULSEWWriteResFWRed<"WriteVFWRedV_From", []>;
+defm "" : LMULSEWWriteResFWRed<"WriteVFWRedOV_From", []>;
 
 // 15. Vector Mask Instructions
 defm "" : LMULWriteRes<"WriteVMALUV", []>;


        


More information about the llvm-commits mailing list