[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)

Kirill Vedernikov llvmlistbot at llvm.org
Thu Nov 20 00:35:00 PST 2025


================
@@ -2772,6 +2832,221 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
+                       WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma"
+              # "_" # !subst("::", "_", Metadata)
+              # "_" # A.geom
+              # "_row_col"
+              # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+              # !if(Satfinite, "_satfinite", "")
+              # signature;
+}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+                            string kind, int satf> {
+  // MMA.SP ops check both layouts.
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
+
+  bit is_int = !or(!eq(a_type, "s8"),
+                   !eq(a_type, "u8"),
+                   !eq(a_type, "s4"),
+                   !eq(a_type, "u4"));
+
+  bit ret = !cond(
+
+    // Limit satf to valid types
+    !and(!eq(satf, 1),
+         !eq(is_int, 0)): false,
+
+    // f16/bf16/tf32 requires A and B to be the same type.
+    !and(!or(!eq(a_type, "f16"),
+             !eq(a_type, "bf16"),
+             !eq(a_type, "tf32")),
+         !ne(a_type, b_type)): false,
+
+    // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64")),
+         !ne(c_type, d_type)): false,
+
+    !and(!eq(kind, ""),
+         !or(!eq(a_type, "e3m2"),
+             !eq(a_type, "e2m3"),
+             !eq(a_type, "e2m1"),
+             !eq(b_type, "e3m2"),
+             !eq(b_type, "e2m3"),
+             !eq(b_type, "e2m1"))): false,
+
+    !and(!eq(kind, ""),
+         !eq(geom, "m16n8k64"),
+         !or(!eq(c_type, "f16"),
+             !eq(d_type, "f16"))): false,
+
+    !and(!ne(kind, ""),
+         !or(!eq(metadata, "sp"),
+             !ne(geom, "m16n8k64"),
+             !eq(is_int, 1))): false,
+
+    // All other are OK.
+    true: true
+  );
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.sync
+/// intrinsic enum value.
+class MMA_SP_SYNC_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
+      !foreach(metadata, ["sp", "sp::ordered_metadata"],
+        !foreach(kind, ["", "kind::f8f6f4"],
+          !foreach (satf, [0, 1],
+            !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
+                # " && (satf.has_value()  ? " # satf # " == static_cast<int>(*satf) : true)"
+                # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata")
+                # " && " # !if(!eq(kind, ""), "!hasKind", "hasKind") # ")\n"
----------------
kvederni wrote:

It has been fixed in 156919a.

https://github.com/llvm/llvm-project/pull/168686


More information about the Mlir-commits mailing list