[Mlir-commits] [mlir] [mlir][nvgpu] Add address space attribute converter in nvgpu-to-nvvm pass (PR #74075)

Guray Ozen llvmlistbot at llvm.org
Fri Dec 1 05:42:42 PST 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/74075

GPU dialect has `#gpu.address_space<workgroup>` for shared memory of NVGPU (address space =3). Howeverm when IR combine NVGPU and GPU dialect, `nvgpu-to-nvvm` pass fails due to missing attribute conversion.

This PR adds `populateGpuMemorySpaceAttributeConversions` to nvgou-to-nvvm lowering, so we can use `#gpu.address_space<workgroup>` `nvgpu-to-nvvm` pass

>From 2494eb82a21dfa1e6a4cf745b4b70d95becae926 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 1 Dec 2023 14:41:42 +0100
Subject: [PATCH] [mlir][nvgpu] Add memref address space convert

GPU dialect has `#gpu.address_space<workgroup>` for shared memory of NVGPU (address space =3). Howeverm when IR combine NVGPU and GPU dialect, `nvgpu-to-nvvm` pass fails due to missing attribute conversion.

This PR adds `populateGpuMemorySpaceAttributeConversions` to nvgou-to-nvvm lowering, so we can use `#gpu.address_space<workgroup>` `nvgpu-to-nvvm` pass
---
 mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp   | 15 +++++++++++++++
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir     | 13 +++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index c2e7d387a4420b4..9cd3a5ce65ce5f6 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -405,6 +405,21 @@ struct ConvertNVGPUToNVVMPass
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext(), options);
     IRRewriter rewriter(&getContext());
+    populateGpuMemorySpaceAttributeConversions(
+        converter, [](gpu::AddressSpace space) -> unsigned {
+          switch (space) {
+          case gpu::AddressSpace::Global:
+            return static_cast<unsigned>(
+                NVVM::NVVMMemorySpace::kGlobalMemorySpace);
+          case gpu::AddressSpace::Workgroup:
+            return static_cast<unsigned>(
+                NVVM::NVVMMemorySpace::kSharedMemorySpace);
+          case gpu::AddressSpace::Private:
+            return 0;
+          }
+          llvm_unreachable("unknown address space enum value");
+          return 0;
+        });
     /// device-side async tokens cannot be materialized in nvvm. We just
     /// convert them to a dummy i32 type in order to easily drop them during
     /// conversion.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 26a5961b43829f3..e11449e6f7c457c 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -666,6 +666,19 @@ func.func @async_tma_load(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d
   func.return 
 }
 
+// CHECK-LABEL: func @async_tma_load
+!tensorMap1dgpuspace = !nvgpu.tensormap.descriptor<tensor = memref<128xf32, #gpu.address_space<workgroup>>,         swizzle=none,        l2promo = none,        oob = nan,  interleave = none>
+func.func @async_tma_load_gpu_address_space(%tensorMap1d: !tensorMap1dgpuspace,
+                          %buffer1d: memref<128xf32, #gpu.address_space<workgroup>>,
+                          %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  %crd0 = arith.constant 0 : index
+  %crd1 = arith.constant 0 : index
+  // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}} box[%{{.*}}] 
+  nvgpu.tma.async.load %tensorMap1d[%crd0], %mbarrier[%c0] to %buffer1d : !tensorMap1dgpuspace, !mbarrier -> memref<128xf32,#gpu.address_space<workgroup>>
+   func.return 
+}
+
 // CHECK-LABEL: func @async_tma_load_pred
 func.func @async_tma_load_pred(%tensorMap1d: !tensorMap1d, %tensorMap2d: !tensorMap2d, %tensorMap3d: !tensorMap3d, %tensorMap4d: !tensorMap4d, %tensorMap5d: !tensorMap5d, 
                               %buffer1d: memref<128xf32,3>,      



More information about the Mlir-commits mailing list