[Mlir-commits] [mlir] [mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (PR #68728)
Guray Ozen
llvmlistbot at llvm.org
Mon Oct 16 05:58:50 PDT 2023
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/68728
>From acce6abce8e5f0da1df64b0978b64dce64ba4d6d Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 10 Oct 2023 02:33:20 +0200
Subject: [PATCH 1/2] [mlir][nvgpu] Simplify `NVGPU_WarpgroupAccumulator`
accumulator
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 10 +-
.../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 3 +
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 112 +++++++++++-------
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 99 ++++++----------
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 65 +++++-----
5 files changed, 145 insertions(+), 144 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 79183acfb71b61e..fd16376be366912 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -719,8 +719,8 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
- Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
- let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixD);
+ NVGPU_WarpgroupAccumulator:$matrixC);
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixD);
let assemblyFormat = [{
$descriptorA`,` $descriptorB`,` $matrixC attr-dict
`:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD)
@@ -739,11 +739,11 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
Note that, the op must be run with warp group.
}];
- let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+ let arguments = (ins NVGPU_WarpgroupAccumulator:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
let assemblyFormat = [{
- `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ $matrixD `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
}];
let hasVerifier = 1;
}
@@ -755,7 +755,7 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
This Op generates and initializes the accumulator matrix for
`nvgpu.warpgroup.mma` op to perform matrix-multiply-and-accumulate.
}];
- let results = (outs Variadic<NVGPU_WarpgroupAccumulator>:$matrixC);
+ let results = (outs NVGPU_WarpgroupAccumulator:$matrixC);
let assemblyFormat = "attr-dict `->` type($matrixC)";
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 96af26842dafea2..e6bba7e6082964b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -23,6 +23,9 @@
constexpr int kWarpSize = 32;
+/// M size of wgmma.mma_async instruction
+constexpr int kWgmmaSizeM = 64;
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 84f53a4572294ad..2d43230938526b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -412,10 +412,28 @@ struct ConvertNVGPUToNVVMPass
return converter.convertType(IntegerType::get(type.getContext(), 32));
});
converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- VectorType vtype = type.getFragmented();
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
SmallVector<Type> structBody;
- for (unsigned i = 0; i < vtype.getDimSize(0); i++)
- structBody.push_back(vtype.getElementType());
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return converter.convertType(convertedType);
@@ -1186,7 +1204,6 @@ struct NVGPUWarpgroupMmaOpLowering
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
- const LLVMTypeConverter &typeConverter;
// Entire shape of the given Op
int64_t totalM, totalN, totalK;
@@ -1330,7 +1347,7 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
- Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
+ Value generateWgmma(int i, int j, int k, Value matrixC) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
@@ -1359,34 +1376,36 @@ struct NVGPUWarpgroupMmaOpLowering
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
- Type resultStructType = typeConverter.convertType(matrixD.getType());
-
return b.create<NVVM::WgmmaMmaAsyncOp>(
- resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
+ matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
- SmallVector<Value> generateWgmmaGroup() {
- SmallVector<Value> wgmmaResults;
+ Value generateWgmmaGroup() {
+ Value wgmmaResult =
+ b.create<LLVM::UndefOp>(adaptor.getMatrixC().getType());
// Perform GEMM
+ SmallVector<Value> wgmmaResults;
for (int i = 0; i < iterationM; ++i) {
- Value matrixC = adaptor.getMatrixC()[i];
- Value matrixD = op.getMatrixD()[i];
+ Value matrixC = b.create<LLVM::ExtractValueOp>(adaptor.getMatrixC(), i);
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
- matrixC = generateWgmma(i, j, k, matrixC, matrixD);
+ matrixC = generateWgmma(i, j, k, matrixC);
wgmmaResults.push_back(matrixC);
}
-
- return wgmmaResults;
+ for (auto [idx, matrix] : llvm::enumerate(wgmmaResults)) {
+ wgmmaResult = b.create<LLVM::InsertValueOp>(wgmmaResult.getType(),
+ wgmmaResult, matrix, idx);
+ }
+ return wgmmaResult;
}
public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
- OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
- : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
+ OpAdaptor adaptor)
+ : op(op), b(b), adaptor(adaptor) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
@@ -1411,27 +1430,27 @@ struct NVGPUWarpgroupMmaOpLowering
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
- SmallVector<Value> generateWarpgroupMma() {
+ Value generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
- SmallVector<Value> wgmmaResults = generateWgmmaGroup();
+ Value wgmmaResult = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
- return wgmmaResults;
+ return wgmmaResult;
}
};
-
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+
// Step 1. Build a helper class
- WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
+ WarpgroupGemm warpgroupGemm(op, b, adaptor);
// Step 2. Get the entire GEMM Shape
- SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
+ Value wgmmaResult = warpgroupGemm.generateWarpgroupMma();
// Step 3. Replace fragmented result struct with the op results
- rewriter.replaceOp(op, wgmmaResults);
+ rewriter.replaceOp(op, wgmmaResult);
return success();
}
};
@@ -1535,10 +1554,13 @@ struct NVGPUWarpgroupMmaStoreOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int offset = 0;
- ImplicitLocOpBuilder lb(op->getLoc(), rewriter);
- for (Value matrixD : adaptor.getMatrixD()) {
- auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
- storeFragmentedMatrix(lb, matrixD, op.getDstMemref(), offset);
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ Value matriDValue = adaptor.getMatrixD();
+ auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
+ for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
+ auto structType = matrixD.cast<LLVM::LLVMStructType>();
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
+ storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
}
rewriter.eraseOp(op);
@@ -1554,23 +1576,27 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
- SmallVector<Value> results;
- for (OpResult m : op.getMatrixC()) {
- nvgpu::WarpgroupAccumulatorType mType =
- m.getType().cast<nvgpu::WarpgroupAccumulatorType>();
- Type stype = getTypeConverter()->convertType(mType);
- Value undefStruct = b.create<LLVM::UndefOp>(stype);
- Type elemType = mType.getFragmented().getElementType();
- int64_t elemSize = mType.getFragmented().getDimSize(0);
- Value zero =
- b.create<LLVM::ConstantOp>(elemType, rewriter.getZeroAttr(elemType));
- for (int64_t i = 0; i < elemSize; ++i) {
- undefStruct = b.create<LLVM::InsertValueOp>(stype, undefStruct, zero,
- ArrayRef<int64_t>({i}));
+ LLVM::LLVMStructType structType =
+ getTypeConverter()
+ ->convertType(op.getMatrixC().getType())
+ .cast<LLVM::LLVMStructType>();
+ Type elemType = structType.getBody()
+ .front()
+ .cast<LLVM::LLVMStructType>()
+ .getBody()
+ .front();
+ Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
+ Value structValue = b.create<LLVM::UndefOp>(structType);
+ for (auto [idx, s] : llvm::enumerate(structType.getBody())) {
+ auto innerStructType = s.cast<LLVM::LLVMStructType>();
+ int ii = idx;
+ Value innerStructValue = b.create<LLVM::ExtractValueOp>(structValue, ii);
+ for (unsigned i = 0; i < innerStructType.getBody().size(); ++i) {
+ innerStructValue = b.create<LLVM::InsertValueOp>(
+ innerStructType, innerStructValue, zero, ArrayRef<int64_t>({i}));
}
- results.push_back(undefStruct);
}
- rewriter.replaceOp(op, results);
+ rewriter.replaceOp(op, structValue);
return success();
}
};
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index fe71eae899cd63d..f5b02fe1b515591 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -435,7 +435,11 @@ LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) {
return failure();
}
-LogicalResult isAllowedSizeM(int sizeM) { return success(sizeM == 64); }
+LogicalResult isAllowedSizeM(int sizeM) {
+ if (sizeM % kWgmmaSizeM)
+ return failure();
+ return success();
+}
LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
SmallVector<int> allowedN = {8, 16, 24, 32, 40, 48, 56, 64,
@@ -458,35 +462,16 @@ LogicalResult isAllowedSizeN(int sizeN, Type typeA) {
LogicalResult WarpgroupMmaOp::verify() {
if (getTransposeA() && !getTransposeB())
- return emitOpError() << "supports non-transpose A (Row Major) "
- "and transpose B (Column Major) for the time being";
+ return emitOpError()
+ << "supports non-transpose A (Row Major) "
+ "and transpose B (Column Major) for the time being ";
MemRefType matrixA = getDescriptorA().getType().getTensor();
MemRefType matrixB = getDescriptorB().getType().getTensor();
- VectorType matrixC = getMatrixC()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- VectorType matrixD = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
- unsigned sizeAcc = getMatrixC().size();
-
- if (getMatrixC().size() != getMatrixD().size())
- return emitOpError() << "number of matrix C and matrix D must be the same";
-
- if (llvm::all_of(getMatrixC(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix C must be the same";
- }
- if (llvm::all_of(getMatrixD(),
- [&](Value rhs) { return rhs.getType() == matrixC; })) {
- return emitOpError()
- << "types of all operands in matrix D must be the same as matrix C";
- }
+ VectorType matrixC = getMatrixC().getType().getFragmented();
+ VectorType matrixD = getMatrixD().getType().getFragmented();
+
+ if (matrixC != matrixD)
+ return emitOpError() << "type of matrix C and matrix D must be the same";
if (matrixA.getRank() != 2 || matrixB.getRank() != 2 ||
matrixC.getRank() != 2 || matrixD.getRank() != 2) {
@@ -498,7 +483,7 @@ LogicalResult WarpgroupMmaOp::verify() {
return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1]
<< ")!= 1st dim matrix-B (" << matrixB.getShape()[0]
<< " )";
- if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc))
+ if (matrixA.getShape()[0] != matrixC.getShape()[0])
return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0]
<< " )!= 1st dim matrix-C ( " << matrixC.getShape()[0]
<< " )";
@@ -534,29 +519,16 @@ LogicalResult WarpgroupMmaOp::verify() {
LogicalResult WarpgroupMmaStoreOp::verify() {
MemRefType dstMemrefType = getDstMemref().getType();
- VectorType firstVtype = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
-
- int64_t totalFirstDimension = 0;
- for (Value result : getMatrixD()) {
- VectorType vtype =
- result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
- if (vtype != firstVtype)
- return emitOpError() << "all fragmented types must be the same";
- // Limitation
- if (!vtype.getElementType().isF32()) {
- return emitOpError()
- << "hit a limitation: only f32 results for the time being";
- }
- totalFirstDimension += vtype.getDimSize(0);
+ VectorType vtype = getMatrixD().getType().getFragmented();
+
+ // Limitation
+ if (!vtype.getElementType().isF32()) {
+ return emitOpError()
+ << "hit a limitation: only f32 results for the time being";
}
- if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
- firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
- return emitOpError() << "results [" << totalFirstDimension << "]["
- << firstVtype.getDimSize(1)
+ if (vtype.getDimSize(0) != dstMemrefType.getDimSize(0) ||
+ vtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+ return emitOpError() << "results [" << vtype << "][" << vtype.getDimSize(1)
<< "] values. However, destination memref["
<< dstMemrefType.getDimSize(0) << "]["
<< dstMemrefType.getDimSize(1)
@@ -570,19 +542,18 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
- for (OpResult matrix : getMatrixC()) {
- VectorType vectorType = matrix.getType()
- .cast<nvgpu::WarpgroupAccumulatorType>()
- .getFragmented();
- // Check [M][N] shape
- if (failed(isAllowedSizeM(vectorType.getDimSize(0))) ||
- failed(isAllowedSizeN(vectorType.getDimSize(1),
- vectorType.getElementType()))) {
- return emitOpError() << "has type " << vectorType
- << ". It does not fit into warp-group "
- "level (wgmma) matrix multiplication instruction "
- "(or not supported yet)";
- }
+
+ nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+ int64_t sizeM = accType.getFragmented().getDimSize(0);
+ int64_t sizeN = accType.getFragmented().getDimSize(1);
+ Type elemType = accType.getFragmented().getElementType();
+
+ if (failed(isAllowedSizeM(sizeM)) ||
+ failed(isAllowedSizeN(sizeN, elemType))) {
+ return emitOpError() << "has type " << accType.getFragmented()
+ << ". It does not fit into warp-group "
+ "level (wgmma) matrix multiplication instruction "
+ "(or not supported yet)";
}
return success();
}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index ca030575e5e961e..bf660e2683158e5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -713,18 +713,18 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.warpgroup.
}
// CHECK-LABEL: @warpgroup_mma_128_128_64(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
func.func @warpgroup_mma_128_128_64(
%descA: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
- %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>)
+ %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
{
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.fence.aligned
+// CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S2]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64
// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64
@@ -741,6 +741,7 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64
// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], <m = 64, n = 128, k = 16>, D[%[[S14]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S3:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64
// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64
// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], <m = 64, n = 128, k = 16>, D[%[[S3]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
@@ -759,27 +760,26 @@ func.func @warpgroup_mma_128_128_64(
// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64
// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64
// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], <m = 64, n = 128, k = 16>, D[%[[S33]], <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <col>] : !llvm.struct
+// CHECK: %[[S40:.+]] = llvm.insertvalue %[[S19]], %[[UD]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[S41:.+]] = llvm.insertvalue %[[S38]], %[[S40]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.commit.group.sync.aligned
// CHECK: nvvm.wgmma.wait.group.sync.aligned 1
- %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}:
+ %wgmmaResult = nvgpu.warpgroup.mma %descA, %descB, %acc {transposeB}:
!nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>,
!nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
->
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>
return
}
// CHECK-LABEL: @warpgroup_mma_store(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
func.func @warpgroup_mma_store(
- %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ %result : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>,
%matrixD: memref<128x128xf32,3>) {
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK: %[[EX1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
@@ -807,8 +807,8 @@ func.func @warpgroup_mma_store(
// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
-// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
-// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
+// CHECK: %[[S26:.+]] = llvm.extractvalue %[[EX1]][0] : !llvm.struct
+// CHECK: %[[S27:.+]] = llvm.extractvalue %[[EX1]][1] : !llvm.struct
// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
@@ -821,8 +821,8 @@ func.func @warpgroup_mma_store(
// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
-// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
-// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
+// CHECK: %[[S35:.+]] = llvm.extractvalue %[[EX1]][4] : !llvm.struct<
+// CHECK: %[[S36:.+]] = llvm.extractvalue %[[EX1]][5] : !llvm.struct<
// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
@@ -835,8 +835,8 @@ func.func @warpgroup_mma_store(
// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
-// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
-// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
+// CHECK: %[[S44:.+]] = llvm.extractvalue %[[EX1]][8] : !llvm.struct<
+// CHECK: %[[S45:.+]] = llvm.extractvalue %[[EX1]][9] : !llvm.struct<
// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
@@ -849,8 +849,8 @@ func.func @warpgroup_mma_store(
// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
-// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
-// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
+// CHECK: %[[S53:.+]] = llvm.extractvalue %[[EX1]][12] : !llvm.struct<
+// CHECK: %[[S54:.+]] = llvm.extractvalue %[[EX1]][13] : !llvm.struct<
// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
@@ -860,7 +860,7 @@ func.func @warpgroup_mma_store(
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
// ### Store {d64, d65} of each thread ###
-
+// CHECK: %[[EX2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[S312:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
@@ -887,24 +887,24 @@ func.func @warpgroup_mma_store(
// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
-// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
-// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
+// CHECK: %[[S337:.+]] = llvm.extractvalue %[[EX2]][0]
+// CHECK: %[[S338:.+]] = llvm.extractvalue %[[EX2]][1]
// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
// Pattern continues similarly 31x times until {... d126, d127}
- nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD :
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
- !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ nvgpu.warpgroup.mma.store %result, %matrixD :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<128x128xf32>>
to memref<128x128xf32,3>
return
}
func.func @warpgroup_mma_init() {
- //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
//CHECK: %[[S1:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f3
- //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[S0]][0] : !llvm.struct
+ //CHECK: %[[S0:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ //CHECK: %[[EX:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+ //CHECK: %[[S2:.+]] = llvm.insertvalue %[[S1]], %[[EX]][0] : !llvm.struct
//CHECK: %[[S3:.+]] = llvm.insertvalue %[[S1]], %[[S2]][1] : !llvm.struct
//CHECK: %[[S4:.+]] = llvm.insertvalue %[[S1]], %[[S3]][2] : !llvm.struct
//CHECK: %[[S5:.+]] = llvm.insertvalue %[[S1]], %[[S4]][3] : !llvm.struct
@@ -968,10 +968,11 @@ func.func @warpgroup_mma_init() {
//CHECK: %[[S63:.+]] = llvm.insertvalue %[[S1]], %[[S62]][61] : !llvm.struct
//CHECK: %[[S64:.+]] = llvm.insertvalue %[[S1]], %[[S63]][62] : !llvm.struct
//CHECK: %[[S65:.+]] = llvm.insertvalue %[[S1]], %[[S64]][63] : !llvm.struct
- %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ %matrixC = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator< fragmented = vector<128x128xf32>>
return
}
+
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
>From 04582c019cec9c7c8fa9122b35ace2e301e6348c Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 16 Oct 2023 14:58:16 +0200
Subject: [PATCH 2/2] fix transform dialect
---
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 24 ++++++++++++++++---
1 file changed, 21 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 94d7d565ff1a905..eaaadbbea4d0a75 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -62,10 +62,28 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
});
llvmTypeConverter.addConversion(
[&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- VectorType vtype = type.getFragmented();
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType = LLVM::LLVMStructType::getLiteral(
+ type.getContext(), innerStructBody);
+
SmallVector<Type> structBody;
- for (unsigned i = 0; i < vtype.getDimSize(0); i++)
- structBody.push_back(vtype.getElementType());
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
auto convertedType =
LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
return llvmTypeConverter.convertType(convertedType);
More information about the Mlir-commits
mailing list