[Mlir-commits] [mlir] [mlir][nvgpu] Mark TMA descriptor as MemWriteAt in `tma.async.store` (PR #79427)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 01:27:41 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-nvgpu

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

The canonicalizer finds `nvgpu.tma.async.store` Op trivially dead, because it lacks any memory side effects. This PR aims to address this issue by adding the `MemWriteAt` to the TMA descriptor.

This Op copies data `shared memory -> global memory` asynchronously, so the fix might not be optimal as memory mutation does not happen right away.

The asynchronous behavior is controlled by two NVVM OPs below: `nvvm.cp.async.bulk.commit.group`: Groups all the `nvgpu.tma.async.store` together and commits the group. `nvvm.cp.async.bulk.wait_group 1`: Waits for the completion of the 1st group

Here's a simplified representation of the code:
```
gpu.func ...  {
  // Write something to shared memory
  %shmem = ...

  // Perform asynchronous store shared memory -> global memory
  nvgpu.tma.async.store %shmem to %arg0[%c0, %c0], predicate = %1
    : memref<128x32xf32, #gpu.address_space<workgroup>> ->
      <tensor = memref<128x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>

  // Control asynchronous execution
  nvvm.cp.async.bulk.commit.group
  nvvm.cp.async.bulk.wait_group 1
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/79427.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+1-1) 
- (added) mlir/test/Dialect/NVGPU/canonicalization.mlir (+30) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 239a5f1e2bc2985..a0c0d4cfd8714ba 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -671,7 +671,7 @@ def NVGPU_TmaAsyncStoreOp : NVGPU_Op<"tma.async.store", [AttrSizedOperandSegment
     tile shape. The descriptor is created by `nvgpu.tma.create.descriptor`
   }];  
   let arguments = (ins  Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
-                        NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
+                        Arg<NVGPU_TensorMapDescriptor, "", [MemWriteAt<0, FullEffect>]>:$tensorMapDescriptor,
                         Variadic<Index>:$coordinates, 
                         Optional<I1>:$predicate);
   let assemblyFormat = [{
diff --git a/mlir/test/Dialect/NVGPU/canonicalization.mlir b/mlir/test/Dialect/NVGPU/canonicalization.mlir
new file mode 100644
index 000000000000000..a7fbfd80673957c
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/canonicalization.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s | mlir-opt -canonicalize -cse | FileCheck %s
+
+gpu.module @main_kernel {
+
+// CHECK-LABEL: @main_kernel(
+//  CHECK-SAME: %[[arg0:.*]]: !nvgpu.tensormap.descriptor
+  gpu.func @main_kernel(%arg0: !nvgpu.tensormap.descriptor<
+        tensor = memref<128x32xf32, 3>, swizzle = none, l2promo = none, 
+        oob = zero, interleave = none>) kernel attributes 
+        { gpu.known_block_size = array<i32: 128, 1, 1>, 
+          gpu.known_grid_size = array<i32: 1, 1, 1>
+        } 
+  {
+    // CHECK: %[[c0:.+]] = arith.constant 0 : index 
+    // CHECK: %[[S0:.+]] = gpu.thread_id  x
+    // CHECK: %[[S1:.+]] = arith.cmpi eq, %[[S0]], %[[c0]] : index
+    // CHECK: %[[S2:.+]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+    // CHECK: %[[S3:.+]] = memref.view %[[S2]][%[[c0]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x32xf32, #gpu.address_space<workgroup>>
+    // CHECK: nvgpu.tma.async.store %[[S3]] to %[[arg0]][%[[c0]], %[[c0]]], predicate = %[[S1]] : memref<128x32xf32, #gpu.address_space<workgroup>> -> <tensor = memref<128x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+    %c0 = arith.constant 0 : index
+    %0 = gpu.thread_id  x
+    %1 = arith.cmpi eq, %0, %c0 : index
+    %2 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+    %view = memref.view %2[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x32xf32, #gpu.address_space<workgroup>>
+    nvgpu.tma.async.store %view to %arg0[%c0, %c0], predicate = %1 : memref<128x32xf32, #gpu.address_space<workgroup>> -> <tensor = memref<128x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+    nvvm.cp.async.bulk.commit.group
+    nvvm.cp.async.bulk.wait_group 0
+    gpu.return
+  }
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/79427


More information about the Mlir-commits mailing list