[Mlir-commits] [mlir] [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (PR #68728)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 10 10:56:58 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multiple `WarpgroupAccumulator` if the shape of GEMM is larger than the supported shape of `wgmma.mma_async` instruction. This makes IR looks dense.
This PR improves the transformation of `WarpgroupAccumulator` type in every nvgpu Op that uses it.
**Example: Current GEMM in NVGPU-IR**
```
// Init
%m1, %m2 = nvgpu.wargroup.mma.init.accumulator ->
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
// GEMM
%r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
// Epilogue
nvgpu.wargroup.mma.store [%r1, %r2] to %sharedMemoryBuffer
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
**Example: This PR simplifies the IR as below:**
```
// Init
%m = nvgpu.wargroup.mma.init.accumulator ->
!nvgpu.wargroup.accumulator<fragmented = vector<128x128xf32>>
// GEMM
%r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
->
!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
// Epilogue
nvgpu.wargroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
: !nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>,
!nvgpu.wargroup.accumulator<fragmented = vector<64x128xf32>>
into memref<128x128xf32,3>
```
---
Patch is 43.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68728.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+16-4)
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+3)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+84-28)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+46-51)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+101-29)
``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 57cd1a3806c2ed6..fd16376be366912 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
- Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
- let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
+ NVGPU_WarpgroupAccumulator:$matrixC);
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,13 +739,25 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
Note that, the op must be run with warp group.
}];
- let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+ let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
let assemblyFormat = [{
- `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulator"> {
+ let summary = "Initializes the accumulator matrix";
+
+ let description = [{
+ This Op generates and initializes the accumulator matrix for
+ `nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
+ }];
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
+ let assemblyFormat = "attr-dict `->` type($matrixC)";
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 96af26842dafea2..e6bba7e6082964b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
constexpr int kWarpSize = 32;
+/// M size of wgmma.mma_async instruction
+constexpr int kWgmmaSizeM = 64;
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 99c4d4223351352..2d43230938526b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- VectorType vtype = type.getFragmented();
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
SmallVector<Type> structBody;
- for (unsigned i = 0; i < vtype.getDimSize(0); i++)
- structBody.push_back(vtype.getElementType());
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
- const LLVMTypeConverter &typeConverter;
// Entire shape of the given Op
int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
- Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+ Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- Type resultStructType = typeConverter.convertType(matrixD.getType());
-
return b.create<NVVM::WgmmaMmaAsyncOp>(
- resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+ matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
- SmallVector<Value> generateWgmmaGroup() {
- SmallVector<Value> wgmmaResults;
+ Value generateWgmmaGroup() {
+ Value wgmmaResult =
+ b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
// Perform GEMM
+ SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = adaptor.getMatrixC()[i];
- Value matrixD = op.getMatrixD()[i];
+ Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
- matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+ matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
-
- return wgmmaResults;
+ for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
+ wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
+ }
+ return wgmmaResult;
}
public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
- OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
- : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+ OpAdaptor adaptor)
+ : op(op), b(b), adaptor(adaptor) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
- SmallVector<Value> generateWarpgroupMma() {
+ Value generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
- SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+ Value wgmmaResult = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
- return wgmmaResults;
+ return wgmmaResult;
}
};
-
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+
// Step 1. Build a helper class
- WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+ WarpgroupGemm warpgroupGemm(op, b, adaptor);
// Step 2. Get the entire GEMM Shape
- SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+ Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
// Step 3. Replace fragmented result struct with the op results
- rewriter.replaceOp(op, wgmmaResults);
+ rewriter.replaceOp(op, wgmmaResult);
return success();
}
};
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
- ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
- for (Value matrixD : adaptor.getMatrixD()) {
- auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
- storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value matriDValue = adaptor.getMatrixD();
+ auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+ for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
+ auto structType = matrixD.cast<LLVM::LLVMStructType>();
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
@@ -1546,6 +1568,39 @@ struct NVGPUWarpgroupMmaStoreOpLowering
}
};
+struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaInitAccumulatorOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaInitAccumulatorOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ LLVM::LLVMStructType structType =
+ getTypeConverter()
+ ->convertType(op.getMatrixC().getType())
+ .cast<LLVM::LLVMStructType>();
+ Type elemType = structType.getBody()
+ .front()
+ .cast<LLVM::LLVMStructType>()
+ .getBody()
+ .front();
+ Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
+ Value structValue = b.create<LLVM::UndefOp>(structType);
+ for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
+ auto innerStructType = s.cast<LLVM::LLVMStructType>();
+ int ii = idx;
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
+ for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
+ innerStructValue = b.create<LLVM::InsertValueOp>(
+ innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
+ }
+ }
+ rewriter.replaceOp(op, structValue);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1563,6 +1618,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUGenerateWarpgroupDescriptorLowering, // nvgpu.warpgroup.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store
+ NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index e8ecd0faa4c86d3..f5b02fe1b515591 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -435,6 +435,12 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}
+LogicalResult isAllowedSizeM(int sizeM) {
+ if (sizeM % kWgmmaSizeM)
+ return failure();
+ return success();
+}
+
LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
72, 80, 88, 96, 104, 112, 120, 128,
@@ -443,7 +449,7 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedNshort = {8, 16, 24, 32, 48, 64,
80, 96, 112, 128, 144, 160,
176, 192, 208, 224, 240, 256};
- if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() ||
+ if (typeA.isBF16() || typeA.isF16() || typeA.isF32() || typeA.isTF32() ||
typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2())
if (llvm::is_contained(allowedN, sizeN))
return success();
@@ -456,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
- return emitOpError() << "supports non-transpose A (Row Major) "
- "and transpose B (Column Major) for the time being";
+ return emitOpError()
+ << "supports non-transpose A (Row Major) "
+ "and transpose B (Column Major) for the time being ";
MemRefType matrixA = getDescriptorA().getType().getTensor();
MemRefType matrixB = getDescriptorB().getType().getTensor();
- VectorType matrixC = getMatrixC()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- VectorType matrixD = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- unsigned sizeAcc = getMatrixC().size();
-
- if (getMatrixC().size() != getMatrixD().size())
- return emitOpError() << "number of matrix C and matrix D must be the same";
-
- if (llvm::all_of(getMatrixC(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix C must be the same";
- }
- if (llvm::all_of(getMatrixD(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix D must be the same as matrix C";
- }
+ VectorType matrixC = getMatrixC().getType().getFragmented();
+ VectorType matrixD = getMatrixD().getType().getFragmented();
+
+ if (matrixC != matrixD)
+ return emitOpError() << "type of matrix C and matrix D must be the same";
if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -496,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
- if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
+ if (matrixA.getShape()[0] != matrixC.getShape()[0])
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
@@ -532,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
- VectorType firstVtype = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
-
- int64_t totalFirstDimension = 0;
- for (Value result : getMatrixD()) {
- VectorType vtype =
- result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
- if (vtype != firstVtype)
- return emitOpError() << "all fragmented types must be the same";
- // Limitation
- if (!vtype.getElementType().isF32()) {
- return emitOpError()
- << "hit a limitation: only f32 results for the time being";
- }
- totalFirstDimension += vtype.getDimSize(0);
+ VectorType vtype = getMatrixD().getType().getFragmented();
+
+ // Limitation
+ if (!vtype.getElementType().isF32()) {
+ return emitOpError()
+ << "hit a limitation: only f32 results for the time being";
}
- if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
- firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
- return emitOpError() << "results [" << totalFirstDimension << "]["
- << firstVtype.getDimSize(1)
+ if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
+ vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+ return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
@@ -563,6 +537,27 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// WarpgroupMmaInitAccumulatorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
+
+ nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+ int64_t sizeM = accType.getFragmented().getDimSize(0);
+ int64_t sizeN = accType.getFragmented().getDimSize(1);
+ Type elemType = accType.getFragmented().getElementType();
+
+ if (failed(isAllowedSizeM(sizeM)) ||
+ failed(isAllowedSizeN(sizeN, elemType))) {
+ return emitOpError() << "has type " << accType.getFragmented()
+ << ". It does not fit into warp-group "
+ "level (wgmma) matrix multiplication instruction "
+ "(or not supported yet)";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index e54b62a06d4313a..bf660e2683158e5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -713,18 +713,18 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.
}
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
func.func @warpgroup_mma_128_128_64(
%descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
- %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
+ %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
{
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[ar...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/68728
More information about the Mlir-commits
mailing list