[Mlir-commits] [mlir] [MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs (PR #65441)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 22 07:56:19 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu
<details>
<summary>Changes</summary>
[MLIR][NVGPU] Introduce `nvgpu.wargroup.mma.store` Op for Hopper GPUs
This work introduces a new operation called `wargroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.
An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d
The `wargroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.
Here's an example usage of the `nvgpu.wargroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.wargroup.mma ...
// Stores the fragmented result to the give memory
nvgpu.wargroup.mma.store [%res1, %res2], %matrixD :
!nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
!nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
to memref<128x128xf32,3>
```
Depends on #<!-- -->65440
---
Full diff: https://github.com/llvm/llvm-project/pull/65441.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+20)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+81-2)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+32)
``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
let hasVerifier = 1;
}
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+ let description = [{
+ The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result
+ in $matrixD to give memref.
+
+ [See the details of register fragment layout for accumulator matrix D]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d)
+
+ Note that, the op must be run with warp group.
+ }];
+
+ let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+ Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+
+ let assemblyFormat = [{
+ `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+ }];
+ let hasVerifier = 1;
+}
+
#endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
using Base::Base;
void getDependentDialects(DialectRegistry ®istry) const override {
- registry
- .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+ target.addLegalDialect<::mlir::arith::ArithDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,6 +1301,82 @@ struct NVGPUWarpgroupMmaOpLowering
}
};
+struct NVGPUWarpgroupMmaStoreOpLowering
+ : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+ using ConvertOpToLLVMPattern<
+ nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+ void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ int offset) const {
+ Location loc = op->getLoc();
+ Type i32 = rewriter.getI32Type();
+
+ auto makeConst = [&](int32_t index) -> Value {
+ return rewriter.create<LLVM::ConstantOp>(
+ loc, i32, rewriter.getI32IntegerAttr(index));
+ };
+ Value c4 = makeConst(4);
+ Value c32 = makeConst(kWarpSize);
+ Value c8 = makeConst(8);
+ Value c2 = makeConst(2);
+ Value c1 = makeConst(1);
+ Value c16 = makeConst(16);
+
+ auto makeMul = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+ };
+ auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+ return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+ };
+
+ Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+ Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+ Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+ Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+ Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+ auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+ TypedValue<::mlir::MemRefType> memref) {
+ Type it = rewriter.getIndexType();
+ Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+ Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+ Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+ Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+ Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+ rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+ rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+ };
+
+ Value tj = makeMul(lane4modId, c2);
+ Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+ if (offset)
+ ti = makeAdd(ti, makeConst(offset));
+ for (int i = 0; i < 2; ++i) {
+ Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+ for (int j = 0; j < 16; ++j) {
+ Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+ int sIndex = i * 2 + j * 4;
+ makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+ }
+ }
+ }
+
+ LogicalResult
+ matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ int offset = 0;
+ for (auto result : adaptor.getMatrixD()) {
+ auto stype = result.getType().cast<LLVM::LLVMStructType>();
+ storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+ offset += stype.getBody().size();
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1315,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor
NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma
+ NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store`
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,37 @@ LogicalResult WarpgroupMmaOp::verify() {
return success();
}
+LogicalResult WarpgroupMmaStoreOp::verify() {
+ Type stype = getMatrixD()
+ .front()
+ .getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented();
+
+ for (auto result : getMatrixD()) {
+ auto resultStype = result.getType()
+ .cast<WarpgroupAccumulatorType>()
+ .getFragmented()
+ .dyn_cast<LLVM::LLVMStructType>();
+ if (!resultStype)
+ return emitOpError() << "result is " << result.getType()
+ << " but must keep type of llvm struct";
+ if (stype != resultStype)
+ return emitOpError() << "all results must be the same type";
+
+ // todo improve this limitation
+ if (!resultStype.getBody().front().isF32()) {
+ return emitOpError() << "supporst only f32 results for the time being";
+ }
+ }
+
+ if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+ return emitOpError() << "all element types must be equal ";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/65441
More information about the Mlir-commits
mailing list