[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 19 01:45:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kirill Vedernikov (kvederni)

<details>
<summary>Changes</summary>

This change adds sparse MMA intrinsics to the MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.

---

Patch is 97.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168686.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+276-1) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+474) 
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir (+221) 
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir (+411) 
- (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir (+390) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8d5bc7333d47f..b8f69f6b2cb98 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1955,6 +1955,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
 /// Generate the signature part of the mma intrinsic name.
 class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   list<WMMA_REGS> id_frags = !cond(
+     // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+     !or(!eq(A.ptx_elt_type, "e4m3"),
+         !eq(A.ptx_elt_type, "e5m2"),
+         !eq(A.ptx_elt_type, "e3m2"),
+         !eq(A.ptx_elt_type, "e2m3"),
+         !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
      // FP16 ops are identified by accumulator & result type.
      !eq(A.ptx_elt_type, "f16") : [D, C],
      // other ops are identified by input types.
@@ -2081,6 +2087,31 @@ class NVVM_MMA_OPS {
   list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
             tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
             fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+
+  list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,16>, GEOM<16,8,32>],
+            ["bf16"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,8>, GEOM<16,8,16>],
+            ["tf32"], [], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,16>, GEOM<16,8,32>],
+            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+  list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,64>],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
+  list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,64>, GEOM<16,8,128>],
+            ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,32>, GEOM<16,8,64>],
+            ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
+            bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+            subint_mma_sp_ops, int_mma_sp_ops);
+
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -2187,6 +2218,29 @@ def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflo
   let assemblyFormat = "`<` $value `>`";
 }
 
+/// Sparse MMA metadata types
+def MMASpMetadataStandard : I32EnumAttrCase<"standard", 0>;
+def MMASpMetadataOrdered : I32EnumAttrCase<"ordered", 1>;
+def MMASpMetadata : I32EnumAttr<"MMASpMetadata", "Sparse MMA metadata ordering",
+  [MMASpMetadataStandard, MMASpMetadataOrdered]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def MMASpMetadataAttr : EnumAttr<NVVM_Dialect, MMASpMetadata, "mma_sp_metadata"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// MMA kind types (for mixed-precision FP8 operations)
+def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
+def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
+  [MMAKindF8F6F4]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 /// Attribute to hold the MMA shape
 def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
   let summary = "Attribute for MMA operation shape.";
@@ -2330,12 +2384,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
 def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
 def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
 def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
+def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
+def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
+def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
+def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
+def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;
 
 def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
   [MMATypeF16, MMATypeF32, MMATypeTF32,
   MMATypeBF16, MMATypeS8, MMATypeU8,
   MMATypeS32, MMATypeS4, MMATypeU4,
-  MMATypeB1, MMATypeF64]> {
+  MMATypeB1, MMATypeF64,
+  MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
 }
@@ -2772,6 +2832,221 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+/// Generate enum value of the mma.sync intrinsic.
+class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
+                       WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma"
+              # "_" # !subst("::", "_", Metadata)
+              # "_" # A.geom
+              # "_row_col"
+              # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+              # !if(Satfinite, "_satfinite", "")
+              # signature;
+}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+                            string kind, int satf> {
+  // MMA.SP ops check both layouts.
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
+
+  bit is_int = !or(!eq(a_type, "s8"),
+                   !eq(a_type, "u8"),
+                   !eq(a_type, "s4"),
+                   !eq(a_type, "u4"));
+
+  bit ret = !cond(
+
+    // Limit satf to valid types
+    !and(!eq(satf, 1),
+         !eq(is_int, 0)): false,
+
+    // f16/bf16/tf32 requires A and B to be the same type.
+    !and(!or(!eq(a_type, "f16"),
+             !eq(a_type, "bf16"),
+             !eq(a_type, "tf32")),
+         !ne(a_type, b_type)): false,
+
+    // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64")),
+         !ne(c_type, d_type)): false,
+
+    !and(!eq(kind, ""),
+         !or(!eq(a_type, "e3m2"),
+             !eq(a_type, "e2m3"),
+             !eq(a_type, "e2m1"),
+             !eq(b_type, "e3m2"),
+             !eq(b_type, "e2m3"),
+             !eq(b_type, "e2m1"))): false,
+
+    !and(!eq(kind, ""),
+         !eq(geom, "m16n8k64"),
+         !or(!eq(c_type, "f16"),
+             !eq(d_type, "f16"))): false,
+
+    !and(!ne(kind, ""),
+         !or(!eq(metadata, "sp"),
+             !ne(geom, "m16n8k64"),
+             !eq(is_int, 1))): false,
+
+    // All other are OK.
+    true: true
+  );
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.sync
+/// intrinsic enum value.
+class MMA_SP_SYNC_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
+      !foreach(metadata, ["sp", "sp::ordered_metadata"],
+        !foreach(kind, ["", "kind::f8f6f4"],
+          !foreach (satf, [0, 1],
+            !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
+                # " && (satf.has_value()  ? " # satf # " == static_cast<int>(*satf) : true)"
+                # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata")
+                # " && " # !if(!eq(kind, ""), "!hasKind", "hasKind") # ")\n"
+                # "  return " #
+                MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
+                "") // if supported
+          ) // satf
+        ) // kind
+      ) // metadata
+    ); // all_mma_sp_sync_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+                                       !listconcat(acc, el));
+  list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
+
+  let summary = "cooperative sparse matrix-multiply and accumulate";
+
+  let description = [{
+    The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
+    `D = matmul(A_sparse, B) + C` using all threads in a warp.
+
+    This operation is similar to `nvvm.mma.sync` but with structured sparsity
+    in the A operand. The sparsity follows the 2:4 structured sparse pattern
+    where 2 out of every 4 elements are non-zero.
+
+    All the threads in the warp must execute the same `mma.sp.sync` operation.
+
+    The `sparseMetadata` operand provides the sparsity indices that indicate
+    which elements in the A operand are non-zero. The `sparsitySelector`
+    controls how the indices are distributed among threads in the warp and
+    should typically be 0 or 1.
+
+    The optional `metadataType` attribute specifies the metadata ordering:
+    - `standard` (default): Uses standard sparse metadata ordering
+    - `ordered`: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
+    
+    The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
+    - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
+    - Only valid with ordered metadata and m16n8k64 shape
+
+    The shapes, layouts, and data types follow the same constraints as the
+    regular `nvvm.mma.sync` operation, but the A operand contains only the
+    non-zero elements in compressed format.
+
+    Example:
+    ```mlir
+    %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                          sparseMetadata[%meta] selector[%sel]
+                          {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+        : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+    
+    // With ordered metadata:
+    %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                          sparseMetadata[%meta] selector[%sel]
+                          {metadataType = #nvvm.mma_sp_metadata<ordered>,
+                           shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
+        : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+    ```
+  }];
+
+  let results = (outs LLVM_AnyStruct:$res);
+  let arguments = (ins NVVM_MMAShapeAttr:$shape,
+             OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
+             OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+             OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+             OptionalAttr<MMASpMetadataAttr>:$metadataType,
+             OptionalAttr<MMAKindAttr>:$kind,
+             Variadic<LLVM_Type>:$operandA,
+             Variadic<LLVM_Type>:$operandB,
+             Variadic<LLVM_Type>:$operandC,
+             I32:$sparseMetadata,
+             I32:$sparsitySelector);
+
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            std::optional<MMAIntOverflow> satf,
+            std::optional<MMASpMetadata> metadata,
+            std::optional<MMAKind> kind,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+        llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
+        bool orderedMetadata = metadata.has_value() &&
+                               *metadata == MMASpMetadata::ordered;
+        bool hasKind = kind.has_value();
+        }],
+        MMA_SP_SYNC_INTR<>.id, [{
+          return 0;
+      }
+
+      static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
+        bool isAccumulator);
+
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+
+      static mlir::NVVM::IDArgPair
+      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                            llvm::IRBuilderBase& builder);
+    }]);
+
+  let builders = [
+      OpBuilder<(ins  "Type":$resultType, "ValueRange":$operandA,
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "Value":$sparseMetadata, "Value":$sparsitySelector,
+        "ArrayRef<int64_t>":$shape,
+        "std::optional<MMAIntOverflow>":$intOverflow,
+        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
+    ];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..8db724dd0a25b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -940,6 +940,480 @@ LogicalResult MmaOp::verify() {
   return success();
 }
 
+MMATypes MmaSpOp::accumPtxType() {
+  std::optional<mlir::NVVM::MMATypes> val = MmaOp::inferOperandMMAType(
+      getODSOperands(2).getTypes().front(), /*isAccumulator=*/true);
+  assert(val.has_value() && "accumulator PTX type should always be inferrable");
+  return val.value();
+}
+
+MMATypes MmaSpOp::resultPtxType() {
+  std::optional<mlir::NVVM::MMATypes> val =
+      MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true);
+  assert(val.has_value() && "result PTX type should always be inferrable");
+  return val.value();
+}
+
+mlir::NVVM::IDArgPair
+MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::MmaSpOp>(op);
+
+  // Get operands
+  llvm::SmallVector<llvm::Value *> args;
+  for (mlir::Value v : thisOp.getOperands())
+    args.push_back(mt.lookupValue(v));
+
+  // Get intrinsic ID using the existing getIntrinsicID method
+  auto intId = MmaSpOp::getIntrinsicID(
+      thisOp.getShape().getM(), thisOp.getShape().getN(), thisOp.getShape().getK(),
+      thisOp.getIntOverflowBehavior(),
+      thisOp.getMetadataType(),
+      thisOp.getKind(),
+      *thisOp.getMultiplicandAPtxType(),
+      *thisOp.getMultiplicandBPtxType(),
+      thisOp.accumPtxType(),
+      thisOp.resultPtxType());
+
+  return {intId, args};
+}
+
+void MmaSpOp::print(OpAsmPrinter &p) {
+  SmallVector<Type, 4> regTypes;
+  struct OperandFragment {
+    StringRef operandName;
+    StringRef ptxTypeAttr;
+    SmallVector<Value, 4> regs;
+    explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+        : operandName(name), ptxTypeAttr(ptxTypeName) {}
+  };
+
+  std::array<OperandFragment, 5> frags{
+      OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+      OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+      OperandFragment("C", ""),
+      OperandFragment("sparseMetadata", ""),
+      OperandFragment("selector", "")};
+  SmallVector<StringRef, 4> ignoreAttrNames{
+      mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
+
+  // Handle variadic operands A, B, C
+  for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) {
+    auto &frag = frags[fragIdx];
+    auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
+    for (auto operandIdx = varOperandSpec.first;
+         operandIdx < varOperandSpec.first + varOperandSpec.second;
+         operandIdx++) {
+      frag.regs.push_back(this->getOperand(operandIdx));
+      if (operandIdx == varOperandSpec.first) {
+        regTypes.push_back(this->getOperand(operandIdx).getType());
+      }
+    }
+    std::optional<MMATypes> inferredType =
+        MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+    if (inferredType)
+      ignoreAttrNames.push_back(frag.ptxTypeAttr);
+  }
+
+  // Handle sparse metadata and selector (single operands)
+  frags[3].regs.push_back(getSparseMetadata());
+  frags[4].regs.push_back(getSparsitySelector());
+
+  auto printMmaSpOperand = [&](const OperandFragment &frag) -> void {
+    p << " " << frag.operandName;
+    p << "[";
+    p.printOperands(frag.regs);
+    p << "]";
+  };
+
+  for (const auto &frag : frags)
+    printMmaSpOperand(frag);
+
+  p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames);
+  p << " : ";
+  p << "(";
+  for (int i = 0; i < 3; ++i) {
+    p << regTypes[i];
+    if (i < 2) p << ", ";
+  }
+  p << ") -> " << getResult().getType();
+}
+
+void MmaSpOp::build(OpBuilder &builder, OperationState &result,
+                Type resultType, ValueRange operandA, ValueRange operandB,
+                ValueRange operandC, Value sparseMetadata, Value sparsitySelector,
+                ArrayRef<int64_t> shape,
+                std::optional<MMAIntOverflow> intOverflow,
+                std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+
+  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+  MLIRContext *ctx = builder.getContext();
+  result.addAttribute(
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+
+  result.addOperands(operandA);
+  result.addOperands(operandB);
+  result.addOperands(operandC);
+  result.addOperands(sparseMetadata);
+  result.addOperands(sparsitySelector);
+
+  if (multiplicandPtxTypes) {
+    result.addAttribute("multiplicandAPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+    result.addAttribute("multiplicandBPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+  } else {
+    if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+    if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+  }
+
+  if (intOverflow.has_value())
+    result.addAttribute("intOverflowBehavior",
+                        MMAIntOverflowAttr::get(ctx, *intOverflow));
+
+  result.addTypes(resultType);
+  result.addAttribute(
+      MmaSpOp::getOperandSegmentSizeAttr(),
+      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
+                                    static_cast<int32_t>(operandB.size()),
+                                    static_cast<int32_t>(operandC.size()),
+                                    1, 1})); // sparseMetadata and sparsitySelector
+}
+
+ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
+  struct OperandFragment {
+    std::optional<MMATypes> elemtype;
+    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+    SmallVector<Type> regTypes;
+  };
+
+  Builder &builder = parser.getBuilder();
+  std::array<OperandFragment, 6> frags; // A, B, C, sparseMetadata, selector
+
+  NamedAttrList namedAttributes;
+
+  // A helper to parse the operand segments.
+  auto parseMmaSpOperand = [&](StringRef operandName,
+                               OperandFragment &frag) -> LogicalResult {
+    if (parser.parseKeyword(operandName).failed())
+      return failure();
+    if (parser
+            .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
+            .failed())
+      return failure();
+    return success();
+  };
+
+  // Parse the operand segments.
+  if (parseMmaSpOperand("A", frags[0]).failed())
+    return failure();
+  if (parseMmaSpOperand("B", frags[1]).failed(...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/168686


More information about the Mlir-commits mailing list