[Mlir-commits] [mlir] [MLIR] Support for dense and sparse MMA with block scaling (PR #170566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 3 14:01:47 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --diff_from_common_commit
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a06fe19b8..3387c6902 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -693,8 +693,8 @@ void MmaOp::print(OpAsmPrinter &p) {
regTypes.push_back(this->getOperand(operandIdx).getType());
}
}
- std::optional<MMATypes> inferredType =
- MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
+ std::optional<MMATypes> inferredType = MmaOp::inferOperandMMAType(
+ regTypes.back(), /*isAccumulator=*/fragIdx >= 2);
if (inferredType)
ignoreAttrNames.push_back(frag.ptxTypeAttr);
}
@@ -1582,21 +1582,23 @@ void printOperandList(OpAsmPrinter &p, StringRef name,
}
// Helper to parse operand list in the format: name[operands]
-LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
+LogicalResult
+parseMmaOperand(OpAsmParser &parser, StringRef operandName,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®s) {
if (parser.parseKeyword(operandName).failed())
return failure();
- if (parser.parseOperandList(regs,
- OpAsmParser::Delimiter::OptionalSquare).failed())
+ if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare)
+ .failed())
return failure();
return success();
}
-// Helper to process operand fragments and determine which attributes can be inferred
+// Helper to process operand fragments and determine which attributes can be
+// inferred
template <typename Op>
void processOperandFragments(Op &op, std::array<OperandFragment, 3> &frags,
- SmallVectorImpl<Type> ®Types,
- SmallVectorImpl<StringRef> &ignoreAttrNames) {
+ SmallVectorImpl<Type> ®Types,
+ SmallVectorImpl<StringRef> &ignoreAttrNames) {
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx);
@@ -1639,16 +1641,16 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser,
// Helper to infer and set multiplicand PTX type attributes
void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
- const SmallVectorImpl<Type> &operandTypes) {
+ const SmallVectorImpl<Type> &operandTypes) {
if (!attrs.get("multiplicandAPtxType")) {
- if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0],
- false)) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[0], false)) {
attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
if (!attrs.get("multiplicandBPtxType")) {
- if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1],
- false)) {
+ if (auto inferredType =
+ MmaOp::inferOperandMMAType(operandTypes[1], false)) {
attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType));
}
}
@@ -1656,37 +1658,33 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs,
// Helper to add common block scale attributes
void addBlockScaleAttributes(OpBuilder &builder, OperationState &result,
- ArrayRef<int64_t> shape,
- ScaleVecSize scaleVecSize,
+ ArrayRef<int64_t> shape, ScaleVecSize scaleVecSize,
BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind) {
MLIRContext *ctx = builder.getContext();
- result.addAttribute("shape",
- builder.getAttr<MMAShapeAttr>(shape[0], shape[1],
- shape[2]));
- result.addAttribute("scaleVecSize",
- ScaleVecSizeAttr::get(ctx, scaleVecSize));
+ result.addAttribute(
+ "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
+ result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize));
result.addAttribute("blockScaleFormat",
BlockScaleFormatAttr::get(ctx, blockScaleFormat));
result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));
}
// Helper to infer and add multiplicand PTX types to builder
-void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result,
- ValueRange operandA, ValueRange operandB,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
+void addInferredMultiplicandTypes(
+ MLIRContext *ctx, OperationState &result, ValueRange operandA,
+ ValueRange operandB,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes) {
if (multiplicandPtxTypes) {
result.addAttribute("multiplicandAPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
result.addAttribute("multiplicandBPtxType",
- MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
+ MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false))
- result.addAttribute("multiplicandAPtxType",
- MMATypesAttr::get(ctx, *res));
+ result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false))
- result.addAttribute("multiplicandBPtxType",
- MMATypesAttr::get(ctx, *res));
+ result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
}
}
@@ -1730,7 +1728,8 @@ void MmaBlockScaleOp::print(OpAsmPrinter &p) {
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
- frags[2].regs[0].getType()}, p);
+ frags[2].regs[0].getType()},
+ p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
@@ -1775,52 +1774,55 @@ ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser,
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
- if (parser.resolveOperands(frag.regs, operandTypes[idx],
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
}
// Resolve scale operands
- SmallVector<Type, 3> scaleTypes = {
- builder.getI32Type(), builder.getI16Type(), builder.getI16Type()
- };
- if (parser.resolveOperands(scaleAOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(scaleBOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed())
+ SmallVector<Type, 3> scaleTypes = {builder.getI32Type(), builder.getI16Type(),
+ builder.getI16Type()};
+ if (parser
+ .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
- inferAndSetMultiplicandTypes(parser.getContext(),
- result.attributes, operandTypes);
+ inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+ operandTypes);
result.addTypes(resultTypes);
- result.addAttribute(
- MmaBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
- static_cast<int32_t>(frags[1].regs.size()),
- static_cast<int32_t>(frags[2].regs.size()),
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
return success();
}
-void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, ValueRange operandA,
- ValueRange operandB, ValueRange operandC,
- Value scaleAData, Value byteIdA, Value threadIdA,
- Value scaleBData, Value byteIdB, Value threadIdB,
- ArrayRef<int64_t> shape,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
- ScaleVecSize scaleVecSize,
- BlockScaleFormat blockScaleFormat,
- MMABlockScaleKind kind) {
+void MmaBlockScaleOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData,
+ Value byteIdB, Value threadIdB, ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -1829,25 +1831,25 @@ void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
- result.addOperands({scaleAData, byteIdA, threadIdA,
- scaleBData, byteIdB, threadIdB});
+ result.addOperands(
+ {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB});
- addInferredMultiplicandTypes(builder.getContext(), result,
- operandA, operandB, multiplicandPtxTypes);
+ addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+ multiplicandPtxTypes);
result.addTypes(resultType);
- result.addAttribute(
- MmaBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
- static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size()),
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
}
MMATypes MmaBlockScaleOp::accumPtxType() {
@@ -1882,8 +1884,8 @@ NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(),
- curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+ curOp.accumPtxType(), curOp.getScaleVecSize(),
+ curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
@@ -1908,9 +1910,9 @@ LogicalResult MmaBlockScaleOp::verify() {
"unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
- (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64.mxf4nvf4");
} else
@@ -1919,8 +1921,9 @@ LogicalResult MmaBlockScaleOp::verify() {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
- result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
- "attributes for mma.m16n8k32");
+ result =
+ emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k32");
} else
result = emitOpError("unsupported Geom for mma with block scaling");
return result;
@@ -1961,7 +1964,8 @@ void MmaSpBlockScaleOp::print(OpAsmPrinter &p) {
p << " : (";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
- frags[2].regs[0].getType()}, p);
+ frags[2].regs[0].getType()},
+ p);
p << ")";
p.printArrowTypeList(TypeRange{this->getRes().getType()});
}
@@ -1984,7 +1988,8 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
return failure();
// Parse sparse-specific operands
- SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands, selectorOperands;
+ SmallVector<OpAsmParser::UnresolvedOperand, 1> metadataOperands,
+ selectorOperands;
if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() ||
parseMmaOperand(parser, "selector", selectorOperands).failed())
return failure();
@@ -2012,67 +2017,74 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser,
for (const auto &[idx, frag] : llvm::enumerate(frags)) {
frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx],
/*isAccumulator=*/idx >= 2);
- if (parser.resolveOperands(frag.regs, operandTypes[idx],
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
}
// Resolve sparse metadata and selector
Type i32Type = builder.getI32Type();
- if (parser.resolveOperands(metadataOperands, i32Type,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(selectorOperands, i32Type,
- parser.getNameLoc(), result.operands).failed())
+ if (parser
+ .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Resolve scale operands
- SmallVector<Type, 3> scaleTypes = {
- i32Type, builder.getI16Type(), builder.getI16Type()
- };
- if (parser.resolveOperands(scaleAOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed() ||
- parser.resolveOperands(scaleBOperands, scaleTypes,
- parser.getNameLoc(), result.operands).failed())
+ SmallVector<Type, 3> scaleTypes = {i32Type, builder.getI16Type(),
+ builder.getI16Type()};
+ if (parser
+ .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed() ||
+ parser
+ .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(),
+ result.operands)
+ .failed())
return failure();
// Add attributes
result.addAttributes(namedAttributes);
- inferAndSetMultiplicandTypes(parser.getContext(),
- result.attributes, operandTypes);
+ inferAndSetMultiplicandTypes(parser.getContext(), result.attributes,
+ operandTypes);
// orderedMetadata is mandatory
if (!result.attributes.get("orderedMetadata"))
result.addAttribute("orderedMetadata", builder.getUnitAttr());
result.addTypes(resultTypes);
- result.addAttribute(
- MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(frags[0].regs.size()),
- static_cast<int32_t>(frags[1].regs.size()),
- static_cast<int32_t>(frags[2].regs.size()),
- 1, // sparseMetadata
- 1, // sparsitySelector
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(frags[0].regs.size()),
+ static_cast<int32_t>(frags[1].regs.size()),
+ static_cast<int32_t>(frags[2].regs.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
return success();
}
-void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
- Type resultType, ValueRange operandA,
- ValueRange operandB, ValueRange operandC,
- Value sparseMetadata, Value sparsitySelector,
- Value scaleAData, Value byteIdA, Value threadIdA,
- Value scaleBData, Value byteIdB, Value threadIdB,
- ArrayRef<int64_t> shape,
- std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
- ScaleVecSize scaleVecSize,
- BlockScaleFormat blockScaleFormat,
- MMABlockScaleKind kind) {
+void MmaSpBlockScaleOp::build(
+ OpBuilder &builder, OperationState &result, Type resultType,
+ ValueRange operandA, ValueRange operandB, ValueRange operandC,
+ Value sparseMetadata, Value sparsitySelector, Value scaleAData,
+ Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB,
+ Value threadIdB, ArrayRef<int64_t> shape,
+ std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
+ ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat,
+ MMABlockScaleKind kind) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
addBlockScaleAttributes(builder, result, shape, scaleVecSize,
@@ -2082,28 +2094,27 @@ void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(operandA);
result.addOperands(operandB);
result.addOperands(operandC);
- result.addOperands({sparseMetadata, sparsitySelector,
- scaleAData, byteIdA, threadIdA,
- scaleBData, byteIdB, threadIdB});
+ result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA,
+ threadIdA, scaleBData, byteIdB, threadIdB});
- addInferredMultiplicandTypes(builder.getContext(), result,
- operandA, operandB, multiplicandPtxTypes);
+ addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB,
+ multiplicandPtxTypes);
result.addTypes(resultType);
- result.addAttribute(
- MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr({static_cast<int32_t>(operandA.size()),
- static_cast<int32_t>(operandB.size()),
- static_cast<int32_t>(operandC.size()),
- 1, // sparseMetadata
- 1, // sparsitySelector
- 1, // scaleAData
- 1, // byteIdA
- 1, // threadIdA
- 1, // scaleBData
- 1, // byteIdB
- 1 // threadIdB
- }));
+ result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ static_cast<int32_t>(operandA.size()),
+ static_cast<int32_t>(operandB.size()),
+ static_cast<int32_t>(operandC.size()),
+ 1, // sparseMetadata
+ 1, // sparsitySelector
+ 1, // scaleAData
+ 1, // byteIdA
+ 1, // threadIdA
+ 1, // scaleBData
+ 1, // byteIdB
+ 1 // threadIdB
+ }));
}
MMATypes MmaSpBlockScaleOp::accumPtxType() {
@@ -2142,8 +2153,8 @@ NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs(
unsigned intId = MmaSpBlockScaleOp::getIntrinsicID(
curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(),
*curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(),
- curOp.accumPtxType(),
- curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind());
+ curOp.accumPtxType(), curOp.getScaleVecSize(),
+ curOp.getBlockScaleFormat(), curOp.getKind());
return {intId, args};
}
@@ -2173,9 +2184,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
"unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4");
} else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) {
if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
- (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
- getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) ||
+ (getScaleVecSize() == NVVM::ScaleVecSize::X4 &&
+ getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3)))
result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k128.mxf4nvf4");
} else
@@ -2184,8 +2195,9 @@ LogicalResult MmaSpBlockScaleOp::verify() {
if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 &&
getScaleVecSize() == NVVM::ScaleVecSize::X1 &&
getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0))
- result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
- "attributes for mma.m16n8k64");
+ result =
+ emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
+ "attributes for mma.m16n8k64");
} else
result = emitOpError("unsupported Geom for sparse mma with block scaling");
return result;
``````````
</details>
https://github.com/llvm/llvm-project/pull/170566
More information about the Mlir-commits
mailing list