[Mlir-commits] [mlir] [mlir][nvgpu] Add `nvgpu.tma.async.store` (PR #77811)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 11 10:23:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
PR adds `nvgpu.tma.async.store` Op for asynchronous stores using the Tensor Memory Access (TMA) unit.
It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
---
Full diff: https://github.com/llvm/llvm-project/pull/77811.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+22)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+24)
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+23)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+46)
``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
}
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+ let summary = "TMA asynchronous store";
+ let description = [{
+ The Op store a tile memory region from global memory to shared memory by
+ Tensor Memory Access (TMA).
+
+ `$tensorMapDescriptor` is tensor map descriptor which has information about
+ tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+ }];
+ let arguments = (ins Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+ NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+ Variadic<Index>:$coordinates,
+ Optional<I1>:$predicate);
+ let assemblyFormat = [{
+ $src `to` $tensorMapDescriptor `[` $coordinates `]`
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict `:` type($src)
+ `->` type($tensorMapDescriptor)
+ }];
+ let hasVerifier = 1;
+}
+
def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
let summary = "TMA create descriptor";
let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
return success();
}
};
+
+struct NVGPUTmaAsyncStoreOpLowering
+ : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+ using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+ Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+ adaptor.getSrc(), {}, rewriter);
+ SmallVector<Value> coords = adaptor.getCoordinates();
+ for (auto [index, value] : llvm::enumerate(coords)) {
+ coords[index] = truncToI32(b, value);
+ }
+
+ rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+ op, adaptor.getTensorMapDescriptor(), dest, coords,
+ adaptor.getPredicate());
+ return success();
+ }
+};
+
struct NVGPUGenerateWarpgroupDescriptorLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity
NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity
NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load
+ NVGPUTmaAsyncStoreOpLowering, // nvgpu.tma.async.store
NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor
NVGPUTmaPrefetchOpLowering, // nvgpu.tma.prefetch.descriptor
NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+ std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+ *this, getTensorMapDescriptor().getType(), getSrc().getType());
+ if (error.has_value())
+ return error.value();
+
+ if (getCoordinates().size() > kMaxTMATensorDimension) {
+ return emitError() << "Maximum " << kMaxTMATensorDimension
+ << " coordinates are supported.";
+ }
+ if (getCoordinates().size() !=
+ size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+ return emitError() << "number of coordinates do not match with the rank of "
+ "tensor descriptor map.";
+ }
+
+ return success();
+}
+
LogicalResult TmaCreateDescriptorOp::verify() {
if (getBoxDimensions().size() > kMaxTMATensorDimension) {
return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
func.return
}
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1] : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0] : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0] : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0] : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d,
+ %buffer1d: memref<128xf32,3>,
+ %buffer2d: memref<32x32xf32,3>,
+ %buffer3d: memref<2x32x32xf32,3>,
+ %buffer4d: memref<2x2x32x32xf32,3>,
+ %buffer5d: memref<2x2x2x32x32xf32,3>,
+ %p: i1) {
+ %c0 = arith.constant 0 : index
+ %crd0 = arith.constant 0 : index
+ %crd1 = arith.constant 0 : index
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p : memref<32x32xf32,3> -> !tensorMap2d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p : memref<2x32x32xf32,3> -> !tensorMap3d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p : memref<2x2x32x32xf32,3> -> !tensorMap4d
+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+ nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+ func.return
+}
+
func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
%crd0 = arith.constant 64 : index
%crd1 = arith.constant 128 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/77811
More information about the Mlir-commits
mailing list