[Mlir-commits] [mlir] [mlir][nvgpu] Support strided memref when creating TMA descriptor (PR #85652)

Guray Ozen llvmlistbot at llvm.org
Mon Mar 18 08:42:32 PDT 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/85652

Currently, the runtime always assumes that memref is always contiguous, and this limits strided memref usage. This PR supports strided memref when creating TMA descriptor.

Co-authored-by:
Adam Paszke <adam.paszke at gmail.com>

>From 952783a5dae5af30218d9332ac3eab1b6e5cf963 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 18 Mar 2024 15:39:35 +0000
Subject: [PATCH] [mlir][nvgpu] Use the strides of the memref descriptor to
 construct the TMA descriptor

Currently, runtime always assumes that memref is always contiguous, but it's not always the case. This PR improves this supports and supports strided memref.

Co-authored-by:
Adam Paszke <adam.paszke at gmail.com>
---
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   |  49 +++---
 .../tma_load_128x128_stride_noswizzle.mlir    | 147 ++++++++++++++++++
 2 files changed, 172 insertions(+), 24 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir

diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index b9a3429e37b885..9d406bdfc7cc9a 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -423,24 +423,27 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
               elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
 }
 
-namespace {
-
-template <int rank>
-void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
-                               uint64_t *globalDim) {
+template <int Rank>
+void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr,
+                               uint64_t *globalDim, uint64_t *globalStrides,
+                               const CUtensorMapDataType tensorDataType) {
   auto descriptor =
-      reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
+      reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor);
   *addr = descriptor->data;
-  for (int i = 0; i < rank; ++i) {
-    globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
+  for (int i = 0; i < Rank; ++i) {
+    globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]);
+  }
+  static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
+                                               4, 8, 2, 4, 4, 4};
+  for (int i = 0; i < Rank - 1; ++i) {
+    globalStrides[i] = static_cast<uint64_t>(
+        descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]);
   }
 }
 
-} // namespace
-
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
     int64_t tensorRank,                       // Dimensionality of tensor
-    void *ranked_descriptor,                  // Ranked MemRef descriptor
+    void *rankedDescriptor,                   // Ranked MemRef descriptor
     const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
     CUtensorMapInterleave interleave,         // Type of interleaved layout
     CUtensorMapSwizzle swizzle,               // Bank swizzling pattern
@@ -457,38 +460,36 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
   char *globalAddress = nullptr;
   switch (tensorRank) {
   case 1:
-    mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 2:
-    mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 3:
-    mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 4:
-    mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   case 5:
-    mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
+    mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim,
+                                 globalStrides, tensorDataType);
     break;
   default:
     fprintf(
         stderr,
         "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
-    return NULL;
+    return nullptr;
   }
 
-  static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
-                                           4, 8, 2, 4, 4, 4};
   for (int64_t r = 0; r < tensorRank; ++r) {
-    elementStrides[r] = uint32_t(1);
     boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
   }
 
-  globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
-  for (int r = 1; r < tensorRank - 1; r++)
-    globalStrides[r] = globalStrides[r - 1] * globalDim[r];
-
   ScopedContext scopedContext;
   mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
                            globalAddress, globalDim, globalStrides, boxDim,
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir
new file mode 100644
index 00000000000000..54045b82d40da3
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/tma_load_128x128_stride_noswizzle.mlir
@@ -0,0 +1,147 @@
+// RUN: mlir-opt %s \
+// RUN:  -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
+// RUN:  | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN:  | FileCheck %s
+
+// CHECK: Correct Results :8192
+// CHECK: Incorrect Results :0
+
+module {
+  func.func @main() {
+    %c10000000 = arith.constant 10000000 : index
+    %false = arith.constant false
+    %c32768 = arith.constant 32768 : index
+    %c31_i32 = arith.constant 31 : i32
+    %c-1_i32 = arith.constant -1 : i32
+    %c5_i32 = arith.constant 5 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %c0 = arith.constant 0 : index
+    %c8 = arith.constant 8 : index
+    %c64 = arith.constant 64 : index
+    %c2 = arith.constant 2 : index
+    %c32768_i32 = arith.constant 32768 : i32
+    %c128 = arith.constant 128 : index
+    %c1 = arith.constant 1 : index
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.mlir.constant(128 : i64) : i64
+    %2 = llvm.mlir.constant(0 : i64) : i64
+    %f0 = arith.constant 0.0 : f16
+    %f123 = arith.constant 1.123 : f16
+    
+    %srcMemref_host = memref.alloc() : memref<128x128xf16>
+    %dstMemref_host = memref.alloc() : memref<128x128xf16>
+    scf.for %arg0 = %c0 to %c128 step %c1 {
+      scf.for %arg1 = %c0 to %c64 step %c1 {
+        %d1 = arith.index_cast %arg0 : index to i32
+        %d2 = arith.index_cast %arg1 : index to i32
+        %d3 = arith.sitofp %d1 : i32 to f16
+        %d4 = arith.sitofp %d2 : i32 to f16
+        %d5 = arith.addf %d3, %f123 : f16
+        %d6 = arith.constant 3.12 : f16
+        %d7 = arith.mulf %d5, %d6 : f16
+        %d8 = arith.addf %d7, %d5 : f16
+        %d9 = arith.constant 0.178 : f16
+        %d10 = arith.divf %d9, %d8 : f16
+        memref.store %d10, %srcMemref_host[%arg0, %arg1] : memref<128x128xf16>
+        memref.store %f0, %dstMemref_host[%arg0, %arg1] : memref<128x128xf16>
+      }
+    }
+
+    %s1 = gpu.wait async
+    %srcMemref, %s2 = gpu.alloc async [%s1] () : memref<128x128xf16>
+    %dstMemref, %s3 = gpu.alloc async [%s2] () : memref<128x128xf16>
+    %s4 = gpu.memcpy async [%s3] %srcMemref, %srcMemref_host : memref<128x128xf16>, memref<128x128xf16>
+    %s5 = gpu.memcpy async [%s4] %dstMemref, %dstMemref_host : memref<128x128xf16>, memref<128x128xf16>
+
+    %expand_shape = memref.expand_shape %srcMemref [[0, 1], [2, 3]] : memref<128x128xf16> into memref<2x64x2x64xf16>
+    %transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<2x64x2x64xf16> to memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>>
+    %cast = memref.cast %transpose : memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>> to memref<*xf16>
+    %24 = nvgpu.tma.create.descriptor %cast box[%c2, %c2, %c64, %c64] : memref<*xf16> -> <tensor = memref<2x2x64x64xf16, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+    
+    gpu.launch 
+      blocks(%arg2, %arg3, %arg4) in (%arg8 = %c1, %arg9 = %c1, %arg10 = %c1) 
+      threads(%arg5, %arg6, %arg7) in (%arg11 = %c128, %arg12 = %c1, %arg13 = %c1) 
+      dynamic_shared_memory_size %c32768_i32 
+    {
+      %26 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+      %view = memref.view %26[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x2x64x64xf16, #gpu.address_space<workgroup>>
+      %27 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
+      %thread_id_x = gpu.thread_id  x
+      %28 = arith.index_cast %thread_id_x : index to i32
+      %29 = arith.shrui %28, %c5_i32 : i32
+      %30 = nvvm.shfl.sync  idx %c-1_i32, %29, %c0_i32, %c31_i32 : i32 -> i32
+      %31 = arith.cmpi eq, %30, %c0_i32 : i32
+      %32 = nvvm.elect.sync -> i1
+      %33 = arith.andi %31, %32 : i1
+      scf.if %33 {
+        nvgpu.mbarrier.init %27[%c0], %c1 : <memorySpace = #gpu.address_space<workgroup>>
+      }
+      %34 = nvvm.shfl.sync  idx %c-1_i32, %29, %c0_i32, %c31_i32 : i32 -> i32
+      %35 = arith.cmpi eq, %34, %c0_i32 : i32
+      %36 = nvvm.elect.sync -> i1
+      %37 = arith.andi %35, %36 : i1
+      scf.if %37 {
+        nvgpu.mbarrier.arrive.expect_tx %27[%c0], %c32768 : <memorySpace = #gpu.address_space<workgroup>>
+        nvgpu.tma.async.load %24[%c0, %c0, %c0, %c0], %27[%c0] to %view : <tensor = memref<2x2x64x64xf16, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<2x2x64x64xf16, #gpu.address_space<workgroup>>
+      }
+      nvgpu.mbarrier.try_wait.parity %27[%c0], %false, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
+      scf.for %arg14 = %c0 to %c2 step %c1 {
+        scf.for %arg15 = %c0 to %c2 step %c1 {
+          %38 = arith.muli %arg14, %c64 : index
+          %39 = arith.muli %arg15, %c64 : index
+          %subview = memref.subview %view[%arg14, %arg15, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<2x2x64x64xf16, #gpu.address_space<workgroup>> to memref<64x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
+          %subview_0 = memref.subview %dstMemref[%38, %39] [64, 64] [1, 1] : memref<128x128xf16> to memref<64x64xf16, strided<[128, 1], offset: ?>>
+          %block_dim_x = gpu.block_dim  x
+          %thread_id_y = gpu.thread_id  y
+          %40 = arith.muli %thread_id_y, %block_dim_x : index
+          %41 = arith.addi %thread_id_x, %40 : index
+          %block_dim_y = gpu.block_dim  y
+          %42 = arith.muli %block_dim_x, %block_dim_y : index
+          %thread_id_z = gpu.thread_id  z
+          %43 = arith.muli %thread_id_z, %42 : index
+          %44 = arith.addi %41, %43 : index
+          %45 = arith.cmpi eq, %44, %c0 : index
+          scf.if %45 {
+            scf.for %arg16 = %c0 to %c64 step %c1 {
+              scf.for %arg17 = %c0 to %c64 step %c1 {
+                %46 = memref.load %subview[%arg16, %arg17] : memref<64x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
+                memref.store %46, %subview_0[%arg16, %arg17] : memref<64x64xf16, strided<[128, 1], offset: ?>>
+              }
+            }
+          }
+          gpu.barrier
+        }
+      }
+      gpu.terminator
+    }
+
+    %s6 = gpu.memcpy async [%s5] %dstMemref_host, %dstMemref  : memref<128x128xf16>, memref<128x128xf16>
+    gpu.wait [%s6]
+
+   %errorCount, %correctCount =  scf.for %arg0 = %c0 to %c128 step %c1 iter_args(%ec1 = %c0, %cc1 = %c0) -> (index,index) {
+      %ec2, %cc2 = scf.for %arg1 = %c0 to %c64 step %c1 iter_args(%ec2 = %ec1, %cc2 = %cc1) -> (index, index) { 
+        %v1 = memref.load %dstMemref_host[%arg0, %arg1] : memref<128x128xf16>
+        %v2 = memref.load %srcMemref_host[%arg0, %arg1] : memref<128x128xf16>
+        %p = arith.cmpf one, %v1, %v2 : f16        
+        %ec3, %cc3 = scf.if %p -> (index, index) {
+          %ec3 = arith.addi %ec2, %c1 : index
+          scf.yield %ec3, %cc2 : index, index
+        } else {
+          %cc3 = arith.addi %cc2, %c1 : index
+          scf.yield %ec2, %cc3 : index, index
+        }
+      scf.yield %ec3, %cc3 : index,index
+      }
+      scf.yield %ec2, %cc2 : index,index
+    }
+    
+    vector.print str "Correct Results :"
+    vector.print %correctCount : index
+    vector.print str "Incorrect Results :"
+    vector.print %errorCount : index
+    return
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list