[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)
Kirill Vedernikov
llvmlistbot at llvm.org
Wed Dec 10 07:16:59 PST 2025
================
@@ -1559,6 +1559,638 @@ LogicalResult MmaSpOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct MMAOperandFragment {
+ StringRef operandName;
+ StringRef ptxTypeAttr;
+ SmallVector<Value, 4> regs;
+ explicit MMAOperandFragment(StringRef name, StringRef ptxTypeName)
+ : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+ ArrayRef<Value> operands) {
+ p << " " << name << "[";
+ p.printOperands(operands);
+ p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+ if (parser.parseKeyword(operandName).failed())
+ return failure();
+ if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
+ return failure();
+ return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be
+// inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<MMAOperandFragment, 3> &frags,
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+ auto &frag = frags[fragIdx];
+ auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+ for (auto operandIdx = varOperandSpec.first;
+ operandIdx < varOperandSpec.first + varOperandSpec.second;
+ operandIdx++) {
+ frag.regs.push_back(op.getOperand(operandIdx));
+ if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+ regTypes.push_back(op.getOperand(operandIdx).getType());
+ }
+ }
+ if (fragIdx < 2) {
+ regTypes.push_back(frag.regs[0].getType());
+ }
+ std::optional<MMATypes> inferredType =
+ MmaOp::inferOperandMMAType(regTypes.back(),
+ /*isAccumulator=*/fragIdx >= 2);
+ if (inferredType)
+ ignoreAttrNames.push_back(frag.ptxTypeAttr);
+ }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+ SmallVectorImpl<Type> &operandTypes) {
+ if (parser.parseColon().failed() || parser.parseLParen().failed())
+ return failure();
+
+ for (int i = 0; i < 3; i++) {
+ if (i > 0 && parser.parseComma().failed())
+ return failure();
+ Type ty;
+ if (parser.parseType(ty).failed())
+ return failure();
----------------
kvederni wrote:
It's suitable here. The PR has been updated.
https://github.com/llvm/llvm-project/pull/170566
More information about the Mlir-commits
mailing list