[llvm] [NFC][NVVM][NVPTX] Moved common code for tcgen05.mma to the base class (PR #176327)

Pradeep Kumar via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 18 22:07:02 PST 2026


================
@@ -6071,195 +6060,142 @@ foreach sp = [0, 1] in {
                                           !eq(kind, "tf32")), [0, 1], [0]) in {
             foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
               def :
-                Tcgen05MMADisableOutputLaneInst<sp, space, kind, cta_group,
-                                               collector_usage, scale_input_d,
-                                               ashift>;
-            }
-          }
-        }
-      }
-    }
-  }
-}
+                Tcgen05MMADisableOutputLaneInst<sparse, space, kind, cta_group,
+                                                collector_usage, scale_input_d,
+                                                ashift>;
+            } // ashift
+          } // scale_input_d
+        } // collector_usage
+      } // cta_group
+    } // kind
+  } // space
+} // sparse
 
-class Tcgen05MMABlockScaleInst<bit Sp, string ASpace, string KindStr,
-                     int CtaGroup, string ScaleVecSize, string CollectorUsageStr>:
-         NVPTXInst<(outs), (ins), "?", []>, Requires<[]> {
+//
+// tcgen05.mma.block_scale Instructions
+//
+
+class Tcgen05MMABlockScaleInst<bit IsSparse, string ASpace, string Kind,
+                               int CtaGroup, string ScaleVecSize,
+                               string CollectorUsage>:
+        Tcgen05MMABase<IsSparse, ASpace, Kind, CtaGroup, CollectorUsage> {
 
   let Predicates = !cond(
-    !and(!eq(Sp, 1),
-         !eq(KindStr, "mxf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4">],
-    !and(!eq(Sp, 1),
-         !eq(KindStr, "mxf4nvf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4nvf4">],
+    !and(IsSparse,
+         !eq(Kind, "mxf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4">],
+    !and(IsSparse,
+         !eq(Kind, "mxf4nvf4")) : [callSubtarget<"hasTcgen05MMASparseMxf4nvf4">],
     !ne(ScaleVecSize, "") : [callSubtarget<"hasTcgen05InstSupport">, hasPTX<88>],
     true : [callSubtarget<"hasTcgen05InstSupport">]
   );
 
   Intrinsic Intrin = !cast<Intrinsic>(
-                             NVVM_TCGEN05_MMA_BLOCKSCALE<Sp, ASpace, KindStr, ScaleVecSize>.record_name);
-
-  dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins));
-  dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin i32:$spmetadata), (Intrin));
-  string SparseMetadataStr = !if(!eq(Sp, 1), ", [$spmetadata]", "");
+                             NVVM_TCGEN05_MMA_BLOCKSCALE<IsSparse, ASpace, Kind, ScaleVecSize>.record_name);
 
-  int KindVal = !cond(
-                  !eq(KindStr, "mxf8f6f4") : 0,
-                  !eq(KindStr, "mxf4")     : 1,
-                  !eq(KindStr, "mxf4nvf4") : 2,
+  let KindVal = !cond(
+                  !eq(Kind, "mxf8f6f4") : 0,
+                  !eq(Kind, "mxf4")     : 1,
+                  !eq(Kind, "mxf4nvf4") : 2,
                 );
 
-  int CollectorUsage = !cond(
-    !eq(CollectorUsageStr, "discard") : 0,
-    !eq(CollectorUsageStr, "lastuse") : 1,
-    !eq(CollectorUsageStr, "fill")    : 2,
-    !eq(CollectorUsageStr, "use")     : 3,
-  );
-
-  string AOperandStr = !if(!eq(ASpace, "tensor"), "[$a]", "$a");
-  NVPTXRegClass ARegClass = !if(!eq(ASpace, "tensor"), B32, B64);
+  let InOperandList = !con(BaseInOperandList,
+                           (ins B32:$scale_a,
+                                B32:$scale_b));
+  let AsmString = Prefix
+                  # SpCtaKindStr
+                  # ".block_scale" # ScaleVecSize
+                  # ".collector::a::" # CollectorUsage
+                  # BaseOperandsStr
+                  # ", [$scale_a], [$scale_b]"
+                  # InputDStr
+                  # ";";
 
-  dag input = !con((ins B32:$dtmem, ARegClass:$a, B64:$b,
-                        B32:$idesc, B1:$enable_inp_d),
-                    SparseMetadataIns,
-                    (ins B32:$scale_a,
-                         B32:$scale_b));
+  dag IntrinsicPattern = !con(!foreach(tmp, BasePatternArgs, !subst(ins, Intrin, tmp)),
+                              (Intrin i32:$scale_a,
+                                      i32:$scale_b));
 
-  let InOperandList = input;
-  let OutOperandList = (outs);
-  let AsmString = "tcgen05.mma"
-                   # !if(!eq(Sp, 1), ".sp", "")
-                   # ".cta_group::" # CtaGroup
-                   # ".kind::" # KindStr
-                   # ".block_scale" # ScaleVecSize
-                   # ".collector::a::" # CollectorUsageStr
-                   # " [$dtmem], " # AOperandStr # ", $b"
-                   # SparseMetadataStr
-                   # ", $idesc, [$scale_a], [$scale_b], $enable_inp_d;";
-
-  dag IntrinsicPattern = !con((Intrin i32:$dtmem,
-                                      ARegClass:$a, i64:$b,
-                                      i32:$idesc,
-                                      i1:$enable_inp_d),
-                               SparseMetadataIntr,
-                               (Intrin i32:$scale_a,
-                                       i32:$scale_b));
-
-  dag FlagOperands = (Intrin (i32 CtaGroup), (i32 CollectorUsage));
+  dag FlagOperands = (Intrin (i32 CtaGroup), (i32 CollectorUsageVal));
 
   let Pattern = [!con(IntrinsicPattern, FlagOperands)];
 }
 
 // tcgen05.mma.block_scale
-foreach sp = [0, 1] in {
+foreach sparse = [0, 1] in {
   foreach space = ["tensor", "shared"] in {
     foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
       foreach scale_vec_size = ["", ".block16", ".block32"] in {
         foreach cta_group = [1, 2] in {
           foreach collector_usage = ["fill", "use", "lastuse", "discard"] in {
             if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
-              def : Tcgen05MMABlockScaleInst<sp, space, kind, cta_group, scale_vec_size,
-                                             collector_usage>;
+              def : Tcgen05MMABlockScaleInst<sparse, space, kind, cta_group,
+                                             scale_vec_size, collector_usage>;
             }
-          }
-        }
-      }
-    }
-  }
-}
+          } // collector_usage
+        } // cta_group
+      } // scale_vec_size
+    } // kind
+  } // space
+} // sparse
 
 //
 // tcgen05.mma.ws Instructions
 //
 
-class Tcgen05MMAWSInst<bit Sp, string ASpace, string KindStr,
-                       int CollectorBufferB, string CollectorUsageOpStr,
-                       bit HasZeroColMask> :
-         NVPTXInst<(outs), (ins), "?", []>, Requires<[]> {
+class Tcgen05MMAWSInst<bit IsSparse, string ASpace, string Kind,
+                       int CollectorBufferB, string CollectorUsage,
+                       bit IsZeroColMask> :
+        Tcgen05MMABase<IsSparse, ASpace, Kind, /*CtaGroup=*/ 1, CollectorUsage> {
 
   let Predicates = !cond(
-    !eq(KindStr, "i8") : [callSubtarget<"hasTcgen05MMAI8Kind">],
+    !eq(Kind, "i8") : [callSubtarget<"hasTcgen05MMAI8Kind">],
     true : [callSubtarget<"hasTcgen05InstSupport">]
   );
 
   Intrinsic Intrin = !cast<Intrinsic>(
-                            NVVM_TCGEN05_MMA_WS<Sp, ASpace, HasZeroColMask>.record_name);
-
-  dag ZeroColMaskIns = !if(!eq(HasZeroColMask, 1),
-                              (ins B64:$zero_col_mask), (ins));
-  string ZeroColMaskStr = !if(!eq(HasZeroColMask, 1), ", $zero_col_mask", "");
-  dag ZeroColMaskIntr = !if(!eq(HasZeroColMask, 1),
-                               (Intrin i64:$zero_col_mask), (Intrin));
-
-  dag SparseMetadataIns = !if(!eq(Sp, 1), (ins B32:$spmetadata), (ins));
-  dag SparseMetadataIntr = !if(!eq(Sp, 1), (Intrin B32:$spmetadata), (Intrin));
-  string SparseMetadataStr = !if(!eq(Sp, 1), ", [$spmetadata]", "");
-
-  int KindVal = !cond(
-                  !eq(KindStr, "f16")   : 0,
-                  !eq(KindStr, "tf32")  : 1,
-                  !eq(KindStr, "f8f6f4"): 2,
-                  !eq(KindStr, "i8")    : 3,
-                );
-
-  int CollectorUsageOp = !cond(
-    !eq(CollectorUsageOpStr, "discard"): 0,
-    !eq(CollectorUsageOpStr, "lastuse"): 1,
-    !eq(CollectorUsageOpStr, "fill")   : 2,
-    !eq(CollectorUsageOpStr, "use")    : 3,
-  );
-
-  string AOperandStr = !if(!eq(ASpace, "tensor"), "[$a]", "$a");
-  NVPTXRegClass ARegClass = !if(!eq(ASpace, "tensor"), B32, B64);
-
-  dag input = !con((ins B32:$dtmem,
-                        ARegClass:$a, B64:$b,
-                        B32:$idesc,
-                        B1:$enable_inp_d),
-                        SparseMetadataIns,
-                        ZeroColMaskIns);
+                            NVVM_TCGEN05_MMA_WS<IsSparse, ASpace, IsZeroColMask>.record_name);
+
+  dag ZeroColMaskIns = !if(IsZeroColMask, (ins B64:$zero_col_mask), (ins));
+  string ZeroColMaskStr = !if(IsZeroColMask, ", $zero_col_mask", "");
+  dag ZeroColMaskIntr = !if(IsZeroColMask,
+                            (Intrin i64:$zero_col_mask), (Intrin));
+
+  let InOperandList = !con(BaseInOperandList,
+                           ZeroColMaskIns);
+
+  let AsmString = Prefix
+                  # ".ws"
+                  # SpCtaKindStr
+                  # ".collector::b" # CollectorBufferB
+                  # "::" # CollectorUsage
+                  # BaseOperandsStr
+                  # InputDStr
+                  # ZeroColMaskStr
+                  # ";";
 
-  let InOperandList = input;
-  let OutOperandList = (outs);
-  let AsmString = "tcgen05.mma.ws"
-                   # !if(!eq(Sp, 1), ".sp", "")
-                   # ".cta_group::1"
-                   # ".kind::" # KindStr
-                   # ".collector::b" # CollectorBufferB
-                   # "::" # CollectorUsageOpStr
-                   # " [$dtmem], " # AOperandStr # ", $b"
-                   # SparseMetadataStr
-                   # ", $idesc, $enable_inp_d"
-                   # ZeroColMaskStr
-                   # ";";
-
-  dag IntrinsicPattern = !con((Intrin i32:$dtmem,
-                                      ARegClass:$a, i64:$b,
-                                      i32:$idesc,
-                                      i1:$enable_inp_d),
-                               SparseMetadataIntr,
-                               ZeroColMaskIntr);
+  dag IntrinsicPattern = !con(!foreach(tmp, BasePatternArgs, !subst(ins, Intrin, tmp)),
+                              ZeroColMaskIntr);
 
   dag FlagOperands = (Intrin (i32 KindVal), (i32 CollectorBufferB),
-                             (i32 CollectorUsageOp));
+                             (i32 CollectorUsageVal));
 
   let Pattern = [!con(IntrinsicPattern, FlagOperands)];
 }
 
 // tcgen05.mma.ws
-foreach sp = [0, 1] in {
+foreach sparse = [0, 1] in {
   foreach space = ["shared", "tensor"] in {
     foreach kind = ["f16", "tf32", "f8f6f4", "i8"] in {
       foreach collector_buffer_b = [0, 1, 2, 3] in {
         foreach collector_usage_op = ["discard", "fill", "use", "lastuse"] in {
           foreach zero_col_mask = [0, 1] in {
-              def : Tcgen05MMAWSInst<sp, space, kind, collector_buffer_b,
+              def : Tcgen05MMAWSInst<sparse, space, kind, collector_buffer_b,
                                      collector_usage_op, zero_col_mask>;
-          }
-        }
-      }
-    }
-  }
-}
+          } // zero_col_mask
+        } // collector_usage_op
+      } // collector_buffer_b
+    } //kind
----------------
schwarzschild-radius wrote:

Missing space here?

https://github.com/llvm/llvm-project/pull/176327


More information about the llvm-commits mailing list