[Mlir-commits] [mlir] 35a56e5 - [mlir][spirv] Map memory space to OpenCL/Kernel storage class

Lei Zhang llvmlistbot at llvm.org
Tue Aug 23 13:02:08 PDT 2022


Author: Stanley Winata
Date: 2022-08-23T16:01:54-04:00
New Revision: 35a56e5ddc0354d0d317ef981a0a5790a145f2e0

URL: https://github.com/llvm/llvm-project/commit/35a56e5ddc0354d0d317ef981a0a5790a145f2e0
DIFF: https://github.com/llvm/llvm-project/commit/35a56e5ddc0354d0d317ef981a0a5790a145f2e0.diff

LOG: [mlir][spirv] Map memory space to OpenCL/Kernel storage class

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132428

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
    mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index a8bb28bdd1aab..38c0a48079882 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -32,6 +32,13 @@ Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
 /// using the default rule. Returns None if the storage class is unsupported.
 Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
 
+/// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V
+/// using the default rule. Returns None if the memory space is unknown.
+Optional<spirv::StorageClass> mapMemorySpaceToOpenCLStorageClass(unsigned);
+/// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces
+/// using the default rule. Returns None if the storage class is unsupported.
+Optional<unsigned> mapOpenCLStorageClassToMemorySpace(spirv::StorageClass);
+
 /// Type converter for converting numeric MemRef memory spaces into SPIR-V
 /// symbolic ones.
 class MemorySpaceToStorageClassConverter : public TypeConverter {

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index e11e4ef085f76..e63f0594cf230 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
 /// depends on the context where it is used. There are no particular reasons
 /// behind the number assignments; we try to follow NVVM conventions and largely
 /// give common storage classes a smaller number.
-#define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
+#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
@@ -56,7 +56,7 @@ spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
     return storage;
 
   switch (memorySpace) {
-    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+    VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
   default:
     break;
   }
@@ -72,7 +72,7 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
     return space;
 
   switch (storageClass) {
-    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+    VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
   default:
     break;
   }
@@ -81,7 +81,50 @@ spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
 #undef STORAGE_SPACE_MAP_FN
 }
 
-#undef STORAGE_SPACE_MAP_LIST
+#undef VULKAN_STORAGE_SPACE_MAP_LIST
+
+#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN)                                  \
+  MAP_FN(spirv::StorageClass::CrossWorkgroup, 0)                               \
+  MAP_FN(spirv::StorageClass::Generic, 1)                                      \
+  MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
+  MAP_FN(spirv::StorageClass::UniformConstant, 4)                              \
+  MAP_FN(spirv::StorageClass::Private, 5)                                      \
+  MAP_FN(spirv::StorageClass::Function, 6)                                     \
+  MAP_FN(spirv::StorageClass::Image, 7)
+
+Optional<spirv::StorageClass>
+spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case space:                                                                  \
+    return storage;
+
+  switch (memorySpace) {
+    OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+  default:
+    break;
+  }
+  return llvm::None;
+
+#undef STORAGE_SPACE_MAP_FN
+}
+
+Optional<unsigned>
+spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case storage:                                                                \
+    return space;
+
+  switch (storageClass) {
+    OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+  default:
+    break;
+  }
+  return llvm::None;
+
+#undef STORAGE_SPACE_MAP_FN
+}
+
+#undef OPENCL_STORAGE_SPACE_MAP_LIST
 
 //===----------------------------------------------------------------------===//
 // Type Converter
@@ -263,7 +306,11 @@ LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
   if (failed(Pass::initializeOptions(options)))
     return failure();
 
-  if (clientAPI != "vulkan")
+  if (clientAPI == "opencl") {
+    memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+  }
+
+  if (clientAPI != "vulkan" && clientAPI != "opencl")
     return failure();
 
   return success();

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index fa0a1723d171d..43bbf406eab9d 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=opencl' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=OPENCL
 
 // Vulkan Mappings:
 //   0 -> StorageBuffer
@@ -7,24 +8,39 @@
 //   3 -> Workgroup
 //   4 -> Uniform
 
+// OpenCL Mappings:
+//   0 -> CrossWorkgroup
+//   1 -> Generic
+//   2 -> [null]
+//   3 -> Workgroup
+//   4 -> UniformConstant
+
 // VULKAN-LABEL: func @operand_result
+// OPENCL-LABEL: func @operand_result
 func.func @operand_result() {
   // VULKAN: memref<f32, #spv.storage_class<StorageBuffer>>
+  // OPENCL: memref<f32, #spv.storage_class<CrossWorkgroup>>
   %0 = "dialect.memref_producer"() : () -> (memref<f32>)
   // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
+  // OPENCL: memref<4xi32, #spv.storage_class<Generic>>
   %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
   // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  // OPENCL: memref<?x4xf16, #spv.storage_class<Workgroup>>
   %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
   // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
+  // OPENCL: memref<*xf16, #spv.storage_class<UniformConstant>>
   %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
 
 
   "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
   // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
+  // OPENCL: memref<4xi32, #spv.storage_class<Generic>>
   "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
   // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  // OPENCL: memref<?x4xf16, #spv.storage_class<Workgroup>>
   "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
   // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
+  // OPENCL: memref<*xf16, #spv.storage_class<UniformConstant>>
   "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
 
   return
@@ -33,8 +49,10 @@ func.func @operand_result() {
 // -----
 
 // VULKAN-LABEL: func @type_attribute
+// OPENCL-LABEL: func @type_attribute
 func.func @type_attribute() {
   // VULKAN: attr = memref<i32, #spv.storage_class<Generic>>
+  // OPENCL: attr = memref<i32, #spv.storage_class<Generic>>
   "dialect.memref_producer"() { attr = memref<i32, 1> } : () -> ()
   return
 }
@@ -42,10 +60,13 @@ func.func @type_attribute() {
 // -----
 
 // VULKAN-LABEL: func.func @function_io
+// OPENCL-LABEL: func.func @function_io
 func.func @function_io
   // VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
+  // OPENCL-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
   (%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
   // VULKAN-SAME: -> (memref<f64, #spv.storage_class<Generic>>, memref<4xi32, #spv.storage_class<Workgroup>>)
+  // OPENCL-SAME: -> (memref<f64, #spv.storage_class<Generic>>, memref<4xi32, #spv.storage_class<Workgroup>>)
   -> (memref<f64, 1>, memref<4xi32, 3>) {
   return %arg0, %arg1: memref<f64, 1>, memref<4xi32, 3>
 }
@@ -54,17 +75,22 @@ func.func @function_io
 
 gpu.module @kernel {
 // VULKAN-LABEL: gpu.func @function_io
+// OPENCL-LABEL: gpu.func @function_io
 // VULKAN-SAME: memref<8xi32, #spv.storage_class<StorageBuffer>>
+// OPENCL-SAME: memref<8xi32, #spv.storage_class<CrossWorkgroup>>
 gpu.func @function_io(%arg0 : memref<8xi32>) kernel { gpu.return }
 }
 
 // -----
 
 // VULKAN-LABEL: func.func @region
+// OPENCL-LABEL: func.func @region
 func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
   scf.if %cond {
     //      VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
+    //      OPENCL: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
     // VULKAN-SAME: (memref<f32, #spv.storage_class<Generic>>) -> memref<f32, #spv.storage_class<Generic>>
+    // OPENCL-SAME: (memref<f32, #spv.storage_class<Generic>>) -> memref<f32, #spv.storage_class<Generic>>
     %0 = "dialect.memref_consumer"(%arg0) { attr = memref<i64, 3> } : (memref<f32, 1>) -> (memref<f32, 1>)
   }
   return
@@ -73,8 +99,10 @@ func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
 // -----
 
 // VULKAN-LABEL: func @non_memref_types
+// OPENCL-LABEL: func @non_memref_types
 func.func @non_memref_types(%arg: f32) -> f32 {
   // VULKAN: "dialect.op"(%{{.+}}) {attr = 16 : i64} : (f32) -> f32
+  // OPENCL: "dialect.op"(%{{.+}}) {attr = 16 : i64} : (f32) -> f32
   %0 = "dialect.op"(%arg) { attr = 16 } : (f32) -> (f32)
   return %0 : f32
 }


        


More information about the Mlir-commits mailing list