[Mlir-commits] [mlir] [MLIR] Support for dense and sparse MMA with block scaling (PR #170566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 3 14:01:47 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --diff_from_common_commit
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a06fe19b8..3387c6902 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -693,8 +693,8 @@ void MmaOp::print(OpAsmPrinter &p) {
         regTypes.push_back(this->getOperand(operandIdx).getType());
       }
     }
-    std::optional<MMATypes> inferredType =
-        MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+    std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+        regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
     if (inferredType)
       ignoreAttrNames.push_back(frag.ptxTypeAttr);
   }
@@ -1582,21 +1582,23 @@ void printOperandList(OpAsmPrinter &p, StringRef name,
 }
 
 // Helper to parse operand list in the format: name[operands]
-LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName,
-                              SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regs) {
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+                SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regs) {
   if (parser.parseKeyword(operandName).failed())
     return failure();
-  if (parser.parseOperandList(regs,
-                              OpAsmParser::Delimiter::OptionalSquare).failed())
+  if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+          .failed())
     return failure();
   return success();
 }
 
-// Helper to process operand fragments and determine which attributes can be inferred
+// Helper to process operand fragments and determine which attributes can be
+// inferred
 template <typename Op>
 void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
-                              SmallVectorImpl<Type> &regTypes,
-                              SmallVectorImpl<StringRef> &ignoreAttrNames) {
+                             SmallVectorImpl<Type> &regTypes,
+                             SmallVectorImpl<StringRef> &ignoreAttrNames) {
   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
     auto &frag = frags[fragIdx];
     auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
@@ -1639,16 +1641,16 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
 
 // Helper to infer and set multiplicand PTX type attributes
 void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
-                                   const SmallVectorImpl<Type> &operandTypes) {
+                                  const SmallVectorImpl<Type> &operandTypes) {
   if (!attrs.get("multiplicandAPtxType")) {
-    if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0],
-                                                       false)) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[0], false)) {
       attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
     }
   }
   if (!attrs.get("multiplicandBPtxType")) {
-    if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1],
-                                                       false)) {
+    if (auto inferredType =
+            MmaOp::inferOperandMMAType(operandTypes[1], false)) {
       attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
     }
   }
@@ -1656,37 +1658,33 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
 
 // Helper to add common block scale attributes
 void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
-                             ArrayRef<int64_t> shape,
-                             ScaleVecSize scaleVecSize,
+                             ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
                              BlockScaleFormat blockScaleFormat,
                              MMABlockScaleKind kind) {
   MLIRContext *ctx = builder.getContext();
-  result.addAttribute("shape",
-                     builder.getAttr<MMAShapeAttr>(shape[0], shape[1],
-                                                   shape[2]));
-  result.addAttribute("scaleVecSize",
-                      ScaleVecSizeAttr::get(ctx, scaleVecSize));
+  result.addAttribute(
+      "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+  result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
   result.addAttribute("blockScaleFormat",
                       BlockScaleFormatAttr::get(ctx, blockScaleFormat));
   result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
 }
 
 // Helper to infer and add multiplicand PTX types to builder
-void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result,
-                                  ValueRange operandA, ValueRange operandB,
-                                  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+void addInferredMultiplicandTypes(
+    MLIRContext *ctx, OperationState &result, ValueRange operandA,
+    ValueRange operandB,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
   if (multiplicandPtxTypes) {
     result.addAttribute("multiplicandAPtxType",
-                       MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
     result.addAttribute("multiplicandBPtxType",
-                       MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+                        MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
   } else {
     if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
-      result.addAttribute("multiplicandAPtxType",
-                          MMATypesAttr::get(ctx, *res));
+      result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
     if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
-      result.addAttribute("multiplicandBPtxType",
-                          MMATypesAttr::get(ctx, *res));
+      result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
   }
 }
 
@@ -1730,7 +1728,8 @@ void MmaBlockScaleOp::print(OpAsmPrinter &p) {
   p << " : (";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
-                                             frags[2].regs[0].getType()}, p);
+                                             frags[2].regs[0].getType()},
+                        p);
   p << ")";
   p.printArrowTypeList(TypeRange{this->getRes().getType()});
 }
@@ -1775,52 +1774,55 @@ ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
   for (const auto &[idx, frag] : llvm::enumerate(frags)) {
     frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
                                                /*isAccumulator=*/idx >= 2);
-    if (parser.resolveOperands(frag.regs, operandTypes[idx],
-                               parser.getNameLoc(), result.operands).failed())
+    if (parser
+            .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+                             result.operands)
+            .failed())
       return failure();
   }
 
   // Resolve scale operands
-  SmallVector<Type, 3> scaleTypes = {
-    builder.getI32Type(), builder.getI16Type(), builder.getI16Type()
-  };
-  if (parser.resolveOperands(scaleAOperands, scaleTypes,
-                             parser.getNameLoc(), result.operands).failed() ||
-      parser.resolveOperands(scaleBOperands, scaleTypes,
-                             parser.getNameLoc(), result.operands).failed())
+  SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
+                                     builder.getI16Type()};
+  if (parser
+          .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed())
     return failure();
 
   // Add attributes
   result.addAttributes(namedAttributes);
-  inferAndSetMultiplicandTypes(parser.getContext(),
-                               result.attributes, operandTypes);
+  inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+                               operandTypes);
 
   result.addTypes(resultTypes);
-  result.addAttribute(
-      MmaBlockScaleOp::getOperandSegmentSizeAttr(),
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
-                                    static_cast<int32_t>(frags[1].regs.size()),
-                                    static_cast<int32_t>(frags[2].regs.size()),
-                                    1, // scaleAData
-                                    1, // byteIdA
-                                    1, // threadIdA
-                                    1, // scaleBData
-                                    1, // byteIdB
-                                    1  // threadIdB
-                                   }));
+  result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
   return success();
 }
 
-void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
-                            Type resultType, ValueRange operandA,
-                            ValueRange operandB, ValueRange operandC,
-                            Value scaleAData, Value byteIdA, Value threadIdA,
-                            Value scaleBData, Value byteIdB, Value threadIdB,
-                            ArrayRef<int64_t> shape,
-                            std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
-                            ScaleVecSize scaleVecSize,
-                            BlockScaleFormat blockScaleFormat,
-                            MMABlockScaleKind kind) {
+void MmaBlockScaleOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType,
+    ValueRange operandA, ValueRange operandB, ValueRange operandC,
+    Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
+    Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+    ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+    MMABlockScaleKind kind) {
   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
 
   addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -1829,25 +1831,25 @@ void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
   result.addOperands(operandA);
   result.addOperands(operandB);
   result.addOperands(operandC);
-  result.addOperands({scaleAData, byteIdA, threadIdA,
-                      scaleBData, byteIdB, threadIdB});
+  result.addOperands(
+      {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
 
-  addInferredMultiplicandTypes(builder.getContext(), result,
-                               operandA, operandB, multiplicandPtxTypes);
+  addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+                               multiplicandPtxTypes);
 
   result.addTypes(resultType);
-  result.addAttribute(
-      MmaBlockScaleOp::getOperandSegmentSizeAttr(),
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
-                                    static_cast<int32_t>(operandB.size()),
-                                    static_cast<int32_t>(operandC.size()),
-                                    1, // scaleAData
-                                    1, // byteIdA
-                                    1, // threadIdA
-                                    1, // scaleBData
-                                    1, // byteIdB
-                                    1  // threadIdB
-                                   }));
+  result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(operandA.size()),
+                          static_cast<int32_t>(operandB.size()),
+                          static_cast<int32_t>(operandC.size()),
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
 }
 
 MMATypes MmaBlockScaleOp::accumPtxType() {
@@ -1882,8 +1884,8 @@ NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
   unsigned intId = MmaBlockScaleOp::getIntrinsicID(
       curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
       *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
-      curOp.accumPtxType(),
-      curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+      curOp.accumPtxType(), curOp.getScaleVecSize(),
+      curOp.getBlockScaleFormat(), curOp.getKind());
 
   return {intId, args};
 }
@@ -1908,9 +1910,9 @@ LogicalResult MmaBlockScaleOp::verify() {
             "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
     } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
       if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
-           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
-          (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
-           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+            (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
         result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
                              "attributes for mma.m16n8k64.mxf4nvf4");
     } else
@@ -1919,8 +1921,9 @@ LogicalResult MmaBlockScaleOp::verify() {
     if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
           getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
-      result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
-                           "attributes for mma.m16n8k32");
+      result =
+          emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+                      "attributes for mma.m16n8k32");
   } else
     result = emitOpError("unsupported Geom for mma with block scaling");
   return result;
@@ -1961,7 +1964,8 @@ void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
   p << " : (";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
-                                             frags[2].regs[0].getType()}, p);
+                                             frags[2].regs[0].getType()},
+                        p);
   p << ")";
   p.printArrowTypeList(TypeRange{this->getRes().getType()});
 }
@@ -1984,7 +1988,8 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
     return failure();
 
   // Parse sparse-specific operands
-  SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands, selectorOperands;
+  SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands,
+      selectorOperands;
   if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
       parseMmaOperand(parser, "selector", selectorOperands).failed())
     return failure();
@@ -2012,67 +2017,74 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
   for (const auto &[idx, frag] : llvm::enumerate(frags)) {
     frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
                                                /*isAccumulator=*/idx >= 2);
-    if (parser.resolveOperands(frag.regs, operandTypes[idx],
-                               parser.getNameLoc(), result.operands).failed())
+    if (parser
+            .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+                             result.operands)
+            .failed())
       return failure();
   }
 
   // Resolve sparse metadata and selector
   Type i32Type = builder.getI32Type();
-  if (parser.resolveOperands(metadataOperands, i32Type,
-                             parser.getNameLoc(), result.operands).failed() ||
-      parser.resolveOperands(selectorOperands, i32Type,
-                             parser.getNameLoc(), result.operands).failed())
+  if (parser
+          .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
+                           result.operands)
+          .failed())
     return failure();
 
   // Resolve scale operands
-  SmallVector<Type, 3> scaleTypes = {
-    i32Type, builder.getI16Type(), builder.getI16Type()
-  };
-  if (parser.resolveOperands(scaleAOperands, scaleTypes,
-                             parser.getNameLoc(), result.operands).failed() ||
-      parser.resolveOperands(scaleBOperands, scaleTypes,
-                             parser.getNameLoc(), result.operands).failed())
+  SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
+                                     builder.getI16Type()};
+  if (parser
+          .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed() ||
+      parser
+          .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+                           result.operands)
+          .failed())
     return failure();
 
   // Add attributes
   result.addAttributes(namedAttributes);
-  inferAndSetMultiplicandTypes(parser.getContext(),
-                               result.attributes, operandTypes);
+  inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+                               operandTypes);
 
   // orderedMetadata is mandatory
   if (!result.attributes.get("orderedMetadata"))
     result.addAttribute("orderedMetadata", builder.getUnitAttr());
 
   result.addTypes(resultTypes);
-  result.addAttribute(
-      MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
-                                    static_cast<int32_t>(frags[1].regs.size()),
-                                    static_cast<int32_t>(frags[2].regs.size()),
-                                    1, // sparseMetadata
-                                    1, // sparsitySelector
-                                    1, // scaleAData
-                                    1, // byteIdA
-                                    1, // threadIdA
-                                    1, // scaleBData
-                                    1, // byteIdB
-                                    1  // threadIdB
-                                   }));
+  result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(frags[0].regs.size()),
+                          static_cast<int32_t>(frags[1].regs.size()),
+                          static_cast<int32_t>(frags[2].regs.size()),
+                          1, // sparseMetadata
+                          1, // sparsitySelector
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
   return success();
 }
 
-void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
-                              Type resultType, ValueRange operandA,
-                              ValueRange operandB, ValueRange operandC,
-                              Value sparseMetadata, Value sparsitySelector,
-                              Value scaleAData, Value byteIdA, Value threadIdA,
-                              Value scaleBData, Value byteIdB, Value threadIdB,
-                              ArrayRef<int64_t> shape,
-                              std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
-                              ScaleVecSize scaleVecSize,
-                              BlockScaleFormat blockScaleFormat,
-                              MMABlockScaleKind kind) {
+void MmaSpBlockScaleOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType,
+    ValueRange operandA, ValueRange operandB, ValueRange operandC,
+    Value sparseMetadata, Value sparsitySelector, Value scaleAData,
+    Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
+    Value threadIdB, ArrayRef<int64_t> shape,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+    ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+    MMABlockScaleKind kind) {
   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
 
   addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -2082,28 +2094,27 @@ void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
   result.addOperands(operandA);
   result.addOperands(operandB);
   result.addOperands(operandC);
-  result.addOperands({sparseMetadata, sparsitySelector,
-                     scaleAData, byteIdA, threadIdA,
-                     scaleBData, byteIdB, threadIdB});
+  result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
+                      threadIdA, scaleBData, byteIdB, threadIdB});
 
-  addInferredMultiplicandTypes(builder.getContext(), result,
-                               operandA, operandB, multiplicandPtxTypes);
+  addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+                               multiplicandPtxTypes);
 
   result.addTypes(resultType);
-  result.addAttribute(
-      MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
-      builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
-                                    static_cast<int32_t>(operandB.size()),
-                                    static_cast<int32_t>(operandC.size()),
-                                    1, // sparseMetadata
-                                    1, // sparsitySelector
-                                    1, // scaleAData
-                                    1, // byteIdA
-                                    1, // threadIdA
-                                    1, // scaleBData
-                                    1, // byteIdB
-                                    1  // threadIdB
-                                   }));
+  result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr({
+                          static_cast<int32_t>(operandA.size()),
+                          static_cast<int32_t>(operandB.size()),
+                          static_cast<int32_t>(operandC.size()),
+                          1, // sparseMetadata
+                          1, // sparsitySelector
+                          1, // scaleAData
+                          1, // byteIdA
+                          1, // threadIdA
+                          1, // scaleBData
+                          1, // byteIdB
+                          1  // threadIdB
+                      }));
 }
 
 MMATypes MmaSpBlockScaleOp::accumPtxType() {
@@ -2142,8 +2153,8 @@ NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
   unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
       curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
       *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
-      curOp.accumPtxType(),
-      curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+      curOp.accumPtxType(), curOp.getScaleVecSize(),
+      curOp.getBlockScaleFormat(), curOp.getKind());
 
   return {intId, args};
 }
@@ -2173,9 +2184,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
             "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
     } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
       if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
-           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
-          (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
-           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+            (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+             getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
         result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
                              "attributes for mma.m16n8k128.mxf4nvf4");
     } else
@@ -2184,8 +2195,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
     if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
           getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
           getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
-      result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
-                           "attributes for mma.m16n8k64");
+      result =
+          emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+                      "attributes for mma.m16n8k64");
   } else
     result = emitOpError("unsupported Geom for sparse mma with block scaling");
   return result;

``````````

</details>


https://github.com/llvm/llvm-project/pull/170566


More information about the Mlir-commits mailing list