[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)

Mehdi Amini llvmlistbot at llvm.org
Wed Dec 10 01:21:12 PST 2025


================
@@ -1559,6 +1559,638 @@ LogicalResult MmaSpOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct MMAOperandFragment {
+  StringRef operandName;
+  StringRef ptxTypeAttr;
+  SmallVector<Value, 4> regs;
+  explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
+      : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+                      ArrayRef<Value> operands) {
+  p << " " << name << "[";
+  p.printOperands(operands);
+  p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regs) {
+  if (parser.parseKeyword(operandName).failed())
+    return failure();
+  if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+          .failed())
+    return failure();
+  return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be
+// inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
+                             SmallVectorImpl<Type> &regTypes,
+                             SmallVectorImpl<StringRef> &ignoreAttrNames) {
+  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+    auto &frag = frags[fragIdx];
+    auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+    for (auto operandIdx = varOperandSpec.first;
+         operandIdx < varOperandSpec.first + varOperandSpec.second;
+         operandIdx++) {
+      frag.regs.push_back(op.getOperand(operandIdx));
+      if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+        regTypes.push_back(op.getOperand(operandIdx).getType());
+      }
+    }
+    if (fragIdx < 2) {
+      regTypes.push_back(frag.regs[0].getType());
+    }
+    std::optional<MMATypes> inferredType =
+        MmaOp::inferOperandMMAType(regTypes.back(),
+                                   /*isAccumulator=*/fragIdx >= 2);
+    if (inferredType)
+      ignoreAttrNames.push_back(frag.ptxTypeAttr);
+  }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+                                    SmallVectorImpl<Type> &operandTypes) {
+  if (parser.parseColon().failed() || parser.parseLParen().failed())
+    return failure();
+
+  for (int i = 0; i < 3; i++) {
+    if (i > 0 && parser.parseComma().failed())
+      return failure();
+    Type ty;
+    if (parser.parseType(ty).failed())
+      return failure();
+    operandTypes.push_back(ty);
+  }
+
+  return parser.parseRParen();
+}
+
+// Helper to infer and set multiplicand PTX type attributes
+void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+                                  const SmallVectorImpl<Type> &operandTypes) {
+  if (!attrs.get("multiplicandAPtxType")) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[0], false)) {
+      attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
+    }
+  }
+  if (!attrs.get("multiplicandBPtxType")) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[1], false)) {
+      attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
+    }
+  }
+}
+
+// Helper to add common block scale attributes
+void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+                             ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
+                             BlockScaleFormat blockScaleFormat,
+                             MMABlockScaleKind kind) {
+  MLIRContext *ctx = builder.getContext();
+  result.addAttribute(
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+  result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
+  result.addAttribute("blockScaleFormat",
+                      BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+  result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+}
----------------
joker-eph wrote:

Can this be done using properties instead of attributes?

https://github.com/llvm/llvm-project/pull/170566


More information about the Mlir-commits mailing list