[Mlir-commits] [mlir] [MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (PR #65440)
Guray Ozen
llvmlistbot at llvm.org
Wed Sep 13 07:05:25 PDT 2023
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/65440:
>From 91666bfb2b3420dad072c29f9194895896b2cd35 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 1 Sep 2023 09:30:56 +0200
Subject: [PATCH 1/4] [MLIR][NVGPU] Adding `nvgpu.wargroup.mma` Op for Hopper
GPUs
This work introduces a new operation called `wargroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture.
Previously, the `nvvm.wgmma.mma_async` operation was introduced to support wargroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.wargroup.mma` operation abstracts this complexity and provides a higher-level interface for performing wargroup-level matrix operations.
The `nvgpu.wargroup.mma` does followings:
1) Corresponds multiple `wgmma` instructions.
2) Iterates input matrix descriptors to achieve the desired computation shape.
3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned`
4) Results fragmented matrices
Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
%wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
vector<128x128xf32>
-> !nvgpu.warpgroup.result<tensor = !llvm.struct<...>,
!nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
```
Differential Revision: https://reviews.llvm.org/D158434
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 48 +++++
.../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 2 +
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 169 +++++++++++++++++-
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 108 ++++++++++-
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 66 ++++++-
mlir/test/Dialect/NVGPU/invalid.mlir | 61 +++++++
6 files changed, 446 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index a3245bf9196eed1..f891aae136eba25 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -192,6 +192,15 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w
let assemblyFormat = "`<` struct(params) `>`";
}
+def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> {
+ let parameters = (ins "Type":$tensor);
+ let assemblyFormat = "`<` struct(params) `>`";
+ let description = [{
+ It is fragmented result matrix from `nvgpu.wargroup.mma`.
+ [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVGPU Op Definitions
//===----------------------------------------------------------------------===//
@@ -664,5 +673,44 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> {
+ let description = [{
+ The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps)
+ matrix-multiply-and-accumulate (mma) operation that results in
+ `nvvm.wgmma.mma_async`.
+
+ The operands are `descriptorA` and `descriptorB` that are wgmma matrix
+ descriptors that shows the properties of the matrix in shared memory. The
+ results are thread-level ownership to the warpgroup-level mma operation
+ shape. The shape is deduced from the descriptor types and output vector.
+
+ The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the
+ given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous,
+ this Op groups the `nvvm.wgmma.async` and surrounds them between
+ `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`,
+ `wgmma.wait.group.sync.aligned` Ops.
+
+ Example:
+ ```mlir
+ %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc:
+ !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = ...>
+ ```
+ }];
+
+ let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
+ NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
+ AnyVector:$matrixC,
+ DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
+ OptionalAttr<UnitAttr>:$transposeA,
+ OptionalAttr<UnitAttr>:$transposeB);
+ let results = (outs Variadic<NVGPU_WarpgroupResult>:$matrixD);
+ let assemblyFormat = [{
+ $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict
+ `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
+ }];
+ let hasVerifier = 1;
+}
#endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 192afcb2dba7913..96af26842dafea2 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -21,6 +21,8 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
+constexpr int kWarpSize = 32;
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index b045089244ff1a7..90d138bd206e010 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -34,6 +35,10 @@ namespace mlir {
using namespace mlir;
+/// Number of bits that needs to excluded when building matrix descriptor for
+/// wgmma operations.
+constexpr int exclude4LSB = 4;
+
/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
@@ -984,10 +989,9 @@ struct NVGPUGenerateGmmaDescriptorLowering
shiftLeft(val, startBit));
};
- int ex4LSB = 4;
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
- uint64_t strideDimVal = (layout << 3) >> ex4LSB;
- uint64_t leadDimVal = (sizeN * layout) >> ex4LSB;
+ uint64_t strideDimVal = (layout << 3) >> exclude4LSB;
+ uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB;
uint64_t offsetVal = 0;
Value strideDim = makeConst(strideDimVal);
@@ -1141,6 +1145,164 @@ struct NVGPUTmaCreateDescriptorOpLowering
}
};
+struct NVGPUWarpgroupMmaOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
+ using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
+ int &wgmmaShapeM, int &wgmmaShapeN,
+ int &wgmmaShapeK) const {
+ wgmmaShapeM = 64;
+ wgmmaShapeN = sizeN;
+ if (inputElemType.isTF32()) {
+ wgmmaShapeK = 8;
+ } else if (inputElemType.isF16() || inputElemType.isBF16()) {
+ wgmmaShapeK = 16;
+ } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
+ inputElemType.isInteger(16)) {
+ wgmmaShapeK = 32;
+ } else if (inputElemType.isInteger(1)) {
+ wgmmaShapeK = 256;
+ } else {
+ return failure();
+ }
+ LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
+ << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
+ << "]\n");
+ return success();
+ }
+
+ Value generateNVVMWgmmaOp(MLIRContext *ctx,
+ ConversionPatternRewriter &rewriter, Location loc,
+ int m, int n, int k, Type resultStructType,
+ Value inout, Value descriptorA,
+ Value descriptorB) const {
+ TypeRange resultTypes = {resultStructType};
+ auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
+ auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
+ auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
+ auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
+ auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
+ // todo input type
+ auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
+ auto overflow =
+ NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
+ Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
+ loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype,
+ scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
+ return res;
+ }
+
+ static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
+ int sizeN) {
+ int outputElements = 0;
+ if (outElemType.isF32() || outElemType.isInteger(32))
+ outputElements = sizeN / 2;
+ if (outElemType.isF16())
+ outputElements = sizeN / 4;
+ SmallVector<Type> structBody;
+ for (int i = 0; i < outputElements; i++)
+ structBody.push_back(outElemType);
+ return LLVM::LLVMStructType::getLiteral(ctx, structBody);
+ }
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> wgmmaResults;
+
+ int64_t sizeM = op.getMatrixC().getType().getDimSize(0);
+ int64_t sizeN = op.getMatrixC().getType().getDimSize(1);
+ int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
+
+ LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
+ << sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
+ << sizeN << "] ---===\n");
+
+ int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
+ if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
+ wgmmaShapeN, wgmmaShapeK))) {
+ return failure();
+ }
+
+ Value descriptorA = adaptor.getDescriptorA();
+ Value descriptorB = adaptor.getDescriptorB();
+
+ // Generate wgmma group
+
+ auto loc = op->getLoc();
+ Type outElemType = op.getMatrixC().getType().getElementType();
+ Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN);
+ MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
+ MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
+
+ auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ };
+
+ auto iterateDescA = [&](Value desc, int iterM, int iterN,
+ int iterK) -> Value {
+ // todo : Handle column major
+ int byte = typeTensorA.getElementTypeBitWidth() / 8;
+ int tileShapeA = typeTensorA.getDimSize(1);
+ int incrementVal =
+ ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
+ incrementVal = incrementVal >> exclude4LSB;
+ LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
+ << iterK << "] [wgmma descriptors] Descriptor A + "
+ << incrementVal << " | \t ");
+ return incrementVal
+ ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
+ : desc;
+ };
+
+ auto iterateDescB = [&](Value desc, int iterM, int iterN,
+ int iterK) -> Value {
+ // todo : Handle row major
+ int byte = typeTensorB.getElementTypeBitWidth() / 8;
+ int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
+ incrementVal = incrementVal >> exclude4LSB;
+ LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ return incrementVal
+ ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
+ : desc;
+ };
+
+ rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+ for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
+ Value undefOp = rewriter.create<LLVM::UndefOp>(loc, stype);
+ Value inout = undefOp;
+ LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
+ << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
+ << ":" << wgmmaShapeN << "] += \n");
+ for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
+ Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
+ Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
+ LLVM_DEBUG(DBGS() << "\t wgmma."
+ << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
+ << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
+ << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
+ << (iterK * wgmmaShapeK) << ":"
+ << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
+ << " B[" << (iterK * wgmmaShapeK) << ":"
+ << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
+ << ":" << wgmmaShapeN << "])\n");
+ inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
+ wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+ stype, inout, descA, descB);
+ }
+ wgmmaResults.push_back(inout);
+ }
+
+ rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
+ rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
+
+ ValueRange myres(wgmmaResults);
+ rewriter.replaceOp(op, myres);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1156,6 +1318,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
+ NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d832a983a132d61..cd0d65ddd9a65c0 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -151,7 +151,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
// - For F32 (TF32), F16, S8, and S4 data
// types the fundamental tensor core operation is of shape 8-by-8-by-128b.
// - F64 is an exception and is of shape 8-by-8-by-256b.
- constexpr int kThreads = 32; // 32 threads per warp
int64_t shapeM = 8;
int64_t shapeN = 8;
int64_t shapeK; // set based on data type (128b for all data types except F64)
@@ -206,17 +205,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
// verify warp-wide size for vector a
int64_t sparseFactor = sparse ? 2 : 1;
- if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor)
+ if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor)
return op->emitOpError()
<< "expected " << m * k << " warp-wide matrix A elements";
// verify warp-wide size for vector b
- if (bShape[0] * bShape[1] * kThreads != k * n)
+ if (bShape[0] * bShape[1] * kWarpSize != k * n)
return op->emitOpError()
<< "expected " << k * n << " warp-wide matrix B elements";
// verify warp-wide size for vector c
- if (cShape[0] * cShape[1] * kThreads != m * n)
+ if (cShape[0] * cShape[1] * kWarpSize != m * n)
return op->emitOpError()
<< "expected " << m * n << " warp-wide matrix C elements";
@@ -402,6 +401,107 @@ LogicalResult GenerateGmmaDescriptorOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
+ // F32 += F16 + F16
+ // F16 += F16 + F16
+ if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16()))
+ return success();
+ // F32 += TF32 + TF32
+ if (typeA.isTF32() && typeD.isF32() && typeB.isTF32())
+ return success();
+ // s32 += i8 + i8
+ if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32))
+ return success();
+ // s32 += i1 + i1
+ if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32))
+ return success();
+ // F32 += BF16 + BF16
+ // F16 += BF16 + BF16
+ if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16()))
+ return success();
+ // F16 += f8 + f8
+ // F32 += f8 + f8
+ if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) &&
+ (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) &&
+ (typeD.isF32() || typeD.isF16()))
+ return success();
+
+ return failure();
+}
+
+LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
+ SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
+ 72, 80, 88, 96, 104, 112, 120, 128,
+ 136, 144, 152, 160, 168, 176, 184, 192,
+ 200, 208, 216, 224, 232, 240, 248, 256};
+ SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
+ 80, 96, 112, 128, 144, 160,
+ 176, 192, 208, 224, 240, 256};
+ if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+ typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
+ if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; }))
+ return success();
+
+ if (typeA.isInteger(8) || typeA.isInteger(1))
+ if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; }))
+ return success();
+ return failure();
+}
+
+LogicalResult WarpgroupMmaOp::verify() {
+ if (getTransposeA() && !getTransposeB())
+ return emitOpError() << "supports non-transpose A (Row Major) "
+ "and transpose B (Column Major) for the time being";
+ auto matrixA = getDescriptorA().getType().getTensor();
+ auto matrixB = getDescriptorB().getType().getTensor();
+ auto matrixC = getMatrixC().getType();
+ if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
+ matrixC.getRank() != 2)
+ return emitOpError()
+ << "has input matrices A, B and D, they must be 2 dimensional";
+
+ if (matrixA.getShape()[1] != matrixB.getShape()[0])
+ return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
+ << ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
+ << " )";
+ if (matrixA.getShape()[0] != matrixC.getShape()[0])
+ return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
+ << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
+ << " )";
+ if (matrixB.getShape()[1] != matrixC.getShape()[1])
+ return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1]
+ << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1]
+ << " )";
+
+ if (failed(isAllowedWGMMADataType(matrixC.getElementType(),
+ matrixA.getElementType(),
+ matrixB.getElementType())))
+ return emitOpError() << matrixC.getElementType()
+ << " += " << matrixA.getElementType() << " * "
+ << matrixB.getElementType()
+ << ", it is not supported.";
+ // Check N
+ if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) {
+ return emitOpError() << "has input type " << matrixB << " n is set to "
+ << matrixB.getDimSize(1) << ", it is not supported";
+ }
+
+ // Currently, f16/bf16 supported
+ if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() &&
+ !matrixA.getElementType().isBF16()) {
+ return emitOpError() << "hit a limitation: " << matrixC.getElementType()
+ << " += " << matrixA.getElementType() << " * "
+ << matrixB.getElementType()
+ << ", it is not supported yet";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0d7ace52ccb36c9..cafeb785e31ffe2 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -672,6 +672,70 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
}
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32)>
+
+// CHECK-LABEL: @warpgroup_mma_128_128_64(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_mma_128_128_64(
+ %descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
+ %descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ %D: memref<128x128xf32,3>)
+{
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%3, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
+// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
+// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
+// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64
+// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], <m = 64, n = 128, k = 16>, D[%[[S4]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64
+// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64
+// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64
+// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64
+// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], <m = 64, n = 128, k = 16>, D[%[[S9]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64
+// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64
+// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
+// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64
+// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
+// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64
+// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S20]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64
+// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64
+// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64
+// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64
+// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], <m = 64, n = 128, k = 16>, D[%[[S23]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64
+// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64
+// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64
+// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64
+// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], <m = 64, n = 128, k = 16>, D[%[[S28]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64
+// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64
+// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
+// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64
+// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: nvvm.wgmma.commit.group.sync.aligned
+// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+ %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32>
+ %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
+ !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+
+ return
+}
+
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
@@ -681,5 +745,5 @@ transform.sequence failures(propagate) {
} with type_converter {
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
{use_opaque_pointers = true}
- } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op
+ } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index ef721b18014071a..d7af22085c10b9e 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -221,3 +221,64 @@ func.func @async_cp_size_invalid_f64(
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 3: memref<128x128xf64> to memref<3x16x128xf64, 3>
return
}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
+
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+ // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
+ %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
+func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) {
+ // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}}
+ %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult
+ return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
+func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+ // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
+ %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ return
+}
+
+// -----
+
+!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
+ f32, f32, f32, f32, f32, f32)>
+!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
+!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
+func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) {
+ // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}}
+ %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult
+ return
+}
>From 69bf60f3b25b462ec025e56b02c18077b9d3ebdf Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 7 Sep 2023 11:13:05 +0200
Subject: [PATCH 2/4] Include WGMMA descriptor type in transform dialect
---
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index a173317bbbdb3f4..d13f640147c52f5 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -64,6 +64,11 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
});
+ llvmTypeConverter.addConversion(
+ [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
+ return llvmTypeConverter.convertType(
+ IntegerType::get(type.getContext(), 64));
+ });
llvmTypeConverter.addConversion(
[&](nvgpu::TensorMapDescriptorType type) -> Type {
return llvmTypeConverter.getPointerType(
>From 744a69ae533616636e30aef7bd9a96a0401ea017 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 7 Sep 2023 11:13:40 +0200
Subject: [PATCH 3/4] wargroup -> warpgroup
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 8 ++++----
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 1 -
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 2 +-
mlir/test/Dialect/NVGPU/invalid.mlir | 16 ++++++++--------
4 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index f891aae136eba25..060fe656d32e840 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -196,7 +196,7 @@ def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []
let parameters = (ins "Type":$tensor);
let assemblyFormat = "`<` struct(params) `>`";
let description = [{
- It is fragmented result matrix from `nvgpu.wargroup.mma`.
+ It is fragmented result matrix from `nvgpu.warpgroup.mma`.
[See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
}];
}
@@ -673,9 +673,9 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> {
let hasVerifier = 1;
}
-def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> {
+def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let description = [{
- The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps)
+ The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps)
matrix-multiply-and-accumulate (mma) operation that results in
`nvvm.wgmma.mma_async`.
@@ -692,7 +692,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> {
Example:
```mlir
- %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc:
+ %res = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = ...>
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 90d138bd206e010..c8d91e7c5893a3c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index cafeb785e31ffe2..fdd6cbc519b6a51 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,7 +728,7 @@ func.func @warpgroup_mma_128_128_64(
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
%acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32>
- %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
+ %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index d7af22085c10b9e..a915f7f3b809582 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -233,8 +233,8 @@ func.func @async_cp_size_invalid_f64(
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
- // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
- %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
return
}
@@ -248,8 +248,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vecto
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) {
- // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}}
- %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input matrices A, B and D, they must be 2 dimensional}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult
return
}
@@ -263,8 +263,8 @@ func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D:
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
- // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
- %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
return
}
@@ -278,7 +278,7 @@ func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: v
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) {
- // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}}
- %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult
return
}
>From be8b62104a01191a1e6c2e29c03110cfb23502e6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 13 Sep 2023 16:05:05 +0200
Subject: [PATCH 4/4] Improve accumulator matrix type
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 36 ++++++----
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 66 +++++++++----------
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 39 +++++++++--
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 10 +++
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 36 +++++-----
mlir/test/Dialect/NVGPU/invalid.mlir | 45 ++++---------
6 files changed, 127 insertions(+), 105 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 060fe656d32e840..90381648dac6acc 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -192,12 +192,16 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w
let assemblyFormat = "`<` struct(params) `>`";
}
-def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> {
- let parameters = (ins "Type":$tensor);
+def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> {
+ let parameters = (ins "VectorType":$fragmented);
let assemblyFormat = "`<` struct(params) `>`";
let description = [{
- It is fragmented result matrix from `nvgpu.warpgroup.mma`.
- [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+ This type represents the result matrix obtained from `nvgpu.warpgroup.mma`.
+ The `$fragmented` type signifies the distributed or fragmented result
+ vector that is collectively owned by all the threads in the warp-group
+ that executed `nvgpu.warpgroup.mma`.
+ [See the details of register fragment layout for accumulator matrix D]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
}];
}
@@ -685,29 +689,33 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
shape. The shape is deduced from the descriptor types and output vector.
The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the
- given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous,
+ given shape. As the instruction `nvvm.wgmma.async` is an asynchronous,
this Op groups the `nvvm.wgmma.async` and surrounds them between
`wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`,
`wgmma.wait.group.sync.aligned` Ops.
Example:
```mlir
- %res = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc:
- !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
- !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
- vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = ...>
+ %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2:
+ !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
+ !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
+ ->
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
```
}];
let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
- NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
- AnyVector:$matrixC,
+ NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
- OptionalAttr<UnitAttr>:$transposeB);
- let results = (outs Variadic<NVGPU_WarpgroupResult>:$matrixD);
+ OptionalAttr<UnitAttr>:$transposeB,
+ Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
+ let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
let assemblyFormat = [{
- $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict
+ $descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index c8d91e7c5893a3c..046727e4ea9ab83 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -17,10 +17,12 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvgpu-to-nvvm"
@@ -423,6 +425,15 @@ struct ConvertNVGPUToNVVMPass
converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
+ converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
+ VectorType vtype = type.getFragmented();
+ SmallVector<Type> structBody;
+ for (unsigned i = 0; i < vtype.getDimSize(0); i++)
+ structBody.push_back(vtype.getElementType());
+ auto convertedType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
+ return converter.convertType(convertedType);
+ });
converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
return converter.convertType(IntegerType::get(type.getContext(), 64));
});
@@ -442,6 +453,8 @@ struct ConvertNVGPUToNVVMPass
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -1163,7 +1176,7 @@ struct NVGPUWarpgroupMmaOpLowering
} else if (inputElemType.isInteger(1)) {
wgmmaShapeK = 256;
} else {
- return failure();
+ llvm_unreachable("msg: not supported K shape");
}
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
@@ -1192,26 +1205,11 @@ struct NVGPUWarpgroupMmaOpLowering
return res;
}
- static Type buildOutputStructType(MLIRContext *ctx, Type outElemType,
- int sizeN) {
- int outputElements = 0;
- if (outElemType.isF32() || outElemType.isInteger(32))
- outputElements = sizeN / 2;
- if (outElemType.isF16())
- outputElements = sizeN / 4;
- SmallVector<Type> structBody;
- for (int i = 0; i < outputElements; i++)
- structBody.push_back(outElemType);
- return LLVM::LLVMStructType::getLiteral(ctx, structBody);
- }
-
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> wgmmaResults;
-
- int64_t sizeM = op.getMatrixC().getType().getDimSize(0);
- int64_t sizeN = op.getMatrixC().getType().getDimSize(1);
+ int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
+ int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
@@ -1230,8 +1228,6 @@ struct NVGPUWarpgroupMmaOpLowering
// Generate wgmma group
auto loc = op->getLoc();
- Type outElemType = op.getMatrixC().getType().getElementType();
- Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN);
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
@@ -1250,9 +1246,9 @@ struct NVGPUWarpgroupMmaOpLowering
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
<< iterK << "] [wgmma descriptors] Descriptor A + "
<< incrementVal << " | \t ");
- return incrementVal
- ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
- : desc;
+ if (!incrementVal)
+ return desc;
+ return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
};
auto iterateDescB = [&](Value desc, int iterM, int iterN,
@@ -1262,15 +1258,18 @@ struct NVGPUWarpgroupMmaOpLowering
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
- return incrementVal
- ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal))
- : desc;
+ if (!incrementVal)
+ return desc;
+ return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
};
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
+
+ SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
- Value undefOp = rewriter.create<LLVM::UndefOp>(loc, stype);
- Value inout = undefOp;
+ Value matrixC = adaptor.getMatrixC()[iterM];
+ Value matrixD = op.getMatrixD()[iterM];
+ Type structType = getTypeConverter()->convertType(matrixD.getType());
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
<< ":" << wgmmaShapeN << "] += \n");
@@ -1286,13 +1285,12 @@ struct NVGPUWarpgroupMmaOpLowering
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
- inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
- wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
- stype, inout, descA, descB);
+ matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
+ wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
+ structType, matrixC, descA, descB);
}
- wgmmaResults.push_back(inout);
+ wgmmaResults.push_back(matrixC);
}
-
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
@@ -1317,7 +1315,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
- NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma
+ NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index cd0d65ddd9a65c0..d96ed69982870b4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -456,19 +457,45 @@ LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
return emitOpError() << "supports non-transpose A (Row Major) "
"and transpose B (Column Major) for the time being";
- auto matrixA = getDescriptorA().getType().getTensor();
- auto matrixB = getDescriptorB().getType().getTensor();
- auto matrixC = getMatrixC().getType();
+ MemRefType matrixA = getDescriptorA().getType().getTensor();
+ MemRefType matrixB = getDescriptorB().getType().getTensor();
+ VectorType matrixC = getMatrixC()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
+ VectorType matrixD = getMatrixD()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
+ unsigned sizeAcc = getMatrixC().size();
+
+ if (getMatrixC().size() != getMatrixD().size())
+ return emitOpError() << "number of matrix C and matrix D must be the same";
+
+ if (llvm::all_of(getMatrixC(),
+ [&](Value rhs) { return rhs.getType() == matrixC; })) {
+ return emitOpError()
+ << "types of all operands in matrix C must be the same";
+ }
+ if (llvm::all_of(getMatrixD(),
+ [&](Value rhs) { return rhs.getType() == matrixC; })) {
+ return emitOpError()
+ << "types of all operands in matrix D must be the same as matrix C";
+ }
+
if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
- matrixC.getRank() != 2)
+ matrixC.getRank() != 2 || matrixD.getRank() != 2) {
return emitOpError()
- << "has input matrices A, B and D, they must be 2 dimensional";
+ << "has matrices A, B, C and D, they must be 2 dimensional";
+ }
if (matrixA.getShape()[1] != matrixB.getShape()[0])
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
- if (matrixA.getShape()[0] != matrixC.getShape()[0])
+ if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index d13f640147c52f5..680c21ab74fe020 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -60,6 +60,16 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
+ llvmTypeConverter.addConversion(
+ [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
+ VectorType vtype = type.getFragmented();
+ SmallVector<Type> structBody;
+ for (unsigned i = 0; i < vtype.getDimSize(0); i++)
+ structBody.push_back(vtype.getElementType());
+ auto convertedType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
+ return llvmTypeConverter.convertType(convertedType);
+ });
llvmTypeConverter.addConversion([&](nvgpu::MBarrierType type) -> Type {
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index fdd6cbc519b6a51..b7aa0c7382d80d0 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -672,23 +672,20 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc
func.return %descA : !nvgpu.wgmma.descriptor<tensor=memref<128x64xf16,3>>
}
-!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32)>
-
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
func.func @warpgroup_mma_128_128_64(
%descA: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
%descB: !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
- %D: memref<128x128xf32,3>)
+ %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
{
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.fence.aligned
-// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%3, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64
@@ -704,10 +701,9 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64
// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
-// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64
-// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S20]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64
// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64
// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64
@@ -724,15 +720,15 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64
// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
// CHECK: nvvm.wgmma.commit.group.sync.aligned
-// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
- %c0 = arith.constant 0 : index
- %f0 = arith.constant 0.0 : f32
- %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32>
- %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}:
+// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
+ %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}:
!nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>,
- vector<128x128xf32> -> !nvgpu.warpgroup.result<tensor = !accMatrixStruct>, !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
-
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ ->
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
return
}
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index a915f7f3b809582..ff391e469815d74 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -224,61 +224,44 @@ func.func @async_cp_size_invalid_f64(
// -----
-!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32)>
-!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x121xf16, 3>>
-func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}}
- %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
return
}
// -----
-!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32)>
-!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<128xf32>>
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>
-func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) {
- // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input matrices A, B and D, they must be 2 dimensional}}
- %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
return
}
// -----
-
-!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32)>
-!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x128xf32, 3>>
-func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) {
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
// expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}}
- %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
return
}
// -----
-!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32,
- f32, f32, f32, f32, f32, f32)>
-!tResult = !nvgpu.warpgroup.result<tensor = !accMatrixStruct>
+!tResult = !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
!tDescA = !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>
!tDescB = !nvgpu.wgmma.descriptor<tensor = memref<64x512xf16, 3>>
-func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) {
- // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}}
- %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult
+func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) {
+ // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}}
+ %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult
return
}
More information about the Mlir-commits
mailing list