[Mlir-commits] [mlir] 1dc48a9 - [mlir][spirv] Query target environment for mapping memory space
Stanley Winata
llvmlistbot at llvm.org
Tue Sep 20 15:30:17 PDT 2022
Author: Stanley Winata
Date: 2022-09-20T15:28:58-07:00
New Revision: 1dc48a916a1bbf99799dbeefef26d9078e159e93
URL: https://github.com/llvm/llvm-project/commit/1dc48a916a1bbf99799dbeefef26d9078e159e93
DIFF: https://github.com/llvm/llvm-project/commit/1dc48a916a1bbf99799dbeefef26d9078e159e93.diff
LOG: [mlir][spirv] Query target environment for mapping memory space
Checks spirv::TargetEnv from op to see if it contains either Kernel or Shader capabilities.
If it does, then it will set the memory space mapping accordingly.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D134317
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 006c69f9d0f67..d52eb4adf5226 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -325,6 +326,15 @@ void MapMemRefStorageClassPass::runOnOperation() {
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+ spirv::TargetEnv targetEnv(attr);
+ if (targetEnv.allows(spirv::Capability::Kernel)) {
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ } else if (targetEnv.allows(spirv::Capability::Shader)) {
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+ }
+ }
+
auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 43bbf406eab9d..9b10e8d017ed6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -114,3 +114,56 @@ func.func @missing_mapping() {
%0 = "dialect.memref_producer"() : () -> (memref<f32, 2>)
return
}
+
+// -----
+
+/// Checks memory maps to OpenCL mapping if Kernel capability is enabled.
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, #spv.resource_limits<>> } {
+func.func @operand_result() {
+ // CHECK: memref<f32, #spv.storage_class<CrossWorkgroup>>
+ %0 = "dialect.memref_producer"() : () -> (memref<f32>)
+ // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+ %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
+ // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+ %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
+ // CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
+ %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
+
+
+ "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
+ // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+ "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
+ // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+ "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
+ // CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
+ "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
+
+ return
+}
+}
+
+// -----
+
+/// Checks memory maps to Vulkan mapping if Shader capability is enabled.
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
+func.func @operand_result() {
+ // CHECK: memref<f32, #spv.storage_class<StorageBuffer>>
+ %0 = "dialect.memref_producer"() : () -> (memref<f32>)
+ // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+ %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
+ // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+ %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
+ // CHECK: memref<*xf16, #spv.storage_class<Uniform>>
+ %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
+
+
+ "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
+ // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+ "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
+ // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+ "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
+ // CHECK: memref<*xf16, #spv.storage_class<Uniform>>
+ "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
+ return
+}
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list