[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)
Guray Ozen
llvmlistbot at llvm.org
Wed Sep 27 00:43:59 PDT 2023
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/65441
>From 7b71da55fca8fe2a7dbe4982b1959be6a6175fa1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 7 Sep 2023 11:52:38 +0200
Subject: [PATCH 1/4] [MLIR][NVGPU] Introduce `nvgpu.warpgroup.mma.store` Op
for Hopper GPUs
This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.
An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d
The `warpgroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.
Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.warpgroup.mma ...
// Stores the fragmented result to the give memory
nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD :
!nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
!nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
to memref<128x128xf32,3>
```
Depends on #65440
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 19 +++++
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 83 ++++++++++++++++++-
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 29 +++++++
3 files changed, 129 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..e102ae0dc581013 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,23 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+ let description = [{
+ The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
+ in $matrixD to give memref.
+
+ [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+
+ Note that, the op must be run with warp group.
+ }];
+
+ let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+ Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+
+ let assemblyFormat = [{
+ `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..4f1a0bc651e81b7 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
using Base::Base;
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+ target.addLegalDialect<::mlir::arith::ArithDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,11 +1301,88 @@ struct NVGPUWarpgroupMmaOpLowering
}
};
+struct NVGPUWarpgroupMmaStoreOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+ void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ int offset) const {
+ Location loc = op->getLoc();
+ Type i32 = rewriter.getI32Type();
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, i32, rewriter.getI32IntegerAttr(index));
+ };
+ Value c4 = makeConst(4);
+ Value c32 = makeConst(kWarpSize);
+ Value c8 = makeConst(8);
+ Value c2 = makeConst(2);
+ Value c1 = makeConst(1);
+ Value c16 = makeConst(16);
+
+ auto makeMul = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+ };
+ auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ };
+
+ Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+ Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+ Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+ Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+ Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+ auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+ TypedValue<::mlir::MemRefType> memref) {
+ Type it = rewriter.getIndexType();
+ Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+ Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+ Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+ Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+ Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+ rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+ rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+ };
+
+ Value tj = makeMul(lane4modId, c2);
+ Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+ if (offset)
+ ti = makeAdd(ti, makeConst(offset));
+ for (int i = 0; i < 2; ++i) {
+ Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+ for (int j = 0; j < 16; ++j) {
+ Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+ int sIndex = i * 2 + j * 4;
+ makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+ }
+ }
+ }
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ int offset = 0;
+ for (auto result : adaptor.getMatrixD()) {
+ auto stype = result.getType().cast<LLVM::LLVMStructType>();
+ storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+ offset += stype.getBody().size();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<
+ NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..fc85df1654198d5 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,34 @@ LogicalResult WarpgroupMmaOp::verify() {
return success();
}
+LogicalResult WarpgroupMmaStoreOp::verify() {
+ Type stype =
+ getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+
+ for (auto result : getMatrixD()) {
+ auto resultStype = result.getType()
+ .cast<WarpgroupResultType>()
+ .getTensor()
+ .dyn_cast<LLVM::LLVMStructType>();
+ if (!resultStype)
+ return emitOpError() << "result is " << result.getType()
+ << " but must keep type of llvm struct";
+ if (stype != resultStype)
+ return emitOpError() << "all results must be the same type";
+
+ // todo improve this limitation
+ if (!resultStype.getBody().front().isF32()) {
+ return emitOpError() << "supporst only f32 results for the time being";
+ }
+ }
+
+ if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+ return emitOpError() << "all element types must be equal ";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
>From 4a1824f3e6ae955b78f7262178fa1b8e4608e3da Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Sep 2023 16:53:21 +0200
Subject: [PATCH 2/4] use new type `WarpgroupAccumulator`
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 5 +++--
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +-
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 11 +++++++----
3 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index e102ae0dc581013..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -726,12 +726,13 @@ def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
in $matrixD to give memref.
- [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+ [See the details of register fragment layout for accumulator matrix D]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
Note that, the op must be run with warp group.
}];
- let arguments = (ins Variadic<NVGPU_WarpgroupResult>:$matrixD,
+ let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 4f1a0bc651e81b7..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1382,7 +1382,6 @@ struct NVGPUWarpgroupMmaStoreOpLowering
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<
- NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
@@ -1394,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
+ NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index fc85df1654198d5..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -531,13 +531,16 @@ LogicalResult WarpgroupMmaOp::verify() {
}
LogicalResult WarpgroupMmaStoreOp::verify() {
- Type stype =
- getMatrixD().front().getType().cast<WarpgroupResultType>().getTensor();
+ Type stype = getMatrixD()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
for (auto result : getMatrixD()) {
auto resultStype = result.getType()
- .cast<WarpgroupResultType>()
- .getTensor()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented()
.dyn_cast<LLVM::LLVMStructType>();
if (!resultStype)
return emitOpError() << "result is " << result.getType()
>From e60310d10c8e43669402e432cd130383cdf7a837 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 27 Sep 2023 09:41:45 +0200
Subject: [PATCH 3/4] add test
---
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 127 ++++++++++++++++++
1 file changed, 127 insertions(+)
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index f011007e040ce9c..93123cecbc38f94 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -732,6 +732,133 @@ func.func @warpgroup_mma_128_128_64(
return
}
+// CHECK-LABEL: @warpgroup_mma_store(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>)
+func.func @warpgroup_mma_store(
+ %result1 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ %result2 : !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ %matrixD: memref<128x128xf32,3>) {
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[DB:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast %[[arg2]] :
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32
+// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32
+
+// ### Store {d0, d1} of each thread ###
+
+// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32
+// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32
+// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32
+// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32
+// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32
+// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32
+// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32
+// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32
+// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32
+// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32
+// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32
+// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32
+// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index
+// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32
+// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index
+// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct
+// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct
+// CHECK: memref.store %[[S26]], %[[arg2]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S27]], %[[arg2]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3>
+
+// ### Store {d2, d3} of each thread ###
+
+// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32
+// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32
+// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index
+// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32
+// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index
+// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct<
+// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct<
+// CHECK: memref.store %[[S35]], %[[arg2]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S36]], %[[arg2]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3>
+
+// ### Store {d4, d5} of each thread ###
+
+// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32
+// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32
+// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index
+// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32
+// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index
+// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct<
+// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct<
+// CHECK: memref.store %[[S44]], %[[arg2]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S45]], %[[arg2]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3>
+
+// ### Store {d6, d7} of each thread ###
+
+// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32
+// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32
+// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index
+// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index
+// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32
+// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index
+// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct<
+// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct<
+// CHECK: memref.store %[[S53]], %[[arg2]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S54]], %[[arg2]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3>
+
+// Pattern continues similarly 28x times until {... d62, d63}
+
+// ### Store {d64, d65} of each thread ###
+
+// CHECK: %[[S311:.+]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[S312:.+]] = llvm.mlir.constant(32 : i32) : i32
+// CHECK: %[[S313:.+]] = llvm.mlir.constant(8 : i32) : i32
+// CHECK: %[[S314:.+]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[S315:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[S316:.+]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[S317:.+]] = nvvm.read.ptx.sreg.tid.x : i32
+// CHECK: %[[S318:.+]] = llvm.urem %[[S317]], %[[S312]] : i32
+// CHECK: %[[S319:.+]] = llvm.udiv %[[S317]], %[[S312]] : i32
+// CHECK: %[[S320:.+]] = llvm.udiv %[[S318]]
+// CHECK: %[[S321:.+]] = llvm.urem %[[S318]]
+// CHECK: %[[S322:.+]] = llvm.mul %[[S321]], %[[S314]] : i32
+// CHECK: %[[S323:.+]] = llvm.mul %[[S319]], %[[S316]] : i32
+// CHECK: %[[S324:.+]] = llvm.add %[[S320]], %[[S323]] : i32
+// CHECK: %[[S325:.+]] = llvm.mlir.constant(64 : i32) : i32
+// CHECK: %[[S326:.+]] = llvm.add %[[S324]], %[[S325]] : i32
+// CHECK: %[[S327:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S328:.+]] = llvm.mul %[[S327]], %[[S313]] : i32
+// CHECK: %[[S329:.+]] = llvm.add %[[S326]], %[[S328]] : i32
+// CHECK: %[[S330:.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[S331:.+]] = llvm.mul %[[S330]], %[[S313]] : i32
+// CHECK: %[[S332:.+]] = llvm.add %[[S322]], %[[S331]] : i32
+// CHECK: %[[S333:.+]] = arith.index_cast %[[S329]] : i32 to index
+// CHECK: %[[S334:.+]] = arith.index_cast %[[S332]] : i32 to index
+// CHECK: %[[S335:.+]] = llvm.add %[[S332]], %[[S315]] : i32
+// CHECK: %[[S336:.+]] = arith.index_cast %[[S335]] : i32 to index
+// CHECK: %[[S337:.+]] = llvm.extractvalue %[[S1]][0]
+// CHECK: %[[S338:.+]] = llvm.extractvalue %[[S1]][1]
+// CHECK: memref.store %[[S337]], %[[arg2]][%[[S333]], %[[S334]]] : memref<128x128xf32, 3>
+// CHECK: memref.store %[[S338]], %[[arg2]][%[[S333]], %[[S336]]] : memref<128x128xf32, 3>
+
+// Pattern continues similarly 31x times until {... d126, d127}
+
+ nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD :
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
+ !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>
+ to memref<128x128xf32,3>
+ return
+}
+
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["func.func"]} in %arg1
>From fe03a52c573c287efba3e9c77837a5d91a1e3ad1 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 27 Sep 2023 09:41:53 +0200
Subject: [PATCH 4/4] better verification
---
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 48 +++++++++++-----------
1 file changed, 25 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 1486bba5d3e57f6..b9994aced0be7f4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -531,33 +531,35 @@ LogicalResult WarpgroupMmaOp::verify() {
}
LogicalResult WarpgroupMmaStoreOp::verify() {
- Type stype = getMatrixD()
- .front()
- .getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented();
-
+ MemRefType dstMemrefType = getDstMemref().getType();
+ VectorType firstVtype = getMatrixD()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
+
+ int64_t totalFirstDimension = 0;
for (auto result : getMatrixD()) {
- auto resultStype = result.getType()
- .cast<WarpgroupAccumulatorType>()
- .getFragmented()
- .dyn_cast<LLVM::LLVMStructType>();
- if (!resultStype)
- return emitOpError() << "result is " << result.getType()
- << " but must keep type of llvm struct";
- if (stype != resultStype)
- return emitOpError() << "all results must be the same type";
-
- // todo improve this limitation
- if (!resultStype.getBody().front().isF32()) {
- return emitOpError() << "supporst only f32 results for the time being";
+ VectorType vtype =
+ result.getType().cast<WarpgroupAccumulatorType>().getFragmented();
+ if (vtype != firstVtype)
+ return emitOpError() << "all fragmented types must be the same";
+ // Limitation
+ if (!vtype.getElementType().isF32()) {
+ return emitOpError()
+ << "hit a limitation: only f32 results for the time being";
}
+ totalFirstDimension += vtype.getDimSize(0);
}
-
- if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
- return emitOpError() << "all element types must be equal ";
+ if (totalFirstDimension != dstMemrefType.getDimSize(0) ||
+ firstVtype.getDimSize(1) != dstMemrefType.getDimSize(1)) {
+ return emitOpError() << "results [" << totalFirstDimension << "]["
+ << firstVtype.getDimSize(1)
+ << "] values. However, destination memref["
+ << dstMemrefType.getDimSize(0) << "]["
+ << dstMemrefType.getDimSize(1)
+ << "] does not have same size as results";
}
-
return success();
}
More information about the Mlir-commits
mailing list