[Mlir-commits] [mlir] 1dc48a9 - [mlir][spirv] Query target environment for mapping memory space

Stanley Winata llvmlistbot at llvm.org
Tue Sep 20 15:30:17 PDT 2022


Author: Stanley Winata
Date: 2022-09-20T15:28:58-07:00
New Revision: 1dc48a916a1bbf99799dbeefef26d9078e159e93

URL: https://github.com/llvm/llvm-project/commit/1dc48a916a1bbf99799dbeefef26d9078e159e93
DIFF: https://github.com/llvm/llvm-project/commit/1dc48a916a1bbf99799dbeefef26d9078e159e93.diff

LOG: [mlir][spirv] Query target environment for mapping memory space

Checks spirv::TargetEnv from op to see if it contains either Kernel or Shader capabilities.
If it does, then it will set the memory space mapping accordingly.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134317

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
    mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 006c69f9d0f67..d52eb4adf5226 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -325,6 +326,15 @@ void MapMemRefStorageClassPass::runOnOperation() {
   MLIRContext *context = &getContext();
   Operation *op = getOperation();
 
+  if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+    spirv::TargetEnv targetEnv(attr);
+    if (targetEnv.allows(spirv::Capability::Kernel)) {
+      memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+    } else if (targetEnv.allows(spirv::Capability::Shader)) {
+      memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+    }
+  }
+
   auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
   spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
 

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 43bbf406eab9d..9b10e8d017ed6 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -114,3 +114,56 @@ func.func @missing_mapping() {
   %0 = "dialect.memref_producer"() : () -> (memref<f32, 2>)
   return
 }
+
+// -----
+
+/// Checks memory maps to OpenCL mapping if Kernel capability is enabled.
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, #spv.resource_limits<>> } {
+func.func @operand_result() {
+  // CHECK: memref<f32, #spv.storage_class<CrossWorkgroup>>
+  %0 = "dialect.memref_producer"() : () -> (memref<f32>)
+  // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+  %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
+  // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
+  // CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
+  %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
+
+
+  "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
+  // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+  "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
+  // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
+  // CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
+  "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
+
+  return
+}
+}
+
+// -----
+
+/// Checks memory maps to Vulkan mapping if Shader capability is enabled.
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
+func.func @operand_result() {
+  // CHECK: memref<f32, #spv.storage_class<StorageBuffer>>
+  %0 = "dialect.memref_producer"() : () -> (memref<f32>)
+  // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+  %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
+  // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
+  // CHECK: memref<*xf16, #spv.storage_class<Uniform>>
+  %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
+
+
+  "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
+  // CHECK: memref<4xi32, #spv.storage_class<Generic>>
+  "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
+  // CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
+  "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
+  // CHECK: memref<*xf16, #spv.storage_class<Uniform>>
+  "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
+  return
+}
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list