[Mlir-commits] [mlir] [mlir][nvgpu] Add `nvgpu.tma.async.store` (PR #77811)

Guray Ozen llvmlistbot at llvm.org
Thu Jan 11 10:22:58 PST 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/77811

PR adds `nvgpu.tma.async.store` Op for asynchronous stores using the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.

>From bac1944406c9a9b391c18f78f5c3149ec665ed5d Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 11 Jan 2024 19:21:56 +0100
Subject: [PATCH] [mlir][nvgpu] Add `nvgpu.tma.async.store`

PR adds `nvgpu.tma.async.store` Op for asynchronous stores usingfrom the Tensor Memory Access (TMA) unit.

It also implements Op lowering to NVVM dialect. The Op currently performs asynchronous stores of a tile memory region from shared to global memory for a single CTA.
---
 mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td   | 22 +++++++++
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 24 ++++++++++
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    | 23 ++++++++++
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 46 +++++++++++++++++++
 4 files changed, 115 insertions(+)

diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 7e139663d74b47..239a5f1e2bc298 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -661,6 +661,28 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]
 
 }
 
+def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegments]> {
+  let summary = "TMA asynchronous store";
+  let description = [{
+    The Op store a tile memory region from global memory to shared memory by 
+    Tensor Memory Access (TMA).
+    
+    `$tensorMapDescriptor` is tensor map descriptor which has information about
+    tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
+  }];  
+  let arguments = (ins  Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
+                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+                        Variadic<Index>:$coordinates, 
+                        Optional<I1>:$predicate);
+  let assemblyFormat = [{
+      $src `to` $tensorMapDescriptor `[` $coordinates `]`
+      (`,` `predicate` `=` $predicate^)?
+      attr-dict `:` type($src)
+      `->` type($tensorMapDescriptor)
+  }];
+  let hasVerifier = 1;
+}
+
 def NVGPU_TmaCreateDescriptorOp : NVGPU_Op<"tma.create.descriptor", []> {
   let summary = "TMA create descriptor";
   let description = [{
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index db84e5cf62a5e9..759766275de4a5 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -995,6 +995,29 @@ struct NVGPUTmaAsyncLoadOpLowering
     return success();
   }
 };
+
+struct NVGPUTmaAsyncStoreOpLowering
+    : public MBarrierBasePattern<nvgpu::TmaAsyncStoreOp> {
+  using MBarrierBasePattern<nvgpu::TmaAsyncStoreOp>::MBarrierBasePattern;
+  LogicalResult
+  matchAndRewrite(nvgpu::TmaAsyncStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+    auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
+    Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
+                                      adaptor.getSrc(), {}, rewriter);
+    SmallVector<Value> coords = adaptor.getCoordinates();
+    for (auto [index, value] : llvm::enumerate(coords)) {
+      coords[index] = truncToI32(b, value);
+    }
+
+    rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(
+        op, adaptor.getTensorMapDescriptor(), dest, coords,
+        adaptor.getPredicate());
+    return success();
+  }
+};
+
 struct NVGPUGenerateWarpgroupDescriptorLowering
     : public ConvertOpToLLVMPattern<nvgpu::WarpgroupGenerateDescriptorOp> {
   using ConvertOpToLLVMPattern<
@@ -1639,6 +1662,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierTestWaitLowering,         // nvgpu.mbarrier.test_wait_parity
       NVGPUMBarrierTryWaitParityLowering,    // nvgpu.mbarrier.try_wait_parity
       NVGPUTmaAsyncLoadOpLowering,           // nvgpu.tma.async.load
+      NVGPUTmaAsyncStoreOpLowering,          // nvgpu.tma.async.store
       NVGPUTmaCreateDescriptorOpLowering,    // nvgpu.tma.create.descriptor
       NVGPUTmaPrefetchOpLowering,            // nvgpu.tma.prefetch.descriptor
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..5ffa854e97cb17 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -402,6 +402,29 @@ LogicalResult TmaAsyncLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// NVGPU_TmaAsyncStoreOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TmaAsyncStoreOp::verify() {
+  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+      *this, getTensorMapDescriptor().getType(), getSrc().getType());
+  if (error.has_value())
+    return error.value();
+
+  if (getCoordinates().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
+  }
+  if (getCoordinates().size() !=
+      size_t(getTensorMapDescriptor().getType().getTensor().getRank())) {
+    return emitError() << "number of coordinates do not match with the rank of "
+                          "tensor descriptor map.";
+  }
+
+  return success();
+}
+
 LogicalResult TmaCreateDescriptorOp::verify() {
   if (getBoxDimensions().size() > kMaxTMATensorDimension) {
     return emitError() << "Maximum " << kMaxTMATensorDimension
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..edccd7e80603bd 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
   func.return 
 }
 
+func.func @async_tma_store(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}] 
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0] : memref<128xf32,3> -> !tensorMap1d 
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1]  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0]  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0]  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] 
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0]  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
+
+func.func @async_tma_store_predicate(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
+                           %buffer1d: memref<128xf32,3>,      
+                           %buffer2d: memref<32x32xf32,3>,    
+                           %buffer3d: memref<2x32x32xf32,3>,  
+                           %buffer4d: memref<2x2x32x32xf32,3>,  
+                           %buffer5d: memref<2x2x2x32x32xf32,3>,
+                           %p: i1) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer1d to %tensorMap1d[%crd0], predicate = %p : memref<128xf32,3> -> !tensorMap1d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer2d to %tensorMap2d[%crd0, %crd1], predicate = %p  : memref<32x32xf32,3> -> !tensorMap2d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer3d to %tensorMap3d[%crd0, %crd1, %crd0], predicate = %p  : memref<2x32x32xf32,3> -> !tensorMap3d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer4d to %tensorMap4d[%crd0, %crd1, %crd1, %crd0], predicate = %p  : memref<2x2x32x32xf32,3> -> !tensorMap4d
+  // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
+  nvgpu.tma.async.store %buffer5d to %tensorMap5d[%crd0, %crd1, %crd1, %crd0, %crd0], predicate = %p  : memref<2x2x2x32x32xf32,3> -> !tensorMap5d
+  func.return 
+}
+
 func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : memref<128xf32>) {
   %crd0 = arith.constant 64 : index
   %crd1 = arith.constant 128 : index



More information about the Mlir-commits mailing list