[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)
Mehdi Amini
llvmlistbot at llvm.org
Wed Dec 10 01:21:12 PST 2025
================
@@ -1559,6 +1559,638 @@ LogicalResult MmaSpOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct MMAOperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+ ArrayRef<Value> operands) {
+ p << " " << name << "[";
+ p.printOperands(operands);
+ p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be
+// inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(op.getOperand(operandIdx));
+ if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+ regTypes.push_back(op.getOperand(operandIdx).getType());
+ }
+ }
+ if (fragIdx < 2) {
+ regTypes.push_back(frag.regs[0].getType());
+ }
+ std::optional<MMATypes> inferredType =
+ MmaOp::inferOperandMMAType(regTypes.back(),
+ /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &operandTypes) {
+ if (parser.parseColon().failed() || parser.parseLParen().failed())
+ return failure();
+
+ for (int i = 0; i < 3; i++) {
+ if (i > 0 && parser.parseComma().failed())
+ return failure();
+ Type ty;
+ if (parser.parseType(ty).failed())
+ return failure();
+ operandTypes.push_back(ty);
+ }
+
+ return parser.parseRParen();
+}
+
+// Helper to infer and set multiplicand PTX type attributes
+void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+ const SmallVectorImpl<Type> &operandTypes) {
+ if (!attrs.get("multiplicandAPtxType")) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[0], false)) {
+ attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+ if (!attrs.get("multiplicandBPtxType")) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[1], false)) {
+ attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+}
+
+// Helper to add common block scale attributes
+void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+ ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute("blockScaleFormat",
+ BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+ result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+}
----------------
joker-eph wrote:
Can this be done using properties instead of attributes?
https://github.com/llvm/llvm-project/pull/170566
More information about the Mlir-commits
mailing list