[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 19 01:46:50 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --diff_from_common_commit
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e0f1087f5..dd864fe9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -969,7 +969,7 @@ MMATypes MmaSpOp::resultPtxType() {
 
 mlir::NVVM::IDArgPair
 MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
-                                llvm::IRBuilderBase &builder) {
+                               llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::MmaSpOp>(op);
 
   // Get operands
@@ -979,14 +979,11 @@ MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
 
   // Get intrinsic ID using the existing getIntrinsicID method
   auto intId = MmaSpOp::getIntrinsicID(
-      thisOp.getShape().getM(), thisOp.getShape().getN(), thisOp.getShape().getK(),
-      thisOp.getIntOverflowBehavior(),
-      thisOp.getMetadataType(),
-      thisOp.getKind(),
-      *thisOp.getMultiplicandAPtxType(),
-      *thisOp.getMultiplicandBPtxType(),
-      thisOp.accumPtxType(),
-      thisOp.resultPtxType());
+      thisOp.getShape().getM(), thisOp.getShape().getN(),
+      thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+      thisOp.getMetadataType(), thisOp.getKind(),
+      *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+      thisOp.accumPtxType(), thisOp.resultPtxType());
 
   return {intId, args};
 }
@@ -1004,8 +1001,7 @@ void MmaSpOp::print(OpAsmPrinter &p) {
   std::array<OperandFragment, 5> frags{
       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
-      OperandFragment("C", ""),
-      OperandFragment("sparseMetadata", ""),
+      OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
       OperandFragment("selector", "")};
   SmallVector<StringRef, 4> ignoreAttrNames{
       mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
@@ -1022,8 +1018,8 @@ void MmaSpOp::print(OpAsmPrinter &p) {
         regTypes.push_back(this->getOperand(operandIdx).getType());
       }
     }
-    std::optional<MMATypes> inferredType =
-        MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+    std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+        regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
     if (inferredType)
       ignoreAttrNames.push_back(frag.ptxTypeAttr);
   }
@@ -1047,17 +1043,18 @@ void MmaSpOp::print(OpAsmPrinter &p) {
   p << "(";
   for (int i = 0; i < 3; ++i) {
     p << regTypes[i];
-    if (i < 2) p << ", ";
+    if (i < 2)
+      p << ", ";
   }
   p << ") -> " << getResult().getType();
 }
 
-void MmaSpOp::build(OpBuilder &builder, OperationState &result,
-                Type resultType, ValueRange operandA, ValueRange operandB,
-                ValueRange operandC, Value sparseMetadata, Value sparsitySelector,
-                ArrayRef<int64_t> shape,
-                std::optional<MMAIntOverflow> intOverflow,
-                std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+void MmaSpOp::build(
+    OpBuilder &builder, OperationState &result, Type resultType,
+    ValueRange operandA, ValueRange operandB, ValueRange operandC,
+    Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+    std::optional<MMAIntOverflow> intOverflow,
+    std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
 
   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
   MLIRContext *ctx = builder.getContext();
@@ -1091,8 +1088,8 @@ void MmaSpOp::build(OpBuilder &builder, OperationState &result,
       MmaSpOp::getOperandSegmentSizeAttr(),
       builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
                                     static_cast<int32_t>(operandB.size()),
-                                    static_cast<int32_t>(operandC.size()),
-                                    1, 1})); // sparseMetadata and sparsitySelector
+                                    static_cast<int32_t>(operandC.size()), 1,
+                                    1})); // sparseMetadata and sparsitySelector
 }
 
 ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1155,23 +1152,27 @@ ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
                                       parser.getNameLoc(), result.operands)))
       return failure();
-    frag.elemtype = MmaOp::inferOperandMMAType(frag.regTypes[0],
-                                               /*isAccumulator*/ iter.index() >= 2);
+    frag.elemtype =
+        MmaOp::inferOperandMMAType(frag.regTypes[0],
+                                   /*isAccumulator*/ iter.index() >= 2);
   }
 
   Type resultType;
   if (parser.parseArrow() || parser.parseType(resultType))
     return failure();
-  frags[5].elemtype = MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+  frags[5].elemtype =
+      MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
 
   // Resolve sparse metadata and selector (assume i32 type)
   Type i32Type = builder.getIntegerType(32);
-  if (parser.resolveOperands(frags[3].regs, i32Type,
-                             parser.getCurrentLocation(), result.operands)
+  if (parser
+          .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+                           result.operands)
           .failed())
     return failure();
-  if (parser.resolveOperands(frags[4].regs, i32Type,
-                             parser.getCurrentLocation(), result.operands)
+  if (parser
+          .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+                           result.operands)
           .failed())
     return failure();
 
@@ -1316,7 +1317,8 @@ LogicalResult MmaSpOp::verify() {
       expectedC.emplace_back(4, f32Ty);
     }
 
-    // For sparse MMA, A operand is compressed (2:4 sparsity means half the elements)
+    // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+    // elements)
     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
     expectedA.emplace_back(unitA, multiplicandFragType);

``````````

</details>


https://github.com/llvm/llvm-project/pull/168686


More information about the Mlir-commits mailing list