[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)
Kirill Vedernikov
llvmlistbot at llvm.org
Tue Dec 9 21:15:19 PST 2025
================
@@ -1559,6 +1559,650 @@ LogicalResult MmaSpOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct OperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+ ArrayRef<Value> operands) {
+ p << " " << name << "[";
+ p.printOperands(operands);
+ p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be
+// inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(op.getOperand(operandIdx));
+ if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+ regTypes.push_back(op.getOperand(operandIdx).getType());
+ }
+ }
+ if (fragIdx < 2) {
+ regTypes.push_back(frag.regs[0].getType());
+ }
+ std::optional<MMATypes> inferredType =
+ MmaOp::inferOperandMMAType(regTypes.back(),
+ /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &operandTypes) {
+ if (parser.parseColon().failed() || parser.parseLParen().failed())
+ return failure();
+
+ for (int i = 0; i < 3; i++) {
+ if (i > 0 && parser.parseComma().failed())
+ return failure();
+ Type ty;
+ if (parser.parseType(ty).failed())
+ return failure();
+ operandTypes.push_back(ty);
+ }
+
+ return parser.parseRParen();
+}
+
+// Helper to infer and set multiplicand PTX type attributes
+void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+ const SmallVectorImpl<Type> &operandTypes) {
+ if (!attrs.get("multiplicandAPtxType")) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[0], false)) {
+ attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+ if (!attrs.get("multiplicandBPtxType")) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[1], false)) {
+ attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
+ }
+ }
+}
+
+// Helper to add common block scale attributes
+void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+ ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
+ BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ MLIRContext *ctx = builder.getContext();
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute("blockScaleFormat",
+ BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+ result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+}
+
+// Helper to infer and add multiplicand PTX types to builder
+void addInferredMultiplicandTypes(
+ MLIRContext *ctx, OperationState &result, ValueRange operandA,
+ ValueRange operandB,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+ if (multiplicandPtxTypes) {
+ result.addAttribute("multiplicandAPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ result.addAttribute("multiplicandBPtxType",
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ } else {
+ if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+ if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+ }
+}
+
+// Template helper for common accumPtxType/resultPtxType implementation
+template <typename OpTy>
+MMATypes inferPtxTypeFromResult(OpTy op) {
+ return *MmaOp::inferOperandMMAType(
+ cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
+ /*isAccumulator=*/true);
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// MmaBlockScaleOp
+//===----------------------------------------------------------------------===//
+
+void MmaBlockScaleOp::print(OpAsmPrinter &p) {
+ SmallVector<Type, 4> regTypes;
+ std::array<OperandFragment, 3> frags{
+ OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+ OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+ OperandFragment("C", "")};
+ SmallVector<StringRef, 4> ignoreAttrNames{
+ mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
+
+ processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
+
+ // Print A, B, C operands
+ for (const auto &frag : frags)
+ printOperandList(p, frag.operandName, frag.regs);
+
+ // Print scale operands
+ printOperandList(p, "scaleA",
+ {getScaleAData(), getByteIdA(), getThreadIdA()});
+ printOperandList(p, "scaleB",
+ {getScaleBData(), getByteIdB(), getThreadIdB()});
+
+ p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+ // Print type signature
+ p << " : (";
+ llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+ frags[1].regs[0].getType(),
+ frags[2].regs[0].getType()},
+ p);
+ p << ")";
+ p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ struct LocalOperandFragment {
+ std::optional<MMATypes> elemtype;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+ };
+
+ Builder &builder = parser.getBuilder();
+ std::array<LocalOperandFragment, 3> frags;
+ NamedAttrList namedAttributes;
+
+ // Parse A[...] B[...] C[...]
+ if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
+ parseMmaOperand(parser, "B", frags[1].regs).failed() ||
+ parseMmaOperand(parser, "C", frags[2].regs).failed())
+ return failure();
+
+ // Parse scale operands: scaleA[...] scaleB[...]
+ SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
+ if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
+ parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
+ return failure();
+
+ if (parser.parseOptionalAttrDict(namedAttributes).failed())
+ return failure();
+
+ // Parse type signature
+ SmallVector<Type, 3> operandTypes;
+ if (parseMmaTypeSignature(parser, operandTypes).failed())
+ return failure();
+
+ // Parse result type
+ SmallVector<Type, 1> resultTypes;
+ if (parser.parseArrowTypeList(resultTypes).failed())
+ return failure();
+
+ // Infer element types and resolve operands
+ for (const auto &[idx, frag] : llvm::enumerate(frags)) {
+ frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
+ /*isAccumulator=*/idx >= 2);
+ if (parser
+ .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+ result.operands)
+ .failed())
+ return failure();
+ }
+
+ // Resolve scale operands
+ SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
+ builder.getI16Type()};
+ if (parser
+ .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed())
+ return failure();
+
+ // Add attributes
+ result.addAttributes(namedAttributes);
+ inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+ operandTypes);
+
+ result.addTypes(resultTypes);
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+ return success();
+}
+
+void MmaBlockScaleOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
+ Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
+ assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+
+ addBlockScaleAttributes(builder, result, shape, scaleVecSize,
+ blockScaleFormat, kind);
+
+ result.addOperands(operandA);
+ result.addOperands(operandB);
+ result.addOperands(operandC);
+ result.addOperands(
+ {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
+
+ addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+ multiplicandPtxTypes);
+
+ result.addTypes(resultType);
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
+}
+
+MMATypes MmaBlockScaleOp::accumPtxType() {
----------------
kvederni wrote:
I'm sorry. I missed it. It was fixed in [d78c4b9](https://github.com/llvm/llvm-project/pull/170566/commits/d78c4b996bd2a41fcab652c35cce5a9034bc7417)
https://github.com/llvm/llvm-project/pull/170566
More information about the Mlir-commits
mailing list