[Mlir-commits] [mlir] [MLIR][NVVM] Support for dense and sparse MMA with block scaling (PR #170566)

Kirill Vedernikov llvmlistbot at llvm.org
Tue Dec 9 06:52:59 PST 2025


================
@@ -1559,6 +1559,650 @@ LogicalResult MmaSpOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MMA Block Scale Operations - Shared Helpers
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Shared structure for MMA operand fragments (A, B, C)
+struct OperandFragment {
+  StringRef operandName;
+  StringRef ptxTypeAttr;
+  SmallVector<Value, 4> regs;
+  explicit OperandFragment(StringRef name, StringRef ptxTypeName)
+      : operandName(name), ptxTypeAttr(ptxTypeName) {}
+};
+
+// Helper to print operand list in the format: name[operands]
+void printOperandList(OpAsmPrinter &p, StringRef name,
+                      ArrayRef<Value> operands) {
+  p << " " << name << "[";
+  p.printOperands(operands);
+  p << "]";
+}
+
+// Helper to parse operand list in the format: name[operands]
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regs) {
+  if (parser.parseKeyword(operandName).failed())
+    return failure();
+  if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+          .failed())
+    return failure();
+  return success();
+}
+
+// Helper to process operand fragments and determine which attributes can be
+// inferred
+template <typename Op>
+void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
+                             SmallVectorImpl<Type> &regTypes,
+                             SmallVectorImpl<StringRef> &ignoreAttrNames) {
+  for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
+    auto &frag = frags[fragIdx];
+    auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
+    for (auto operandIdx = varOperandSpec.first;
+         operandIdx < varOperandSpec.first + varOperandSpec.second;
+         operandIdx++) {
+      frag.regs.push_back(op.getOperand(operandIdx));
+      if (fragIdx == 0 && operandIdx == varOperandSpec.first) {
+        regTypes.push_back(op.getOperand(operandIdx).getType());
+      }
+    }
+    if (fragIdx < 2) {
+      regTypes.push_back(frag.regs[0].getType());
+    }
+    std::optional<MMATypes> inferredType =
+        MmaOp::inferOperandMMAType(regTypes.back(),
+                                   /*isAccumulator=*/fragIdx >= 2);
+    if (inferredType)
+      ignoreAttrNames.push_back(frag.ptxTypeAttr);
+  }
+}
+
+// Helper to parse type signature: (A_type, B_type, C_type)
+LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
+                                    SmallVectorImpl<Type> &operandTypes) {
+  if (parser.parseColon().failed() || parser.parseLParen().failed())
+    return failure();
+
+  for (int i = 0; i < 3; i++) {
+    if (i > 0 && parser.parseComma().failed())
+      return failure();
+    Type ty;
+    if (parser.parseType(ty).failed())
+      return failure();
+    operandTypes.push_back(ty);
+  }
+
+  return parser.parseRParen();
+}
+
+// Helper to infer and set multiplicand PTX type attributes
+void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
+                                  const SmallVectorImpl<Type> &operandTypes) {
+  if (!attrs.get("multiplicandAPtxType")) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[0], false)) {
+      attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
+    }
+  }
+  if (!attrs.get("multiplicandBPtxType")) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[1], false)) {
+      attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
+    }
+  }
+}
+
+// Helper to add common block scale attributes
+void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
+                             ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
+                             BlockScaleFormat blockScaleFormat,
+                             MMABlockScaleKind kind) {
+  MLIRContext *ctx = builder.getContext();
+  result.addAttribute(
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+  result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
+  result.addAttribute("blockScaleFormat",
+                      BlockScaleFormatAttr::get(ctx, blockScaleFormat));
+  result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
+}
+
+// Helper to infer and add multiplicand PTX types to builder
+void addInferredMultiplicandTypes(
+    MLIRContext *ctx, OperationState &result, ValueRange operandA,
+    ValueRange operandB,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+  if (multiplicandPtxTypes) {
+    result.addAttribute("multiplicandAPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+    result.addAttribute("multiplicandBPtxType",
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+  } else {
+    if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
+      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
+    if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
+      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
+  }
+}
+
+// Template helper for common accumPtxType/resultPtxType implementation
+template <typename OpTy>
+MMATypes inferPtxTypeFromResult(OpTy op) {
+  return *MmaOp::inferOperandMMAType(
+      cast<LLVM::LLVMStructType>(op.getRes().getType()).getBody()[0],
+      /*isAccumulator=*/true);
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// MmaBlockScaleOp
+//===----------------------------------------------------------------------===//
+
+void MmaBlockScaleOp::print(OpAsmPrinter &p) {
+  SmallVector<Type, 4> regTypes;
+  std::array<OperandFragment, 3> frags{
+      OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+      OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+      OperandFragment("C", "")};
+  SmallVector<StringRef, 4> ignoreAttrNames{
+      mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()};
+
+  processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
+
+  // Print A, B, C operands
+  for (const auto &frag : frags)
+    printOperandList(p, frag.operandName, frag.regs);
+
+  // Print scale operands
+  printOperandList(p, "scaleA",
+                   {getScaleAData(), getByteIdA(), getThreadIdA()});
+  printOperandList(p, "scaleB",
+                   {getScaleBData(), getByteIdB(), getThreadIdB()});
+
+  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+  // Print type signature
+  p << " : (";
+  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+                                             frags[1].regs[0].getType(),
+                                             frags[2].regs[0].getType()},
+                        p);
+  p << ")";
+  p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
+                                   OperationState &result) {
+  struct LocalOperandFragment {
+    std::optional<MMATypes> elemtype;
+    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+  };
+
+  Builder &builder = parser.getBuilder();
+  std::array<LocalOperandFragment, 3> frags;
+  NamedAttrList namedAttributes;
+
+  // Parse A[...] B[...] C[...]
+  if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
+      parseMmaOperand(parser, "B", frags[1].regs).failed() ||
+      parseMmaOperand(parser, "C", frags[2].regs).failed())
+    return failure();
+
+  // Parse scale operands: scaleA[...] scaleB[...]
+  SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
+  if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
+      parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
+    return failure();
+
+  if (parser.parseOptionalAttrDict(namedAttributes).failed())
+    return failure();
+
+  // Parse type signature
+  SmallVector<Type, 3> operandTypes;
+  if (parseMmaTypeSignature(parser, operandTypes).failed())
+    return failure();
+
+  // Parse result type
+  SmallVector<Type, 1> resultTypes;
+  if (parser.parseArrowTypeList(resultTypes).failed())
+    return failure();
+
+  // Infer element types and resolve operands
+  for (const auto &[idx, frag] : llvm::enumerate(frags)) {
+    frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
+                                               /*isAccumulator=*/idx >= 2);
+    if (parser
+            .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+                             result.operands)
+            .failed())
+      return failure();
+  }
+
+  // Resolve scale operands
+  SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
+                                     builder.getI16Type()};
+  if (parser
+          .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed())
+    return failure();
+
+  // Add attributes
+  result.addAttributes(namedAttributes);
+  inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+                               operandTypes);
+
+  result.addTypes(resultTypes);
+  result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
+  return success();
+}
+
+void MmaBlockScaleOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType,
+    ValueRange operandA, ValueRange operandB, ValueRange operandC,
+    Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
+    Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+    ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+    MMABlockScaleKind kind) {
+  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+
+  addBlockScaleAttributes(builder, result, shape, scaleVecSize,
+                          blockScaleFormat, kind);
+
+  result.addOperands(operandA);
+  result.addOperands(operandB);
+  result.addOperands(operandC);
+  result.addOperands(
+      {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
+
+  addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+                               multiplicandPtxTypes);
+
+  result.addTypes(resultType);
+  result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(operandA.size()),
+                          static_cast<int32_t>(operandB.size()),
+                          static_cast<int32_t>(operandC.size()),
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
+}
+
+MMATypes MmaBlockScaleOp::accumPtxType() {
+  return inferPtxTypeFromResult(*this);
+}
+
+MMATypes MmaBlockScaleOp::resultPtxType() {
+  return inferPtxTypeFromResult(*this);
+}
+
+NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::MmaBlockScaleOp>(op);
+
+  SmallVector<llvm::Value *> args;
+  // Add A, B, C operands
+  for (Value operand : curOp.getOperandA())
+    args.push_back(mt.lookupValue(operand));
+  for (Value operand : curOp.getOperandB())
+    args.push_back(mt.lookupValue(operand));
+  for (Value operand : curOp.getOperandC())
+    args.push_back(mt.lookupValue(operand));
+
+  // Add scale operands
+  args.push_back(mt.lookupValue(curOp.getScaleAData()));
+  args.push_back(mt.lookupValue(curOp.getByteIdA()));
+  args.push_back(mt.lookupValue(curOp.getThreadIdA()));
+  args.push_back(mt.lookupValue(curOp.getScaleBData()));
+  args.push_back(mt.lookupValue(curOp.getByteIdB()));
+  args.push_back(mt.lookupValue(curOp.getThreadIdB()));
+
+  unsigned intId = MmaBlockScaleOp::getIntrinsicID(
+      curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
+      *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
+      curOp.accumPtxType(), curOp.getScaleVecSize(),
+      curOp.getBlockScaleFormat(), curOp.getKind());
+
+  return {intId, args};
+}
+
+LogicalResult MmaBlockScaleOp::verify() {
+  LogicalResult result = success();
+  int m = getShape().getM();
+  int n = getShape().getN();
+  int k = getShape().getK();
+
+  if (m == 16 && n == 8 && k == 64) {
+    if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
+        getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
+      result = emitOpError(
+          "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)");
+    if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
+      if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
+        result = emitOpError(
+            "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4");
+      if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
+        result = emitOpError(
+            "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
+    } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
+      if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+            (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+        result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
+                             "attributes for mma.m16n8k64.mxf4nvf4");
+    } else
+      result = emitOpError("unsupported Kind attribute for mma.m16n8k64");
+  } else if (m == 16 && n == 8 && k == 32) {
+    if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
+          getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
+          getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
+      result =
+          emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+                      "attributes for mma.m16n8k32");
+  } else
+    result = emitOpError("unsupported Geom for mma with block scaling");
+  return result;
+}
+
+//===----------------------------------------------------------------------===//
+// MmaSpBlockScaleOp
+//===----------------------------------------------------------------------===//
+
+void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
+  SmallVector<Type, 4> regTypes;
+  std::array<OperandFragment, 3> frags{
+      OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
+      OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
+      OperandFragment("C", "")};
+  SmallVector<StringRef, 4> ignoreAttrNames{
+      mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()};
+
+  processOperandFragments(*this, frags, regTypes, ignoreAttrNames);
+
+  // Print A, B, C operands
+  for (const auto &frag : frags)
+    printOperandList(p, frag.operandName, frag.regs);
+
+  // Print sparse-specific operands
+  printOperandList(p, "sparseMetadata", {getSparseMetadata()});
+  printOperandList(p, "selector", {getSparsitySelector()});
+
+  // Print scale operands
+  printOperandList(p, "scaleA",
+                   {getScaleAData(), getByteIdA(), getThreadIdA()});
+  printOperandList(p, "scaleB",
+                   {getScaleBData(), getByteIdB(), getThreadIdB()});
+
+  p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
+
+  // Print type signature
+  p << " : (";
+  llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
+                                             frags[1].regs[0].getType(),
+                                             frags[2].regs[0].getType()},
+                        p);
+  p << ")";
+  p.printArrowTypeList(TypeRange{this->getRes().getType()});
+}
+
+ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  struct LocalOperandFragment {
+    std::optional<MMATypes> elemtype;
+    SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
+  };
+
+  Builder &builder = parser.getBuilder();
+  std::array<LocalOperandFragment, 3> frags;
+  NamedAttrList namedAttributes;
+
+  // Parse A[...] B[...] C[...]
+  if (parseMmaOperand(parser, "A", frags[0].regs).failed() ||
+      parseMmaOperand(parser, "B", frags[1].regs).failed() ||
+      parseMmaOperand(parser, "C", frags[2].regs).failed())
+    return failure();
+
+  // Parse sparse-specific operands
+  SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands,
+      selectorOperands;
+  if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
+      parseMmaOperand(parser, "selector", selectorOperands).failed())
+    return failure();
+
+  // Parse scale operands
+  SmallVector<OpAsmParser::UnresolvedOperand, 3> scaleAOperands, scaleBOperands;
+  if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() ||
+      parseMmaOperand(parser, "scaleB", scaleBOperands).failed())
+    return failure();
+
+  if (parser.parseOptionalAttrDict(namedAttributes).failed())
+    return failure();
+
+  // Parse type signature
+  SmallVector<Type, 3> operandTypes;
+  if (parseMmaTypeSignature(parser, operandTypes).failed())
+    return failure();
+
+  // Parse result type
+  SmallVector<Type, 1> resultTypes;
+  if (parser.parseArrowTypeList(resultTypes).failed())
+    return failure();
+
+  // Infer element types and resolve operands
+  for (const auto &[idx, frag] : llvm::enumerate(frags)) {
+    frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
+                                               /*isAccumulator=*/idx >= 2);
+    if (parser
+            .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+                             result.operands)
+            .failed())
+      return failure();
+  }
+
+  // Resolve sparse metadata and selector
+  Type i32Type = builder.getI32Type();
+  if (parser
+          .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
+                           result.operands)
+          .failed())
+    return failure();
+
+  // Resolve scale operands
+  SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
+                                     builder.getI16Type()};
+  if (parser
+          .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed())
+    return failure();
+
+  // Add attributes
+  result.addAttributes(namedAttributes);
+  inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+                               operandTypes);
+
+  // orderedMetadata is mandatory
+  if (!result.attributes.get("orderedMetadata"))
+    result.addAttribute("orderedMetadata", builder.getUnitAttr());
+
+  result.addTypes(resultTypes);
+  result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                          1, // sparseMetadata
+                          1, // sparsitySelector
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
+  return success();
+}
+
+void MmaSpBlockScaleOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType,
+    ValueRange operandA, ValueRange operandB, ValueRange operandC,
+    Value sparseMetadata, Value sparsitySelector, Value scaleAData,
+    Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
+    Value threadIdB, ArrayRef<int64_t> shape,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+    ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+    MMABlockScaleKind kind) {
+  assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
+
+  addBlockScaleAttributes(builder, result, shape, scaleVecSize,
+                          blockScaleFormat, kind);
+  result.addAttribute("orderedMetadata", builder.getUnitAttr());
+
+  result.addOperands(operandA);
+  result.addOperands(operandB);
+  result.addOperands(operandC);
+  result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
+                      threadIdA, scaleBData, byteIdB, threadIdB});
+
+  addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+                               multiplicandPtxTypes);
+
+  result.addTypes(resultType);
+  result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(operandA.size()),
+                          static_cast<int32_t>(operandB.size()),
+                          static_cast<int32_t>(operandC.size()),
+                          1, // sparseMetadata
+                          1, // sparsitySelector
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
+}
+
+MMATypes MmaSpBlockScaleOp::accumPtxType() {
+  return inferPtxTypeFromResult(*this);
+}
+
+MMATypes MmaSpBlockScaleOp::resultPtxType() {
+  return inferPtxTypeFromResult(*this);
+}
+
+NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::MmaSpBlockScaleOp>(op);
+
+  SmallVector<llvm::Value *> args;
+  // Add A, B, C operands
+  for (Value operand : curOp.getOperandA())
+    args.push_back(mt.lookupValue(operand));
+  for (Value operand : curOp.getOperandB())
+    args.push_back(mt.lookupValue(operand));
+  for (Value operand : curOp.getOperandC())
+    args.push_back(mt.lookupValue(operand));
+
+  // Add sparse metadata and selector
+  args.push_back(mt.lookupValue(curOp.getSparseMetadata()));
+  args.push_back(mt.lookupValue(curOp.getSparsitySelector()));
+
+  // Add scale operands
+  args.push_back(mt.lookupValue(curOp.getScaleAData()));
+  args.push_back(mt.lookupValue(curOp.getByteIdA()));
+  args.push_back(mt.lookupValue(curOp.getThreadIdA()));
+  args.push_back(mt.lookupValue(curOp.getScaleBData()));
+  args.push_back(mt.lookupValue(curOp.getByteIdB()));
+  args.push_back(mt.lookupValue(curOp.getThreadIdB()));
+
+  unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
+      curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
+      *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
+      curOp.accumPtxType(), curOp.getScaleVecSize(),
+      curOp.getBlockScaleFormat(), curOp.getKind());
+
+  return {intId, args};
+}
+
+LogicalResult MmaSpBlockScaleOp::verify() {
+  // Check that orderedMetadata is present
+  if (!getOrderedMetadata()) {
+    return emitOpError("'orderedMetadata' attribute is mandatory");
+  }
+
+  LogicalResult result = success();
+  int m = getShape().getM();
+  int n = getShape().getN();
+  int k = getShape().getK();
+
+  if (m == 16 && n == 8 && k == 128) {
+    if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 ||
+        getMultiplicandBPtxType() != NVVM::MMATypes::e2m1)
+      result = emitOpError(
+          "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)");
+    if (getKind() == NVVM::MMABlockScaleKind::MXF4) {
+      if (getScaleVecSize() != NVVM::ScaleVecSize::X2)
+        result = emitOpError(
+            "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4");
+      if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0)
+        result = emitOpError(
+            "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
+    } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
+      if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+            (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+        result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
+                             "attributes for mma.m16n8k128.mxf4nvf4");
+    } else
+      result = emitOpError("unsupported Kind attribute for mma.m16n8k128");
+  } else if (m == 16 && n == 8 && k == 64) {
+    if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
+          getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
+          getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
+      result =
+          emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+                      "attributes for mma.m16n8k64");
+  } else
----------------
kvederni wrote:

The braces were added in [5febdda](https://github.com/llvm/llvm-project/pull/170566/commits/5febdda3f2742ddec22d3b299ea7e378c14e22d8).

https://github.com/llvm/llvm-project/pull/170566


More information about the Mlir-commits mailing list