[Mlir-commits] [mlir] 1513555 - [mlir][spirv] Use functors for default memory space mappings
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 9 11:39:49 PDT 2022
Author: Lei Zhang
Date: 2022-08-09T14:38:27-04:00
New Revision: 15135553c4cf34d3915e45b55e915154b33ab67b
URL: https://github.com/llvm/llvm-project/commit/15135553c4cf34d3915e45b55e915154b33ab67b
DIFF: https://github.com/llvm/llvm-project/commit/15135553c4cf34d3915e45b55e915154b33ab67b.diff
LOG: [mlir][spirv] Use functors for default memory space mappings
This makes it easier to use as a utility function to query the
mappings, including the reverse.
This commit also drops some storage classes that aren't needed
for now.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D131411
Added:
Modified:
mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 60730246a7c65..a8bb28bdd1aab 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -22,9 +22,15 @@ class SPIRVTypeConverter;
namespace spirv {
/// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones.
-using MemorySpaceToStorageClassMap = DenseMap<unsigned, spirv::StorageClass>;
-/// Returns the default map for targeting Vulkan-flavored SPIR-V.
-MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap();
+using MemorySpaceToStorageClassMap =
+ std::function<Optional<spirv::StorageClass>(unsigned)>;
+
+/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V
+/// using the default rule. Returns None if the memory space is unknown.
+Optional<spirv::StorageClass> mapMemorySpaceToVulkanStorageClass(unsigned);
+/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces
+/// using the default rule. Returns None if the storage class is unsupported.
+Optional<unsigned> mapVulkanStorageClassToMemorySpace(spirv::StorageClass);
/// Type converter for converting numeric MemRef memory spaces into SPIR-V
/// symbolic ones.
@@ -34,7 +40,7 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
const MemorySpaceToStorageClassMap &memorySpaceMap);
private:
- const MemorySpaceToStorageClassMap &memorySpaceMap;
+ MemorySpaceToStorageClassMap memorySpaceMap;
};
/// Creates the target that populates legality of ops with MemRef types.
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 6d20e989a71a7..fb5c89244b549 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -64,7 +64,7 @@ void GPUToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
- spirv::getDefaultVulkanStorageClassMap();
+ spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 535613714d53d..e11e4ef085f76 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -30,7 +31,6 @@ using namespace mlir;
// Mappings
//===----------------------------------------------------------------------===//
-spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
/// Mapping between SPIR-V storage classes to memref memory spaces.
///
/// Note: memref does not have a defined semantics for each memory space; it
@@ -47,29 +47,42 @@ spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() {
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
- MAP_FN(spirv::StorageClass::Output, 10) \
- MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
- MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
- MAP_FN(spirv::StorageClass::Image, 13) \
- MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
- MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
- MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
- MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
- MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
- MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
- MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
- MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
- MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
- MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
-
-#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage},
-
- return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)};
+ MAP_FN(spirv::StorageClass::Output, 10)
+
+Optional<spirv::StorageClass>
+spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) {
+#define STORAGE_SPACE_MAP_FN(storage, space) \
+ case space: \
+ return storage;
+
+ switch (memorySpace) {
+ STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+ default:
+ break;
+ }
+ return llvm::None;
+
+#undef STORAGE_SPACE_MAP_FN
+}
+
+Optional<unsigned>
+spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
+#define STORAGE_SPACE_MAP_FN(storage, space) \
+ case storage: \
+ return space;
+
+ switch (storageClass) {
+ STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+ default:
+ break;
+ }
+ return llvm::None;
#undef STORAGE_SPACE_MAP_FN
-#undef STORAGE_SPACE_MAP_LIST
}
+#undef STORAGE_SPACE_MAP_LIST
+
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
@@ -91,8 +104,8 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
}
unsigned space = memRefType.getMemorySpaceAsInt();
- auto it = this->memorySpaceMap.find(space);
- if (it == this->memorySpaceMap.end()) {
+ auto storage = this->memorySpaceMap(space);
+ if (!storage) {
LLVM_DEBUG(llvm::dbgs()
<< "cannot convert " << memRefType
<< " due to being unable to find memory space in map\n");
@@ -100,7 +113,7 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
}
auto storageAttr =
- spirv::StorageClassAttr::get(memRefType.getContext(), it->second);
+ spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
if (auto rankedType = memRefType.dyn_cast<MemRefType>()) {
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
rankedType.getLayout(), storageAttr);
@@ -231,16 +244,7 @@ class MapMemRefStorageClassPass final
: public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
public:
explicit MapMemRefStorageClassPass() {
- memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
-
- LLVM_DEBUG({
- llvm::dbgs() << "memory space to storage class mapping:\n";
- if (memorySpaceMap.empty())
- llvm::dbgs() << " [empty]\n";
- for (auto kv : memorySpaceMap)
- llvm::dbgs() << " " << kv.first << " -> "
- << spirv::stringifyStorageClass(kv.second) << "\n";
- });
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
explicit MapMemRefStorageClassPass(
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
More information about the Mlir-commits
mailing list