[Mlir-commits] [mlir] [MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (PR #168686)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 19 01:46:50 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e0f1087f5..dd864fe9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -969,7 +969,7 @@ MMATypes MmaSpOp::resultPtxType() {
mlir::NVVM::IDArgPair
MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
- llvm::IRBuilderBase &builder) {
+ llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MmaSpOp>(op);
// Get operands
@@ -979,14 +979,11 @@ MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
// Get intrinsic ID using the existing getIntrinsicID method
auto intId = MmaSpOp::getIntrinsicID(
- thisOp.getShape().getM(), thisOp.getShape().getN(), thisOp.getShape().getK(),
- thisOp.getIntOverflowBehavior(),
- thisOp.getMetadataType(),
- thisOp.getKind(),
- *thisOp.getMultiplicandAPtxType(),
- *thisOp.getMultiplicandBPtxType(),
- thisOp.accumPtxType(),
- thisOp.resultPtxType());
+ thisOp.getShape().getM(), thisOp.getShape().getN(),
+ thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(),
+ thisOp.getMetadataType(), thisOp.getKind(),
+ *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(),
+ thisOp.accumPtxType(), thisOp.resultPtxType());
return {intId, args};
}
@@ -1004,8 +1001,7 @@ void MmaSpOp::print(OpAsmPrinter &p) {
std::array<OperandFragment, 5> frags{
OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
- OperandFragment("C", ""),
- OperandFragment("sparseMetadata", ""),
+ OperandFragment("C", ""), OperandFragment("sparseMetadata", ""),
OperandFragment("selector", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()};
@@ -1022,8 +1018,8 @@ void MmaSpOp::print(OpAsmPrinter &p) {
regTypes.push_back(this->getOperand(operandIdx).getType());
}
}
- std::optional<MMATypes> inferredType =
- MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
@@ -1047,17 +1043,18 @@ void MmaSpOp::print(OpAsmPrinter &p) {
p << "(";
for (int i = 0; i < 3; ++i) {
p << regTypes[i];
- if (i < 2) p << ", ";
+ if (i < 2)
+ p << ", ";
}
p << ") -> " << getResult().getType();
}
-void MmaSpOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, ValueRange operandA, ValueRange operandB,
- ValueRange operandC, Value sparseMetadata, Value sparsitySelector,
- ArrayRef<int64_t> shape,
- std::optional<MMAIntOverflow> intOverflow,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+void MmaSpOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, ArrayRef<int64_t> shape,
+ std::optional<MMAIntOverflow> intOverflow,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
MLIRContext *ctx = builder.getContext();
@@ -1091,8 +1088,8 @@ void MmaSpOp::build(OpBuilder &builder, OperationState &result,
MmaSpOp::getOperandSegmentSizeAttr(),
builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size()),
- 1, 1})); // sparseMetadata and sparsitySelector
+ static_cast<int32_t>(operandC.size()), 1,
+ 1})); // sparseMetadata and sparsitySelector
}
ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1155,23 +1152,27 @@ ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
parser.getNameLoc(), result.operands)))
return failure();
- frag.elemtype = MmaOp::inferOperandMMAType(frag.regTypes[0],
- /*isAccumulator*/ iter.index() >= 2);
+ frag.elemtype =
+ MmaOp::inferOperandMMAType(frag.regTypes[0],
+ /*isAccumulator*/ iter.index() >= 2);
}
Type resultType;
if (parser.parseArrow() || parser.parseType(resultType))
return failure();
- frags[5].elemtype = MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
+ frags[5].elemtype =
+ MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true);
// Resolve sparse metadata and selector (assume i32 type)
Type i32Type = builder.getIntegerType(32);
- if (parser.resolveOperands(frags[3].regs, i32Type,
- parser.getCurrentLocation(), result.operands)
+ if (parser
+ .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
.failed())
return failure();
- if (parser.resolveOperands(frags[4].regs, i32Type,
- parser.getCurrentLocation(), result.operands)
+ if (parser
+ .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(),
+ result.operands)
.failed())
return failure();
@@ -1316,7 +1317,8 @@ LogicalResult MmaSpOp::verify() {
expectedC.emplace_back(4, f32Ty);
}
- // For sparse MMA, A operand is compressed (2:4 sparsity means half the elements)
+ // For sparse MMA, A operand is compressed (2:4 sparsity means half the
+ // elements)
int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2;
int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
expectedA.emplace_back(unitA, multiplicandFragType);
``````````
</details>
https://github.com/llvm/llvm-project/pull/168686
More information about the Mlir-commits
mailing list