[Mlir-commits] [mlir] 708185f - [mlir][NVGPU] Add support for structured sparsity MMA variants
Christopher Bate
llvmlistbot at llvm.org
Mon Nov 7 08:43:10 PST 2022
Author: Christopher Bate
Date: 2022-11-07T09:43:03-07:00
New Revision: 708185f03ff480b3481132802b7b63461564f0ab
URL: https://github.com/llvm/llvm-project/commit/708185f03ff480b3481132802b7b63461564f0ab
DIFF: https://github.com/llvm/llvm-project/commit/708185f03ff480b3481132802b7b63461564f0ab.diff
LOG: [mlir][NVGPU] Add support for structured sparsity MMA variants
This change adds a new NVGPU operation that targets the PTX `mma.sp.sync`
instruction variants. A lowering to NVVM is provided using inline
assembly.
Reviewed By: ThomasRaoux, manishucsd
Differential Revision: https://reviews.llvm.org/D137202
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/NVGPU/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 138ffc896cb2a..db4ee53252fb3 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -98,10 +98,24 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [
let hasVerifier = 1;
}
-def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
- Pure,
- PredOpTrait<"matrixA and matrixB have same element type",
- TCopVTEtIsSameAs<0, 1>>]> {
+class NVGPU_MmaSyncOp<string mnemonic> :
+ NVGPU_Op<mnemonic, [Pure,
+ PredOpTrait<"matrixA and matrixB have same element type",
+ TCopVTEtIsSameAs<0, 1>>]> {
+ code extraBaseClassDeclaration = [{
+ std::array<int64_t, 3> getMmaShapeAsArray() {
+ ArrayAttr mmaShape = this->getMmaShape();
+ assert(mmaShape.size() == 3 && "mmaShape should be three integers");
+ return {mmaShape[0].cast<IntegerAttr>().getInt(),
+ mmaShape[1].cast<IntegerAttr>().getInt(),
+ mmaShape[2].cast<IntegerAttr>().getInt()};
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
let description = [{
The `nvgpu.mma.sync` op represents the warp-level matrix-multiply-and-
accumulate (mma) operation that is compatible with `nvvm.mma.sync`.
@@ -143,9 +157,63 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
- let hasVerifier = 1;
+ let extraClassDeclaration = extraBaseClassDeclaration;
}
+def NVGPU_MmaSparseSyncMetadataType : FixedVectorOfLengthAndType<[2], [I16]>,
+ BuildableType<"::mlir::VectorType::get("
+ "{2},$_builder.getI16Type())">;
+
+def NVGPU_MmaSparseSyncOp : NVGPU_MmaSyncOp<"mma.sp.sync"> {
+ let description = [{
+ The `nvgu.mma.sp.sync` operation performs a warp-distributed MMA operation
+ where operand A is "structured sparse". In this case, the `matrixA` operand
+ represents the (warp-distributed) non-zero values of operand A, and the
+ `sparse_metadata` operand provides the indices.
+
+ The full description of the sparsity storage format and distribution scheme is
+ described in the PTX docs. This operation is meant to follow the semantic
+ described in the PTX documentation here:
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
+
+ The way the indices are distributed among the threads in a warp is controlled
+ by the optional `sparsity_selector` operand, which is `0` by default. For
+ more information, please consult the PTX documentation linked above.
+
+ Example (targetingthe f16 16x8x32 `mma.sp` PTX instruction):
+
+ ```mlir
+ nvgpu.mma.sp.sync (%a, %b, %c) metadata (%meta) {mmaShape = [16, 8, 32]} :
+ (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ ```
+ }];
+
+ let arguments = (ins AnyVector:$matrixA,
+ AnyVector:$matrixB,
+ AnyVector:$matrixC,
+ NVGPU_MmaSparseSyncMetadataType:$sparseMetadata,
+ I64ArrayAttr:$mmaShape,
+ DefaultValuedAttr<I32Attr, "0">:$sparsitySelector,
+ OptionalAttr<UnitAttr>:$tf32Enabled
+ );
+
+ let results = (outs AnyVector:$res);
+
+ let builders = [
+ OpBuilder<(ins "Value":$matrixA,
+ "Value":$matrixB,
+ "Value":$matrixC,
+ "Value":$sparseMetadata,
+ "ArrayRef<int64_t>":$mmaShape)>
+ ];
+
+ let assemblyFormat = [{
+ `(` $matrixA`,` $matrixB`,` $matrixC `)` `metadata` `(` $sparseMetadata `)` attr-dict
+ `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
+ }];
+
+ let extraClassDeclaration = extraBaseClassDeclaration;
+}
def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
AttrSizedOperandSegments]> {
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index c4c49f2edd5ff..d9f54b8cb55d7 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,8 +11,10 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -253,6 +255,23 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
}
};
+/// Convert the given type into the corresponding PTX type (NVVM::MMATypes
+/// enum).
+static FailureOr<NVVM::MMATypes> getNvvmMmaType(Type t) {
+ Type elType = getElementTypeOrSelf(t);
+ if (elType.isInteger(8))
+ return NVVM::MMATypes::s8;
+ if (elType.isInteger(4))
+ return NVVM::MMATypes::s4;
+ if (elType.isF16())
+ return NVVM::MMATypes::f16;
+ if (elType.isF64())
+ return NVVM::MMATypes::f64;
+ if (elType.isF32())
+ return NVVM::MMATypes::tf32;
+ return failure();
+}
+
struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
using ConvertOpToLLVMPattern<nvgpu::MmaSyncOp>::ConvertOpToLLVMPattern;
@@ -262,53 +281,38 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
Location loc = op->getLoc();
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
- auto aType = op.getMatrixA().getType().cast<VectorType>();
- auto cType = op.getMatrixC().getType().cast<VectorType>();
+ VectorType aType = op.getMatrixA().getType();
+ VectorType bType = op.getMatrixA().getType();
+ VectorType cType = op.getMatrixC().getType();
- int64_t m = op.getMmaShape()[0].cast<IntegerAttr>().getInt();
- int64_t n = op.getMmaShape()[1].cast<IntegerAttr>().getInt();
- int64_t k = op.getMmaShape()[2].cast<IntegerAttr>().getInt();
- std::array<int64_t, 3> gemmShape{m, n, k};
+ std::array<int64_t, 3> gemmShape = op.getMmaShapeAsArray();
+
+ // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
+ bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+ if (aType.getElementType().isF32() && !tf32Enabled)
+ return failure();
- NVVM::MMATypes ptxTypeA;
- NVVM::MMATypes ptxTypeB;
+ FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
+ if (failed(ptxTypeA))
+ return op->emitOpError("failed to deduce operand PTX types");
+ FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
+ if (failed(ptxTypeB))
+ return op->emitOpError("failed to deduce operand PTX types");
Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
cType.getElementType(), /*isAccumulator=*/true);
if (!ptxTypeC)
return op->emitError(
"could not infer the PTX type for the accumulator/result");
- // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
- bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
- if (aType.getElementType().isF32() && !tf32Enabled)
- return failure();
-
+ // TODO: add an attribute to the op to customize this behavior.
Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
- if (aType.getElementType().isInteger(8)) {
- ptxTypeA = NVVM::MMATypes::s8;
- ptxTypeB = NVVM::MMATypes::s8;
+ if (aType.getElementType().isa<IntegerType>())
overflow = NVVM::MMAIntOverflow::satfinite;
- } else if (aType.getElementType().isInteger(4)) {
- ptxTypeA = NVVM::MMATypes::s4;
- ptxTypeB = NVVM::MMATypes::s4;
- overflow = NVVM::MMAIntOverflow::satfinite;
- } else if (aType.getElementType().isF16()) {
- ptxTypeA = NVVM::MMATypes::f16;
- ptxTypeB = NVVM::MMATypes::f16;
- } else if (aType.getElementType().isF64()) {
- ptxTypeA = NVVM::MMATypes::f64;
- ptxTypeB = NVVM::MMATypes::f64;
- } else if (aType.getElementType().isF32()) {
- ptxTypeA = NVVM::MMATypes::tf32;
- ptxTypeB = NVVM::MMATypes::tf32;
- } else {
- return op->emitError("could not deduce operand PTX types");
- }
SmallVector<Value> matA =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), ptxTypeA);
+ unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
- unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), ptxTypeB);
+ unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
@@ -321,7 +325,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
/*b1Op=*/llvm::None,
/*intOverflow=*/overflow,
/*multiplicandPtxTypes=*/
- std::array<NVVM::MMATypes, 2>{ptxTypeA, ptxTypeB},
+ std::array<NVVM::MMATypes, 2>{*ptxTypeA, *ptxTypeB},
/*multiplicandLayouts=*/
std::array<NVVM::MMALayout, 2>{NVVM::MMALayout::row,
NVVM::MMALayout::col});
@@ -376,13 +380,182 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};
rewriter.create<LLVM::InlineAsmOp>(
- loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals,
+ loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
+ /*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
+/// Returns the constraints for the sparse MMA inline assembly instruction.
+static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
+ unsigned matBSize,
+ unsigned matCSize) {
+ std::string str;
+ llvm::raw_string_ostream ss(str);
+ for (unsigned i = 0; i < matCSize; i++)
+ ss << "=r,";
+ for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
+ ss << "r,";
+ // The final two operands are for the sparsity metadata and sparsity selector.
+ ss << "r,r";
+ ss.flush();
+ return str;
+}
+
+/// Returns the string for the `mma.sp.sync` instruction that corresponds to
+/// the give parameters. Note that this function doesn't do any validation,
+/// it's expected that the provided parameters correspond to a valid
+/// instruction.
+static std::string
+buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
+ unsigned matBSize, unsigned matCSize,
+ NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+ NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+ Optional<NVVM::MMAIntOverflow> overflow) {
+ auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
+ return NVVM::stringifyMMATypes(ptxType);
+ };
+
+ std::string asmStr;
+ llvm::raw_string_ostream ss(asmStr);
+ ss << "mma.sp.sync.aligned.m" << shape[0] << "n" << shape[1] << "k"
+ << shape[2] << ".row.col.";
+
+ if (overflow)
+ ss << NVVM::stringifyMMAIntOverflow(*overflow) << ".";
+
+ ss << ptxTypeStr(ptxTypeD) << "." << ptxTypeStr(ptxTypeA) << "."
+ << ptxTypeStr(ptxTypeB) << "." << ptxTypeStr(ptxTypeC) << " ";
+ unsigned asmArgIdx = 0;
+
+ // The operand string is structured into sections `{matC elements...},
+ // {matA elements...}, {matB elements...}, {matC elements}`.
+ for (const auto arrSize : {matCSize, matASize, matBSize, matCSize}) {
+ ss << "{";
+ for (unsigned i = 0; i < arrSize; i++)
+ ss << "$" << asmArgIdx++ << (i < arrSize - 1 ? "," : "");
+ ss << "},";
+ }
+ ss << "$" << asmArgIdx++ << ",$" << asmArgIdx++ << ";";
+ ss.flush();
+ return asmStr;
+}
+
+/// Builds an inline assembly operation corresponding to the specified MMA
+/// sparse sync operation.
+static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
+ Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+ NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+ Optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
+ ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
+ int64_t metadataSelector, const std::array<int64_t, 3> &shape,
+ Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
+ auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+ LLVM::AsmDialect::AD_ATT);
+
+ std::string asmStr = buildMmaSparseAsmString(
+ shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA,
+ ptxTypeB, ptxTypeC, ptxTypeD, overflow);
+ std::string constraintStr = buildMmaSparseAsmConstraintString(
+ unpackedAData.size(), unpackedB.size(), unpackedC.size());
+
+ Value selectorVal = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector));
+
+ SmallVector<Value> asmVals;
+ asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() +
+ 2);
+ for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
+ llvm::append_range(asmVals, args);
+ asmVals.push_back(indexData);
+ asmVals.push_back(selectorVal);
+
+ return rewriter.create<LLVM::InlineAsmOp>(loc,
+ /*resultTypes=*/intrinsicResultType,
+ /*operands=*/asmVals,
+ /*asm_string=*/asmStr,
+ /*constraints=*/constraintStr,
+ /*has_side_effects=*/true,
+ /*is_align_stack=*/false,
+ /*asm_dialect=*/asmDialectAttr,
+ /*operand_attrs=*/ArrayAttr());
+}
+
+/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
+struct NVGPUMmaSparseSyncLowering
+ : public ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp> {
+ using ConvertOpToLLVMPattern<nvgpu::MmaSparseSyncOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ // Get the shapes of the MMAMatrix type being used. The shapes will
+ // choose which intrinsic this op will be lowered to.
+ VectorType aType = op.getMatrixA().getType();
+ VectorType bType = op.getMatrixB().getType();
+ VectorType cType = op.getMatrixC().getType();
+
+ FailureOr<NVVM::MMATypes> ptxTypeA = getNvvmMmaType(aType);
+ if (failed(ptxTypeA))
+ return op->emitOpError("failed to deduce operand PTX types");
+ FailureOr<NVVM::MMATypes> ptxTypeB = getNvvmMmaType(bType);
+ if (failed(ptxTypeB))
+ return op->emitOpError("failed to deduce operand PTX types");
+ Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
+ cType.getElementType(), /*isAccumulator=*/true);
+ if (!ptxTypeC)
+ return op->emitError(
+ "could not infer the PTX type for the accumulator/result");
+
+ // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
+ bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+ if (aType.getElementType().isF32() && !tf32Enabled)
+ return failure();
+
+ // TODO: add an attribute to the op to customize this behavior.
+ Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
+ if (aType.getElementType().isa<IntegerType>())
+ overflow = NVVM::MMAIntOverflow::satfinite;
+
+ SmallVector<Value> matA =
+ unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
+ SmallVector<Value> matB =
+ unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
+ SmallVector<Value> matC =
+ unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
+
+ Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
+ Type intrinsicResTy = inferIntrinsicResultType(
+ typeConverter->convertType(op->getResultTypes()[0]));
+
+ // Bitcast the sparse metadata from vector<2xf16> to an i32.
+ Value sparseMetadata = adaptor.getSparseMetadata();
+ if (sparseMetadata.getType() !=
+ LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+ return op->emitOpError() << "Expected metadata type to be LLVM "
+ "VectorType of 2 i16 elements";
+ sparseMetadata = rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getI32Type(), sparseMetadata);
+
+ FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
+ loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
+ matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
+ intrinsicResTy, rewriter);
+ if (failed(intrinsicResult))
+ return failure();
+
+ assert((*intrinsicResult).getNumResults() == 1 &&
+ "expected inline asm op returns a single LLVM struct type");
+ rewriter.replaceOp(
+ op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, desiredRetTy,
+ (*intrinsicResult)->getResult(0), rewriter));
+ return success();
+ }
+};
+
struct NVGPUAsyncCopyLowering
: public ConvertOpToLLVMPattern<nvgpu::DeviceAsyncCopyOp> {
using ConvertOpToLLVMPattern<
@@ -488,8 +661,8 @@ struct NVGPUAsyncWaitLowering
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
- NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering>(
- converter);
+ NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
+ NVGPUMmaSparseSyncLowering>(converter);
}
std::unique_ptr<Pass> mlir::createConvertNVGPUToNVVMPass() {
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 9ed04b45aa1c8..24f70cb986e23 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -13,9 +13,11 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Verifier.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -80,13 +82,21 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
mmaShape, UnitAttr());
}
-LogicalResult MmaSyncOp::verify() {
-
- // Fundamental tensor core mma.sync op
- // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
- // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
- // verification for mma.sync covering various shapes and data types is based
- // on the fundamental tensor core operionation.
+/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
+static LogicalResult verifyMmaSyncOp(Operation *op,
+ TypedValue<VectorType> matrixA,
+ TypedValue<VectorType> matrixB,
+ TypedValue<VectorType> matrixC,
+ const std::array<int64_t, 3> &mmaShape,
+ bool tf32Enabled, bool sparse = false) {
+
+ // The verification for mma.sync covering various shapes and data types is
+ // based on the fundamental tensor core shape.
+
+ // "Fundamental" tensor core shapes:
+ // - For F32 (TF32), F16, S8, and S4 data
+ // types the fundamental tensor core operation is of shape 8-by-8-by-128b.
+ // - F64 is an exception and is of shape 8-by-8-by-256b.
constexpr int kThreads = 32; // 32 threads per warp
int64_t shapeM = 8;
int64_t shapeN = 8;
@@ -98,9 +108,9 @@ LogicalResult MmaSyncOp::verify() {
int64_t numElementC{2}; // two accumulator elements per fundamental tile
// nvgpu.mma.sync vector operands (per thread)
- auto aVector = getMatrixA().getType().cast<VectorType>();
- auto bVector = getMatrixB().getType().cast<VectorType>();
- auto cVector = getMatrixC().getType().cast<VectorType>();
+ auto aVector = matrixA.getType();
+ auto bVector = matrixB.getType();
+ auto cVector = matrixC.getType();
// vector shapes
ArrayRef<int64_t> aShape = aVector.getShape();
@@ -110,13 +120,9 @@ LogicalResult MmaSyncOp::verify() {
// vector element type
Type aType = aVector.getElementType();
- // tensor float32 (TF32) enabled
- bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
-
- // nvgpu.mma.sync shape (per 32 threads or per warp)
- int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
- int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
- int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
+ // Certain data types are not allowed in sparse mode.
+ if (sparse && aType.isF64())
+ return op->emitError() << "f64 is not supported for sparse mode";
if (aType.isF64()) {
// exception to 8-by-8-128b fundamental tensor core tile size
@@ -127,36 +133,43 @@ LogicalResult MmaSyncOp::verify() {
aType.isInteger(8) || aType.isInteger(4)) {
// 8-by-8-128b fundamental tensor core tile size
int operandBitwidth = aType.getIntOrFloatBitWidth();
- shapeK = 128 / operandBitwidth; // 128b wide shapeK
+ shapeK = 128 / operandBitwidth; // 128b wide shapeK
+
numElementA = 32 / operandBitwidth; // 32b wide operand A
numElementB = 32 / operandBitwidth; // 32b wide operand B
} else {
- return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
- "supported by nvgpu.mma.sync";
+ return op->emitError()
+ << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
+ "supported by "
+ << op->getName();
}
//
// Basic verification
//
+ auto [m, n, k] = mmaShape;
+
// verify warp-wide size for vector a
- if (aShape[0] * aShape[1] * kThreads != m * k)
- return emitOpError() << "expected " << m * k
- << " warp-wide matrix A elements";
+ int64_t sparseFactor = sparse ? 2 : 1;
+ if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)
+ return op->emitOpError()
+ << "expected " << m * k << " warp-wide matrix A elements";
// verify warp-wide size for vector b
if (bShape[0] * bShape[1] * kThreads != k * n)
- return emitOpError() << "expected " << k * n
- << " warp-wide matrix B elements";
+ return op->emitOpError()
+ << "expected " << k * n << " warp-wide matrix B elements";
// verify warp-wide size for vector c
if (cShape[0] * cShape[1] * kThreads != m * n)
- return emitOpError() << "expected " << m * n
- << " warp-wide matrix C elements";
+ return op->emitOpError()
+ << "expected " << m * n << " warp-wide matrix C elements";
// verify tf32 tensor cores are enabled for only F32 datatype
if (tf32Enabled && !(aType.isF32()))
- return emitOpError() << "expected tf32 tensor cores only for F32 operands";
+ return op->emitOpError()
+ << "expected tf32 tensor cores only for F32 operands";
//
// Extended verification
@@ -168,23 +181,48 @@ LogicalResult MmaSyncOp::verify() {
int64_t kTile = k / shapeK;
// verify shape of aVector
- if ((aShape[0] != mTile * kTile) || (aShape[1] != numElementA))
- return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
- << " x " << numElementA << ")";
+ if ((aShape[0] != mTile * kTile / (sparse ? 2 : 1)) ||
+ (aShape[1] != numElementA))
+ return op->emitOpError() << "expected matrix A to be shaped ("
+ << mTile * kTile << " x " << numElementA << ")";
// verify shape of bVector
if ((bShape[0] != kTile * nTile) || (bShape[1] != numElementB))
- return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
- << " x " << numElementB << ")";
+ return op->emitOpError() << "expected matrix B to be shaped ("
+ << kTile * nTile << " x " << numElementB << ")";
// verify shape of cVector
if ((cShape[0] != mTile * nTile) || (cShape[1] != numElementC))
- return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
- << " x " << numElementC << ")";
+ return op->emitOpError() << "expected matrix C to be shaped ("
+ << mTile * nTile << " x " << numElementC << ")";
return success();
}
+LogicalResult MmaSyncOp::verify() {
+ return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
+ getMatrixC(), getMmaShapeAsArray(),
+ getOperation()->hasAttr(getTf32EnabledAttrName()));
+}
+
+//===----------------------------------------------------------------------===//
+// NVGPU_MmaSparseSyncOp
+//===----------------------------------------------------------------------===//
+void MmaSparseSyncOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value matrixA,
+ Value matrixB, Value matrixC, Value sparseMetadata,
+ ArrayRef<int64_t> mmaShape) {
+ build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+ sparseMetadata, odsBuilder.getI64ArrayAttr(mmaShape), 0, UnitAttr());
+}
+
+LogicalResult MmaSparseSyncOp::verify() {
+ return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(),
+ getMatrixC(), getMmaShapeAsArray(),
+ getOperation()->hasAttr(getTf32EnabledAttrName()),
+ true);
+}
+
//===----------------------------------------------------------------------===//
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0a9f8d5611903..c95b2fca9dffd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -313,3 +313,119 @@ func.func @async_cp_zfill(
return
}
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16832(
+func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
+ %arg1: vector<4x2xf16>,
+ %arg2: vector<2x2xf16>,
+ %arg3: vector<2xi16>) -> vector<2x2xf16> {
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK-NOT llvm.extractvalue
+
+ // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+ // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+ // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;"
+ // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+
+ // CHECK-DAG: llvm.extractvalue %[[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK-DAG: llvm.extractvalue %[[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.insertvalue %{{.+}}, %{{.+}}[1] : !llvm.array<2 x vector<2xf16>>
+ return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16816(
+func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf16>,
+ %arg3: vector<2xi16>) -> vector<2x2xf16> {
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
+
+ // CHECK-NOT llvm.extractvalue
+
+ // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+ // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+ // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;"
+ // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+
+// -----
+
+// CHECK-LABEL: func @mma_sp_sync_i8_16864(
+func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
+ %arg1: vector<4x4xi8>,
+ %arg2: vector<2x2xi32>,
+ %arg3: vector<2xi16>) -> vector<2x2xi32> {
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+ // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+ // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<4xi8>>
+
+
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+ // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<4xi8>>
+ // CHECK: llvm.bitcast %{{.+}} : vector<4xi8> to i32
+
+ // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+ // CHECK: llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+
+ // CHECK-NOT llvm.extractvalue
+
+ // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
+ // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+ // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32
+ // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
+
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ return %d : vector<2x2xi32>
+}
diff --git a/mlir/test/Dialect/NVGPU/roundtrip.mlir b/mlir/test/Dialect/NVGPU/roundtrip.mlir
index 524f1fd6907b7..ad516b4d2c200 100644
--- a/mlir/test/Dialect/NVGPU/roundtrip.mlir
+++ b/mlir/test/Dialect/NVGPU/roundtrip.mlir
@@ -19,6 +19,44 @@ func.func @mma_sync(%arg0: vector<4x2xf16>,
return %d : vector<2x2xf16>
}
+// CHECK-LABEL: func @mma_sp_sync_f16_16832(
+func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
+ %arg1: vector<4x2xf16>,
+ %arg2: vector<2x2xf16>,
+ %arg3: vector<2xi16>) -> vector<2x2xf16> {
+ // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+ // CHECK-SAME: mmaShape = [16, 8, 32]
+ // CHECK-SAME: (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
+ (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func @mma_sp_sync_f16_16816(
+func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
+ %arg1: vector<2x2xf16>,
+ %arg2: vector<2x2xf16>,
+ %arg3: vector<2xi16>) -> vector<2x2xf16> {
+ // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+ // CHECK-SAME: mmaShape = [16, 8, 16]
+ // CHECK-SAME: (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
+ (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+
+// CHECK-LABEL: func @mma_sp_sync_i8_16864(
+func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
+ %arg1: vector<4x4xi8>,
+ %arg2: vector<2x2xi32>,
+ %arg3: vector<2xi16>) -> vector<2x2xi32> {
+ // CHECK: nvgpu.mma.sp.sync(%{{.*}}, %{{.*}}, %{{.*}}) metadata(%{{.+}}) {
+ // CHECK-SAME: mmaShape = [16, 8, 64]
+ // CHECK-SAME: (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
+ (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ return %d : vector<2x2xi32>
+}
func.func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){
// CHECK-LABEL: func @async_cp
More information about the Mlir-commits
mailing list