[Mlir-commits] [mlir] 1513555 - [mlir][spirv] Use functors for default memory space mappings

Lei Zhang llvmlistbot at llvm.org
Tue Aug 9 11:39:49 PDT 2022


Author: Lei Zhang
Date: 2022-08-09T14:38:27-04:00
New Revision: 15135553c4cf34d3915e45b55e915154b33ab67b

URL: https://github.com/llvm/llvm-project/commit/15135553c4cf34d3915e45b55e915154b33ab67b
DIFF: https://github.com/llvm/llvm-project/commit/15135553c4cf34d3915e45b55e915154b33ab67b.diff

LOG: [mlir][spirv] Use functors for default memory space mappings

This makes it easier to use as a utility function to query the
mappings, including the reverse.

This commit also drops some storage classes that aren't needed
for now.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131411

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 60730246a7c65..a8bb28bdd1aab 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -22,9 +22,15 @@ class SPIRVTypeConverter;
 
 namespace spirv {
 /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
-using MemorySpaceToStorageClassMap = DenseMap<unsigned, spirv::StorageClass>;
-/// Returns the default map for targeting Vulkan-flavored SPIR-V.
-MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap();
+using MemorySpaceToStorageClassMap =
+    std::function<Optional<spirv::StorageClass>(unsigned)>;
+
+/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
+/// using the default rule. Returns None if the memory space is unknown.
+Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
+/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
+/// using the default rule. Returns None if the storage class is unsupported.
+Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
 
 /// Type converter for converting numeric MemRef memory spaces into SPIR-V
 /// symbolic ones.
@@ -34,7 +40,7 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
       const MemorySpaceToStorageClassMap &memorySpaceMap);
 
 private:
-  const MemorySpaceToStorageClassMap &memorySpaceMap;
+  MemorySpaceToStorageClassMap memorySpaceMap;
 };
 
 /// Creates the target that populates legality of ops with MemRef types.

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 6d20e989a71a7..fb5c89244b549 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() {
     std::unique_ptr<ConversionTarget> target =
         spirv::getMemorySpaceToStorageClassTarget(*context);
     spirv::MemorySpaceToStorageClassMap memorySpaceMap =
-        spirv::getDefaultVulkanStorageClassMap();
+        spirv::mapMemorySpaceToVulkanStorageClass;
     spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
 
     RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 535613714d53d..e11e4ef085f76 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -30,7 +31,6 @@ using namespace mlir;
 // Mappings
 //===----------------------------------------------------------------------===//
 
-spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
 /// Mapping between SPIR-V storage classes to memref memory spaces.
 ///
 /// Note: memref does not have a defined semantics for each memory space; it
@@ -47,29 +47,42 @@ spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
   MAP_FN(spirv::StorageClass::Input, 9)                                        \
-  MAP_FN(spirv::StorageClass::Output, 10)                                      \
-  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
-  MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
-  MAP_FN(spirv::StorageClass::Image, 13)                                       \
-  MAP_FN(spirv::StorageClass::CallableDataKHR, 14)                             \
-  MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15)                     \
-  MAP_FN(spirv::StorageClass::RayPayloadKHR, 16)                               \
-  MAP_FN(spirv::StorageClass::HitAttributeKHR, 17)                             \
-  MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18)                       \
-  MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19)                       \
-  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)                       \
-  MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21)                            \
-  MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22)                             \
-  MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
-
-#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage},
-
-  return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)};
+  MAP_FN(spirv::StorageClass::Output, 10)
+
+Optional<spirv::StorageClass>
+spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case space:                                                                  \
+    return storage;
+
+  switch (memorySpace) {
+    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+  default:
+    break;
+  }
+  return llvm::None;
+
+#undef STORAGE_SPACE_MAP_FN
+}
+
+Optional<unsigned>
+spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case storage:                                                                \
+    return space;
+
+  switch (storageClass) {
+    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+  default:
+    break;
+  }
+  return llvm::None;
 
 #undef STORAGE_SPACE_MAP_FN
-#undef STORAGE_SPACE_MAP_LIST
 }
 
+#undef STORAGE_SPACE_MAP_LIST
+
 //===----------------------------------------------------------------------===//
 // Type Converter
 //===----------------------------------------------------------------------===//
@@ -91,8 +104,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
     }
 
     unsigned space = memRefType.getMemorySpaceAsInt();
-    auto it = this->memorySpaceMap.find(space);
-    if (it == this->memorySpaceMap.end()) {
+    auto storage = this->memorySpaceMap(space);
+    if (!storage) {
       LLVM_DEBUG(llvm::dbgs()
                  << "cannot convert " << memRefType
                  << " due to being unable to find memory space in map\n");
@@ -100,7 +113,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
     }
 
     auto storageAttr =
-        spirv::StorageClassAttr::get(memRefType.getContext(), it->second);
+        spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
     if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
       return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
                              rankedType.getLayout(), storageAttr);
@@ -231,16 +244,7 @@ class MapMemRefStorageClassPass final
     : public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
 public:
   explicit MapMemRefStorageClassPass() {
-    memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
-
-    LLVM_DEBUG({
-      llvm::dbgs() << "memory space to storage class mapping:\n";
-      if (memorySpaceMap.empty())
-        llvm::dbgs() << "  [empty]\n";
-      for (auto kv : memorySpaceMap)
-        llvm::dbgs() << "  " << kv.first << " -> "
-                     << spirv::stringifyStorageClass(kv.second) << "\n";
-    });
+    memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
   }
   explicit MapMemRefStorageClassPass(
       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)


        


More information about the Mlir-commits mailing list