[Mlir-commits] [mlir] 89b595e - [mlir][spirv] Detach memory space mapping from type conversion

Lei Zhang llvmlistbot at llvm.org
Tue Aug 9 11:30:49 PDT 2022


Author: Lei Zhang
Date: 2022-08-09T14:30:43-04:00
New Revision: 89b595e1418b22cd7c982276aa32872863c5e186

URL: https://github.com/llvm/llvm-project/commit/89b595e1418b22cd7c982276aa32872863c5e186
DIFF: https://github.com/llvm/llvm-project/commit/89b595e1418b22cd7c982276aa32872863c5e186.diff

LOG: [mlir][spirv] Detach memory space mapping from type conversion

This commit moves MemRef memory space to SPIR-V storage class
conversion out of the main SPIR-V type converter. Now the mapping
should happen as a prelimiary step before performing the final
conversion to SPIR-V. Flows are expect to write their own memory
space mappings like the `MapMemRefStorageClassPass` to handle
memory space mappings according to their needs.

This is needed because SPIR-V is serving multiple client APIs,
including Vulkan and OpenCL. Different client APIs might want
to use different storage classes for buffers in a particular
memory space, e.g., `StorageBuffer` for Vulkan vs. `CrossWorkgroup`
for OpenCL when converting the default 0 memory space.  Hardcoding
a specific mapping makes that hard. While it's possible to embed
selection logic further inside the main type converter, it will
make the main type converter even complicated. So it's better to
separate the concerns, as mapping the memory space is really
concretizing the meaning of those numeric memory spaces in the
particular context of SPIR-V lowering.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131410

Added: 
    mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
    mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir

Modified: 
    mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
    mlir/test/Conversion/SCFToSPIRV/for.mlir
    mlir/test/Conversion/SCFToSPIRV/if.mlir
    mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Removed: 
    mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
    mlir/test/Conversion/GPUToSPIRV/simple.mlir


################################################################################
diff  --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
index dba8d274c7a92..8867c96d0a048 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
@@ -21,9 +21,12 @@ class ModuleOp;
 template <typename T>
 class OperationPass;
 
-/// Creates a pass to convert GPU Ops to SPIR-V ops. For a gpu.func to be
-/// converted, it should have a spv.entry_point_abi attribute.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertGPUToSPIRVPass();
+/// Creates a pass to convert GPU kernel ops to corresponding SPIR-V ops. For a
+/// gpu.func to be converted, it should have a spv.entry_point_abi attribute.
+/// If `mapMemorySpace` is true, performs MemRef memory space to SPIR-V mapping
+/// according to default Vulkan rules first.
+std::unique_ptr<OperationPass<ModuleOp>>
+createConvertGPUToSPIRVPass(bool mapMemorySpace = false);
 
 } // namespace mlir
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRVPASS_H

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 2eede784c208a..d0f09eaa2e753 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -69,15 +69,6 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
-  /// Returns the corresponding memory space for memref given a SPIR-V storage
-  /// class.
-  static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
-
-  /// Returns the SPIR-V storage class given a memory space for memref. Return
-  /// llvm::None if the memory space does not map to any SPIR-V storage class.
-  static Optional<spirv::StorageClass>
-  getStorageClassForMemorySpace(unsigned space);
-
   /// Returns the options controlling the SPIR-V type converter.
   const Options &getOptions() const;
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index b9d4b3fd78cb6..0e0083b60797a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -204,6 +204,8 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, TypeConverter &typeConverter,
     for (const auto &argType :
          enumerate(funcOp.getFunctionType().getInputs())) {
       auto convertedType = typeConverter.convertType(argType.value());
+      if (!convertedType)
+        return nullptr;
       signatureConverter.addInputs(argType.index(), convertedType);
     }
   }

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 5eaa099278b28..6d20e989a71a7 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -35,8 +35,14 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-struct GPUToSPIRVPass : public ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
+class GPUToSPIRVPass : public ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
+public:
+  explicit GPUToSPIRVPass(bool mapMemorySpace)
+      : mapMemorySpace(mapMemorySpace) {}
   void runOnOperation() override;
+
+private:
+  bool mapMemorySpace;
 };
 } // namespace
 
@@ -44,16 +50,30 @@ void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
   ModuleOp module = getOperation();
 
-  SmallVector<Operation *, 1> kernelModules;
+  SmallVector<Operation *, 1> gpuModules;
   OpBuilder builder(context);
-  module.walk([&builder, &kernelModules](gpu::GPUModuleOp moduleOp) {
-    // For each kernel module (should be only 1 for now, but that is not a
-    // requirement here), clone the module for conversion because the
-    // gpu.launch function still needs the kernel module.
+  module.walk([&](gpu::GPUModuleOp moduleOp) {
+    // Clone each GPU kernel module for conversion, given that the GPU
+    // launch op still needs the original GPU kernel module.
     builder.setInsertionPoint(moduleOp.getOperation());
-    kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
+    gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
   });
 
+  // Map MemRef memory space to SPIR-V sotrage class first if requested.
+  if (mapMemorySpace) {
+    std::unique_ptr<ConversionTarget> target =
+        spirv::getMemorySpaceToStorageClassTarget(*context);
+    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+        spirv::getDefaultVulkanStorageClassMap();
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+
+    RewritePatternSet patterns(context);
+    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+
+    if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
+      return signalPassFailure();
+  }
+
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
@@ -68,10 +88,11 @@ void GPUToSPIRVPass::runOnOperation() {
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateFuncToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
+  if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {
-  return std::make_unique<GPUToSPIRVPass>();
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) {
+  return std::make_unique<GPUToSPIRVPass>(mapMemorySpace);
 }

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 55da39cb98b16..d72802be6e1df 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -90,12 +90,12 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
 /// can be lowered to SPIR-V.
 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
   if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
-    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
+    auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+    if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
       return false;
   } else if (isa<memref::AllocaOp>(allocOp)) {
-    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
+    auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+    if (!sc || sc.getValue() != spirv::StorageClass::Function)
       return false;
   } else {
     return false;
@@ -116,12 +116,8 @@ static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
 /// operations of unsupported integer bitwidths, based on the memref
 /// type. Returns None on failure.
 static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass)
-    return {};
-  switch (*storageClass) {
+  auto sc = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  switch (sc.getValue()) {
   case spirv::StorageClass::StorageBuffer:
     return spirv::Scope::Device;
   case spirv::StorageClass::Workgroup:

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 9a94f195805d1..9154c811facac 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/Sequence.h"
@@ -117,65 +118,6 @@ Type SPIRVTypeConverter::getIndexType() const {
   return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
 }
 
-/// Mapping between SPIR-V storage classes to memref memory spaces.
-///
-/// Note: memref does not have a defined semantics for each memory space; it
-/// depends on the context where it is used. There are no particular reasons
-/// behind the number assignments; we try to follow NVVM conventions and largely
-/// give common storage classes a smaller number. The hope is use symbolic
-/// memory space representation eventually after memref supports it.
-// TODO: swap Generic and StorageBuffer assignment to be more akin
-// to NVVM.
-#define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
-  MAP_FN(spirv::StorageClass::Generic, 1)                                      \
-  MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
-  MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
-  MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
-  MAP_FN(spirv::StorageClass::Private, 5)                                      \
-  MAP_FN(spirv::StorageClass::Function, 6)                                     \
-  MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
-  MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
-  MAP_FN(spirv::StorageClass::Input, 9)                                        \
-  MAP_FN(spirv::StorageClass::Output, 10)                                      \
-  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
-  MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
-  MAP_FN(spirv::StorageClass::Image, 13)                                       \
-  MAP_FN(spirv::StorageClass::CallableDataKHR, 14)                             \
-  MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15)                     \
-  MAP_FN(spirv::StorageClass::RayPayloadKHR, 16)                               \
-  MAP_FN(spirv::StorageClass::HitAttributeKHR, 17)                             \
-  MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18)                       \
-  MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19)                       \
-  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)                       \
-  MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21)                            \
-  MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22)                             \
-  MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
-
-unsigned
-SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
-#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
-  case storage:                                                                \
-    return space;
-
-  switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
-#undef STORAGE_SPACE_MAP_FN
-  llvm_unreachable("unhandled storage class!");
-}
-
-Optional<spirv::StorageClass>
-SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
-#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
-  case space:                                                                  \
-    return storage;
-
-  switch (space) {
-    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
-  default:
-    return llvm::None;
-  }
-#undef STORAGE_SPACE_MAP_FN
-}
-
 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
   return options;
 }
@@ -184,8 +126,6 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
   return targetEnv.getAttr().getContext();
 }
 
-#undef STORAGE_SPACE_MAP_LIST
-
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t>
@@ -375,16 +315,8 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
 
 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
                                   const SPIRVTypeConverter::Options &options,
-                                  MemRefType type) {
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert memory space\n");
-    return nullptr;
-  }
-
+                                  MemRefType type,
+                                  spirv::StorageClass storageClass) {
   unsigned numBoolBits = options.boolNumBits;
   if (numBoolBits != 8) {
     LLVM_DEBUG(llvm::dbgs()
@@ -407,34 +339,37 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
-    return wrapInStructAndGetPointer(arrayType, *storageClass);
+    return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
   auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
-  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
-  return wrapInStructAndGetPointer(arrayType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                               const SPIRVTypeConverter::Options &options,
                               MemRefType type) {
-  if (type.getElementType().isa<IntegerType>() &&
-      type.getElementTypeBitWidth() == 1) {
-    return convertBoolMemrefType(targetEnv, options, type);
+  auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+  if (!attr) {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << type
+        << " illegal: expected memory space to be a SPIR-V storage class "
+           "attribute; please use MemorySpaceToStorageClassConverter to map "
+           "numeric memory spaces beforehand\n");
+    return nullptr;
   }
+  spirv::StorageClass storageClass = attr.getValue();
 
-  Optional<spirv::StorageClass> storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(
-          type.getMemorySpaceAsInt());
-  if (!storageClass) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert memory space\n");
-    return nullptr;
+  if (type.getElementType().isa<IntegerType>() &&
+      type.getElementTypeBitWidth() == 1) {
+    return convertBoolMemrefType(targetEnv, options, type, storageClass);
   }
 
   Type arrayElemType;
@@ -463,9 +398,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
-    return wrapInStructAndGetPointer(arrayType, *storageClass);
+    return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
@@ -476,10 +411,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
-  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
-  return wrapInStructAndGetPointer(arrayType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index a85642c8e0337..719af7f0a21de 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -274,29 +274,51 @@ module attributes {
 // CHECK-SAME: Private
 // CHECK-SAME: Function
 func.func @memref_mem_space(
-    %arg0: memref<4xf32, 0>,
-    %arg1: memref<4xf32, 4>,
-    %arg2: memref<4xf32, 3>,
-    %arg3: memref<4xf32, 7>,
-    %arg4: memref<4xf32, 5>,
-    %arg5: memref<4xf32, 6>
+    %arg0: memref<4xf32, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4xf32, #spv.storage_class<Uniform>>,
+    %arg2: memref<4xf32, #spv.storage_class<Workgroup>>,
+    %arg3: memref<4xf32, #spv.storage_class<PushConstant>>,
+    %arg4: memref<4xf32, #spv.storage_class<Private>>,
+    %arg5: memref<4xf32, #spv.storage_class<Function>>
 ) { return }
 
 // CHECK-LABEL: func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32>)>, Function>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<4x8xi1>
-// NOEMU-SAME: memref<4x8xi1, 6>
+// NOEMU-SAME: memref<4x8xi1, #spv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<4x8xi1, #spv.storage_class<Function>>
 func.func @memref_1bit_type(
-    %arg0: memref<4x8xi1, 0>,
-    %arg1: memref<4x8xi1, 6>
+    %arg0: memref<4x8xi1, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4x8xi1, #spv.storage_class<Function>>
 ) { return }
 
 } // end module
 
 // -----
 
+// Reject memory spaces.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @numeric_memref_mem_space1
+// CHECK-SAME: memref<4xf32>
+// NOEMU-LABEL: func @numeric_memref_mem_space1
+// NOEMU-SAME: memref<4xf32>
+func.func @numeric_memref_mem_space1(%arg0: memref<4xf32>) { return }
+
+// CHECK-LABEL: func @numeric_memref_mem_space2
+// CHECK-SAME: memref<4xf32, 3>
+// NOEMU-LABEL: func @numeric_memref_mem_space2
+// NOEMU-SAME: memref<4xf32, 3>
+func.func @numeric_memref_mem_space2(%arg0: memref<4xf32, 3>) { return }
+
+} // end module
+
+// -----
+
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
 // satisfied.
@@ -308,86 +330,86 @@ module attributes {
 // CHECK-LABEL: spv.func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<5xi1>
-func.func @memref_1bit_type(%arg0: memref<5xi1>) { return }
+// NOEMU-SAME: memref<5xi1, #spv.storage_class<StorageBuffer>>
+func.func @memref_1bit_type(%arg0: memref<5xi1, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
-// NOEMU-SAME: memref<16xi8>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+// NOEMU-SAME: memref<16xi8, #spv.storage_class<StorageBuffer>>
+func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_8bit_Uniform
-// NOEMU-SAME: memref<16xsi8, 4>
-func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
+// NOEMU-SAME: memref<16xsi8, #spv.storage_class<Uniform>>
+func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_8bit_PushConstant
-// NOEMU-SAME: memref<16xui8, 7>
-func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
+// NOEMU-SAME: memref<16xui8, #spv.storage_class<PushConstant>>
+func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_16bit_StorageBuffer
-// NOEMU-SAME: memref<16xi16>
-func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
+// NOEMU-SAME: memref<16xi16, #spv.storage_class<StorageBuffer>>
+func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_16bit_Uniform
-// NOEMU-SAME: memref<16xsi16, 4>
-func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
+// NOEMU-SAME: memref<16xsi16, #spv.storage_class<Uniform>>
+func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_16bit_PushConstant
-// NOEMU-SAME: memref<16xui16, 7>
-func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
+// NOEMU-SAME: memref<16xui16, #spv.storage_class<PushConstant>>
+func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
-// NOEMU-SAME: memref<16xf16, 9>
-func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+// NOEMU-SAME: memref<16xf16, #spv.storage_class<Input>>
+func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
-// NOEMU-SAME: memref<16xf16, 10>
-func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
+// NOEMU-SAME: memref<16xf16, #spv.storage_class<Output>>
+func.func @memref_16bit_Output(%arg4: memref<16xf16, #spv.storage_class<Output>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_64bit_StorageBuffer
-// NOEMU-SAME: memref<16xi64>
-func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return }
+// NOEMU-SAME: memref<16xi64, #spv.storage_class<StorageBuffer>>
+func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_64bit_Uniform
-// NOEMU-SAME: memref<16xsi64, 4>
-func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return }
+// NOEMU-SAME: memref<16xsi64, #spv.storage_class<Uniform>>
+func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_64bit_PushConstant
-// NOEMU-SAME: memref<16xui64, 7>
-func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return }
+// NOEMU-SAME: memref<16xui64, #spv.storage_class<PushConstant>>
+func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_64bit_Input
-// NOEMU-SAME: memref<16xf64, 9>
-func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return }
+// NOEMU-SAME: memref<16xf64, #spv.storage_class<Input>>
+func.func @memref_64bit_Input(%arg3: memref<16xf64, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_64bit_Output
-// NOEMU-SAME: memref<16xf64, 10>
-func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return }
+// NOEMU-SAME: memref<16xf64, #spv.storage_class<Output>>
+func.func @memref_64bit_Output(%arg4: memref<16xf64, #spv.storage_class<Output>>) { return }
 
 } // end module
 
@@ -406,7 +428,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
 // NOEMU-LABEL: spv.func @memref_8bit_PushConstant
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, PushConstant>
-func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
+func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
@@ -415,8 +437,8 @@ func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, PushConstant>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, PushConstant>
 func.func @memref_16bit_PushConstant(
-  %arg0: memref<16xi16, 7>,
-  %arg1: memref<16xf16, 7>
+  %arg0: memref<16xi16, #spv.storage_class<PushConstant>>,
+  %arg1: memref<16xf16, #spv.storage_class<PushConstant>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_PushConstant
@@ -426,8 +448,8 @@ func.func @memref_16bit_PushConstant(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
 func.func @memref_64bit_PushConstant(
-  %arg0: memref<16xi64, 7>,
-  %arg1: memref<16xf64, 7>
+  %arg0: memref<16xi64, #spv.storage_class<PushConstant>>,
+  %arg1: memref<16xf64, #spv.storage_class<PushConstant>>
 ) { return }
 
 } // end module
@@ -447,7 +469,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
 // NOEMU-LABEL: spv.func @memref_8bit_StorageBuffer
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, StorageBuffer>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
+func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
@@ -456,8 +478,8 @@ func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, StorageBuffer>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, StorageBuffer>
 func.func @memref_16bit_StorageBuffer(
-  %arg0: memref<16xi16, 0>,
-  %arg1: memref<16xf16, 0>
+  %arg0: memref<16xi16, #spv.storage_class<StorageBuffer>>,
+  %arg1: memref<16xf16, #spv.storage_class<StorageBuffer>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
@@ -467,8 +489,8 @@ func.func @memref_16bit_StorageBuffer(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
 func.func @memref_64bit_StorageBuffer(
-  %arg0: memref<16xi64, 0>,
-  %arg1: memref<16xf64, 0>
+  %arg0: memref<16xi64, #spv.storage_class<StorageBuffer>>,
+  %arg1: memref<16xf64, #spv.storage_class<StorageBuffer>>
 ) { return }
 
 } // end module
@@ -488,7 +510,7 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
 // NOEMU-LABEL: spv.func @memref_8bit_Uniform
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i8, stride=1> [0])>, Uniform>
-func.func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
+func.func @memref_8bit_Uniform(%arg0: memref<16xi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
@@ -497,8 +519,8 @@ func.func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Uniform>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Uniform>
 func.func @memref_16bit_Uniform(
-  %arg0: memref<16xi16, 4>,
-  %arg1: memref<16xf16, 4>
+  %arg0: memref<16xi16, #spv.storage_class<Uniform>>,
+  %arg1: memref<16xf16, #spv.storage_class<Uniform>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Uniform
@@ -508,8 +530,8 @@ func.func @memref_16bit_Uniform(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
 func.func @memref_64bit_Uniform(
-  %arg0: memref<16xi64, 4>,
-  %arg1: memref<16xf64, 4>
+  %arg0: memref<16xi64, #spv.storage_class<Uniform>>,
+  %arg1: memref<16xf64, #spv.storage_class<Uniform>>
 ) { return }
 
 } // end module
@@ -528,13 +550,13 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
 // NOEMU-LABEL: spv.func @memref_16bit_Input
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
-func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
+func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
 // NOEMU-LABEL: spv.func @memref_16bit_Output
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
-func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
+func.func @memref_16bit_Output(%arg4: memref<16xi16, #spv.storage_class<Output>>) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
@@ -543,8 +565,8 @@ func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
 func.func @memref_64bit_Input(
-  %arg0: memref<16xi64, 9>,
-  %arg1: memref<16xf64, 9>
+  %arg0: memref<16xi64, #spv.storage_class<Input>>,
+  %arg1: memref<16xf64, #spv.storage_class<Input>>
 ) { return }
 
 // CHECK-LABEL: spv.func @memref_64bit_Output
@@ -554,8 +576,8 @@ func.func @memref_64bit_Input(
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
 func.func @memref_64bit_Output(
-  %arg0: memref<16xi64, 10>,
-  %arg1: memref<16xf64, 10>
+  %arg0: memref<16xi64, #spv.storage_class<Output>>,
+  %arg1: memref<16xf64, #spv.storage_class<Output>>
 ) { return }
 
 } // end module
@@ -575,22 +597,22 @@ func.func @memref_offset_strides(
 // CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer>
-  %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>,  // tightly packed; row major
-  %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>,  // offset 8
-  %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row
-  %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major
-  %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col
+  %arg0: memref<16x4xf32, offset: 0, strides: [4, 1], #spv.storage_class<StorageBuffer>>,  // tightly packed; row major
+  %arg1: memref<16x4xf32, offset: 8, strides: [4, 1], #spv.storage_class<StorageBuffer>>,  // offset 8
+  %arg2: memref<16x4xf32, offset: 0, strides: [16, 1], #spv.storage_class<StorageBuffer>>, // pad 12 after each row
+  %arg3: memref<16x4xf32, offset: 0, strides: [1, 16], #spv.storage_class<StorageBuffer>>, // tightly packed; col major
+  %arg4: memref<16x4xf32, offset: 0, strides: [1, 22], #spv.storage_class<StorageBuffer>>, // pad 4 after each col
 
 // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer>
-  %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>,
-  %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>,
-  %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>,
-  %arg8: memref<16x4xf16, offset: 0, strides: [1, 16]>,
-  %arg9: memref<16x4xf16, offset: 0, strides: [1, 22]>
+  %arg5: memref<16x4xf16, offset: 0, strides: [4, 1], #spv.storage_class<StorageBuffer>>,
+  %arg6: memref<16x4xf16, offset: 8, strides: [4, 1], #spv.storage_class<StorageBuffer>>,
+  %arg7: memref<16x4xf16, offset: 0, strides: [16, 1], #spv.storage_class<StorageBuffer>>,
+  %arg8: memref<16x4xf16, offset: 0, strides: [1, 16], #spv.storage_class<StorageBuffer>>,
+  %arg9: memref<16x4xf16, offset: 0, strides: [1, 22], #spv.storage_class<StorageBuffer>>
 ) { return }
 
 } // end module
@@ -610,14 +632,15 @@ func.func @unranked_memref(%arg0: memref<*xi32>) { return }
 // CHECK-LABEL: func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_1bit_type
-// NOEMU-SAME: memref<?xi1>
-func.func @memref_1bit_type(%arg0: memref<?xi1>) { return }
+// NOEMU-SAME: memref<?xi1, #spv.storage_class<StorageBuffer>>
+func.func @memref_1bit_type(%arg0: memref<?xi1, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: func @dynamic_dim_memref
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
-func.func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
-                         %arg1: memref<?x?xf32>) { return }
+func.func @dynamic_dim_memref(
+    %arg0: memref<8x?xi32, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<?x?xf32, #spv.storage_class<StorageBuffer>>) { return }
 
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
@@ -626,50 +649,50 @@ func.func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
-// NOEMU-SAME: memref<?xi8>
-func.func @memref_8bit_StorageBuffer(%arg0: memref<?xi8, 0>) { return }
+// NOEMU-SAME: memref<?xi8, #spv.storage_class<StorageBuffer>>
+func.func @memref_8bit_StorageBuffer(%arg0: memref<?xi8, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_8bit_Uniform
-// NOEMU-SAME: memref<?xsi8, 4>
-func.func @memref_8bit_Uniform(%arg0: memref<?xsi8, 4>) { return }
+// NOEMU-SAME: memref<?xsi8, #spv.storage_class<Uniform>>
+func.func @memref_8bit_Uniform(%arg0: memref<?xsi8, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_8bit_PushConstant
-// NOEMU-SAME: memref<?xui8, 7>
-func.func @memref_8bit_PushConstant(%arg0: memref<?xui8, 7>) { return }
+// NOEMU-SAME: memref<?xui8, #spv.storage_class<PushConstant>>
+func.func @memref_8bit_PushConstant(%arg0: memref<?xui8, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_16bit_StorageBuffer
-// NOEMU-SAME: memref<?xi16>
-func.func @memref_16bit_StorageBuffer(%arg0: memref<?xi16, 0>) { return }
+// NOEMU-SAME: memref<?xi16, #spv.storage_class<StorageBuffer>>
+func.func @memref_16bit_StorageBuffer(%arg0: memref<?xi16, #spv.storage_class<StorageBuffer>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_16bit_Uniform
-// NOEMU-SAME: memref<?xsi16, 4>
-func.func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
+// NOEMU-SAME: memref<?xsi16, #spv.storage_class<Uniform>>
+func.func @memref_16bit_Uniform(%arg0: memref<?xsi16, #spv.storage_class<Uniform>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_16bit_PushConstant
-// NOEMU-SAME: memref<?xui16, 7>
-func.func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
+// NOEMU-SAME: memref<?xui16, #spv.storage_class<PushConstant>>
+func.func @memref_16bit_PushConstant(%arg0: memref<?xui16, #spv.storage_class<PushConstant>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
-// NOEMU-SAME: memref<?xf16, 9>
-func.func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
+// NOEMU-SAME: memref<?xf16, #spv.storage_class<Input>>
+func.func @memref_16bit_Input(%arg3: memref<?xf16, #spv.storage_class<Input>>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
-// NOEMU-SAME: memref<?xf16, 10>
-func.func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
+// NOEMU-SAME: memref<?xf16, #spv.storage_class<Output>>
+func.func @memref_16bit_Output(%arg4: memref<?xf16, #spv.storage_class<Output>>) { return }
 
 } // end module
 
@@ -684,15 +707,16 @@ module attributes {
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x vector<2xf32>, stride=8> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16> [0])>, Uniform>
 func.func @memref_vector(
-    %arg0: memref<4xvector<2xf32>, 0>,
-    %arg1: memref<4xvector<4xf32>, 4>)
+    %arg0: memref<4xvector<2xf32>, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<4xvector<4xf32>, #spv.storage_class<Uniform>>)
 { return }
 
 // CHECK-LABEL: func @dynamic_dim_memref_vector
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xi32>, stride=16> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
-func.func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
-                         %arg1: memref<?x?xvector<2xf32>>)
+func.func @dynamic_dim_memref_vector(
+    %arg0: memref<8x?xvector<4xi32>, #spv.storage_class<StorageBuffer>>,
+    %arg1: memref<?x?xvector<2xf32>, #spv.storage_class<StorageBuffer>>)
 { return }
 
 } // end module
@@ -705,9 +729,9 @@ module attributes {
 } {
 
 // CHECK-LABEL: func @memref_vector_wrong_size
-// CHECK-SAME: memref<4xvector<5xf32>>
+// CHECK-SAME: memref<4xvector<5xf32>, #spv.storage_class<StorageBuffer>>
 func.func @memref_vector_wrong_size(
-    %arg0: memref<4xvector<5xf32>, 0>)
+    %arg0: memref<4xvector<5xf32>, #spv.storage_class<StorageBuffer>>)
 { return }
 
 } // end module

diff  --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
similarity index 83%
rename from mlir/test/Conversion/GPUToSPIRV/simple.mlir
rename to mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
index c6b45ba491905..00a97fc1a990f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir
@@ -7,7 +7,7 @@ module attributes {gpu.container_module} {
     // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}
     // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
     // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
-    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) kernel
+    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       // CHECK: spv.Return
       gpu.return
@@ -16,11 +16,11 @@ module attributes {gpu.container_module} {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@basic_module_structure
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }
@@ -39,7 +39,7 @@ module attributes {gpu.container_module} {
     gpu.func @basic_module_structure_preset_ABI(
       %arg0 : f32
         {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
-      %arg1 : memref<12xf32>
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>
         {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
       attributes
         {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
@@ -55,18 +55,18 @@ module attributes {gpu.container_module} {
   gpu.module @kernels {
     // expected-error @below {{failed to legalize operation 'gpu.func'}}
     // expected-remark @below {{match failure: missing 'spv.entry_point_abi' attribute}}
-    gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32>) kernel {
+    gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel {
       gpu.return
     }
   }
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@missing_entry_point_abi
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }
@@ -80,7 +80,7 @@ module attributes {gpu.container_module} {
     gpu.func @missing_entry_point_abi(
       %arg0 : f32
         {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>},
-      %arg1 : memref<12xf32>) kernel
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
     attributes
       {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       gpu.return
@@ -96,7 +96,7 @@ module attributes {gpu.container_module} {
     // expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 0}}
     gpu.func @missing_entry_point_abi(
       %arg0 : f32,
-      %arg1 : memref<12xf32>
+      %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>
         {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel
     attributes
       {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
@@ -110,7 +110,7 @@ module attributes {gpu.container_module} {
 module attributes {gpu.container_module} {
   gpu.module @kernels {
     // CHECK-LABEL: spv.func @barrier
-    gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
+    gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<StorageBuffer>>) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       // CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
       gpu.barrier
@@ -120,11 +120,11 @@ module attributes {gpu.container_module} {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<StorageBuffer>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@barrier
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<StorageBuffer>>)
     return
   }
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index bbc01bffff5c3..abce5d7542d7c 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -5,7 +5,7 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
 } {
-  func.func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) {
+  func.func @load_store(%arg0: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2: memref<12x4xf32, #spv.storage_class<StorageBuffer>>) {
     %c0 = arith.constant 0 : index
     %c12 = arith.constant 12 : index
     %0 = arith.subi %c12, %c0 : index
@@ -17,7 +17,7 @@ module attributes {
     %c1_2 = arith.constant 1 : index
     gpu.launch_func @kernels::@load_store_kernel
         blocks in (%0, %c1_2, %c1_2) threads in (%1, %c1_2, %c1_2)
-        args(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>, %arg2 : memref<12x4xf32>,
+        args(%arg0 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2 : memref<12x4xf32, #spv.storage_class<StorageBuffer>>,
              %c0 : index, %c0_0 : index, %c1 : index, %c1_1 : index)
     return
   }
@@ -35,7 +35,7 @@ module attributes {
     // CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
     // CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
     // CHECK-SAME: %[[ARG6:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
-    gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
+    gpu.func @load_store_kernel(%arg0: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg1: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg2: memref<12x4xf32, #spv.storage_class<StorageBuffer>>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 1, 1]>: vector<3xi32>>} {
       // CHECK: %[[ADDRESSWORKGROUPID:.*]] = spv.mlir.addressof @[[$WORKGROUPIDVAR]]
       // CHECK: %[[WORKGROUPID:.*]] = spv.Load "Input" %[[ADDRESSWORKGROUPID]]
@@ -69,15 +69,15 @@ module attributes {
       // CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
       // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]]
-      %14 = memref.load %arg0[%12, %13] : memref<12x4xf32>
+      %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       // CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: %[[VAL2:.*]] = spv.Load "StorageBuffer" %[[PTR2]]
-      %15 = memref.load %arg1[%12, %13] : memref<12x4xf32>
+      %15 = memref.load %arg1[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       // CHECK: %[[VAL3:.*]] = spv.FAdd %[[VAL1]], %[[VAL2]]
       %16 = arith.addf %14, %15 : f32
       // CHECK: %[[PTR3:.*]] = spv.AccessChain %[[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
       // CHECK-NEXT: spv.Store "StorageBuffer" %[[PTR3]], %[[VAL3]]
-      memref.store %16, %arg2[%12, %13] : memref<12x4xf32>
+      memref.store %16, %arg2[%12, %13] : memref<12x4xf32, #spv.storage_class<StorageBuffer>>
       gpu.return
     }
   }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
similarity index 84%
rename from mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
rename to mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
index 2d500e35c61a2..ac22d13815e6e 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
@@ -12,7 +12,7 @@ module attributes {
     //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>
     //   CHECK-NOT:     spv.interface_var_abi
     //  CHECK-SAME:     spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
-    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel
+    gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>) kernel
         attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
       gpu.return
     }
@@ -20,11 +20,11 @@ module attributes {
 
   func.func @main() {
     %0 = "op"() : () -> (f32)
-    %1 = "op"() : () -> (memref<12xf32, 11>)
+    %1 = "op"() : () -> (memref<12xf32, #spv.storage_class<CrossWorkgroup>>)
     %cst = arith.constant 1 : index
     gpu.launch_func @kernels::@basic_module_structure
         blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
-        args(%0 : f32, %1 : memref<12xf32, 11>)
+        args(%0 : f32, %1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>)
     return
   }
 }

diff  --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
index 3b0af88a299e6..6aeb60c8cc3e4 100644
--- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
+++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
@@ -44,12 +44,12 @@ module attributes {
 // CHECK:        }
 // CHECK:        spv.Return
 
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 1, 1]>: vector<3xi32>>
 } {
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -74,11 +74,11 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -103,13 +103,13 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class<StorageBuffer>>, %output: memref<1xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 1, 1]>: vector<3xi32>>
 } {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16xi32>)
-     outs(%output : memref<1xi32>) {
+      ins(%input : memref<16xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<1xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32
@@ -134,13 +134,13 @@ module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, #spv.resource_limits<>>
 } {
-func.func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes {
+func.func @single_workgroup_reduction(%input: memref<16x8xi32, #spv.storage_class<StorageBuffer>>, %output: memref<16xi32, #spv.storage_class<StorageBuffer>>) attributes {
   spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[16, 8, 1]>: vector<3xi32>>
 } {
   // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
   linalg.generic #single_workgroup_reduction_trait
-      ins(%input : memref<16x8xi32>)
-     outs(%output : memref<16xi32>) {
+      ins(%input : memref<16x8xi32, #spv.storage_class<StorageBuffer>>)
+     outs(%output : memref<16xi32, #spv.storage_class<StorageBuffer>>) {
     ^bb(%in: i32, %out: i32):
       %sum = arith.addi %in, %out : i32
       linalg.yield %sum : i32

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index e07b382bafe99..2edc37eb82e68 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -6,10 +6,10 @@ module attributes {
   }
 {
   func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloc() : memref<4x5xf32, 3>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 3>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 3>
-    memref.dealloc %0 : memref<4x5xf32, 3>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xf32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -31,10 +31,10 @@ module attributes {
   }
 {
   func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloc() : memref<4x5xi16, 3>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, 3>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
-    memref.dealloc %0 : memref<4x5xi16, 3>
+    %0 = memref.alloc() : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xi16, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -60,8 +60,8 @@ module attributes {
   }
 {
   func.func @two_allocs() {
-    %0 = memref.alloc() : memref<4x5xf32, 3>
-    %1 = memref.alloc() : memref<2x3xi32, 3>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<2x3xi32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -80,8 +80,8 @@ module attributes {
   }
 {
   func.func @two_allocs_vector() {
-    %0 = memref.alloc() : memref<4xvector<4xf32>, 3>
-    %1 = memref.alloc() : memref<2xvector<2xi32>, 3>
+    %0 = memref.alloc() : memref<4xvector<4xf32>, #spv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<2xvector<2xi32>, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -103,8 +103,8 @@ module attributes {
   // CHECK-LABEL: func @alloc_dynamic_size
   func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
     // CHECK: memref.alloc
-    %0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
+    %0 = memref.alloc(%arg0) : memref<4x?xf32, #spv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class<Workgroup>>
     return %1: f32
   }
 }
@@ -119,8 +119,8 @@ module attributes {
   // CHECK-LABEL: func @alloc_unsupported_memory_space
   func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
     // CHECK: memref.alloc
-    %0 = memref.alloc() : memref<4x5xf32>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+    %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
     return %1: f32
   }
 }
@@ -134,9 +134,9 @@ module attributes {
   }
 {
   // CHECK-LABEL: func @dealloc_dynamic_size
-  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
+  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, #spv.storage_class<Workgroup>>) {
     // CHECK: memref.dealloc
-    memref.dealloc %arg0 : memref<4x?xf32, 3>
+    memref.dealloc %arg0 : memref<4x?xf32, #spv.storage_class<Workgroup>>
     return
   }
 }
@@ -149,9 +149,9 @@ module attributes {
   }
 {
   // CHECK-LABEL: func @dealloc_unsupported_memory_space
-  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
+  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32, #spv.storage_class<StorageBuffer>>) {
     // CHECK: memref.dealloc
-    memref.dealloc %arg0 : memref<4x5xf32>
+    memref.dealloc %arg0 : memref<4x5xf32, #spv.storage_class<StorageBuffer>>
     return
   }
 }

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
index e2cd90e3204d0..80081280eb41d 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -2,9 +2,9 @@
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
-    %0 = memref.alloca() : memref<4x5xf32, 6>
-    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
-    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
+    %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class<Function>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Function>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class<Function>>
     return
   }
 }
@@ -21,8 +21,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @two_allocs() {
-    %0 = memref.alloca() : memref<4x5xf32, 6>
-    %1 = memref.alloca() : memref<2x3xi32, 6>
+    %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class<Function>>
+    %1 = memref.alloca() : memref<2x3xi32, #spv.storage_class<Function>>
     return
   }
 }
@@ -35,8 +35,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 
 module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, #spv.resource_limits<>>} {
   func.func @two_allocs_vector() {
-    %0 = memref.alloca() : memref<4xvector<4xf32>, 6>
-    %1 = memref.alloca() : memref<2xvector<2xi32>, 6>
+    %0 = memref.alloca() : memref<4xvector<4xf32>, #spv.storage_class<Function>>
+    %1 = memref.alloca() : memref<2xvector<2xi32>, #spv.storage_class<Function>>
     return
   }
 }
@@ -52,8 +52,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
   // CHECK-LABEL: func @alloc_dynamic_size
   func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
     // CHECK: memref.alloca
-    %0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
-    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
+    %0 = memref.alloca(%arg0) : memref<4x?xf32, #spv.storage_class<Function>>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class<Function>>
     return %1: f32
   }
 }

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 1cd94f9b761f7..212363c15584b 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -15,60 +15,60 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_store_zero_rank_float
-func.func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_zero_rank_float(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>, %arg1: memref<f32, #spv.storage_class<StorageBuffer>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
   //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]], [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
-  %0 = memref.load %arg0[] : memref<f32>
+  %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]], [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Store "StorageBuffer" %{{.*}} : f32
-  memref.store %0, %arg1[] : memref<f32>
+  memref.store %0, %arg1[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_store_zero_rank_int
-func.func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_zero_rank_int(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>, %arg1: memref<i32, #spv.storage_class<StorageBuffer>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
   //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]], [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
-  %0 = memref.load %arg0[] : memref<i32>
+  %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
   //      CHECK: spv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]], [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spv.Store "StorageBuffer" %{{.*}} : i32
-  memref.store %0, %arg1[] : memref<i32>
+  memref.store %0, %arg1[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: func @load_store_unknown_dim
-func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
-  // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
-  // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spv.storage_class<StorageBuffer>>, %dest: memref<?xi32, #spv.storage_class<StorageBuffer>>) {
+  // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+  // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
   // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
   // CHECK: spv.Load "StorageBuffer" %[[AC0]]
-  %0 = memref.load %source[%i] : memref<?xi32>
+  %0 = memref.load %source[%i] : memref<?xi32, #spv.storage_class<StorageBuffer>>
   // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
   // CHECK: spv.Store "StorageBuffer" %[[AC1]]
-  memref.store %0, %dest[%i]: memref<?xi32>
+  memref.store %0, %dest[%i]: memref<?xi32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: func @load_i1
-//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index)
-func.func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
-  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class<StorageBuffer>>, %[[IDX:.+]]: index)
+func.func @load_i1(%src: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
   // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
@@ -79,17 +79,17 @@ func.func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
   // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8
   // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
   // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
-  %0 = memref.load %src[%i] : memref<4xi1>
+  %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class<StorageBuffer>>
   // CHECK: return %[[BOOL]]
   return %0: i1
 }
 
 // CHECK-LABEL: func @store_i1
-//  CHECK-SAME: %[[DST:.+]]: memref<4xi1>,
+//  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class<StorageBuffer>>,
 //  CHECK-SAME: %[[IDX:.+]]: index
-func.func @store_i1(%dst: memref<4xi1>, %i: index) {
+func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i: index) {
   %true = arith.constant true
-  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class<StorageBuffer>> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
   // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
@@ -101,7 +101,7 @@ func.func @store_i1(%dst: memref<4xi1>, %i: index) {
   // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
   // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
   // CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8
-  memref.store %true, %dst[%i]: memref<4xi1>
+  memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class<StorageBuffer>>
   return
 }
 
@@ -118,7 +118,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_i1
-func.func @load_i1(%arg0: memref<i1>) -> i1 {
+func.func @load_i1(%arg0: memref<i1, #spv.storage_class<StorageBuffer>>) -> i1 {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -138,12 +138,12 @@ func.func @load_i1(%arg0: memref<i1>) -> i1 {
   //     CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
   //     CHECK: %[[RES:.+]]  = spv.IEqual %[[T4]], %[[ONE]] : i32
   //     CHECK: return %[[RES]]
-  %0 = memref.load %arg0[] : memref<i1>
+  %0 = memref.load %arg0[] : memref<i1, #spv.storage_class<StorageBuffer>>
   return %0 : i1
 }
 
 // CHECK-LABEL: @load_i8
-func.func @load_i8(%arg0: memref<i8>) {
+func.func @load_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>) {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -159,13 +159,13 @@ func.func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[T2:.+]] = spv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[] : memref<i8>
+  %0 = memref.load %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i16
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index)
-func.func @load_i16(%arg0: memref<10xi16>, %index : index) {
+func.func @load_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index : index) {
   //     CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32
@@ -186,31 +186,31 @@ func.func @load_i16(%arg0: memref<10xi16>, %index : index) {
   //     CHECK: %[[T2:.+]] = spv.Constant 16 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[%index] : memref<10xi16>
+  %0 = memref.load %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i32
-func.func @load_i32(%arg0: memref<i32>) {
+func.func @load_i32(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<i32>
+  %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_f32
-func.func @load_f32(%arg0: memref<f32>) {
+func.func @load_f32(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<f32>
+  %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i1
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
-func.func @store_i1(%arg0: memref<i1>, %value: i1) {
+func.func @store_i1(%arg0: memref<i1, #spv.storage_class<StorageBuffer>>, %value: i1) {
   //     CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
@@ -230,13 +230,13 @@ func.func @store_i1(%arg0: memref<i1>, %value: i1) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i1>
+  memref.store %value, %arg0[] : memref<i1, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
-func.func @store_i8(%arg0: memref<i8>, %value: i8) {
+func.func @store_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@@ -254,13 +254,13 @@ func.func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i8>
+  memref.store %value, %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i16
-//       CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
-func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+//       CHECK: (%[[ARG0:.+]]: memref<10xi16, #spv.storage_class<StorageBuffer>>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
+func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
   //     CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
@@ -283,25 +283,25 @@ func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[%index] : memref<10xi16>
+  memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i32
-func.func @store_i32(%arg0: memref<i32>, %value: i32) {
+func.func @store_i32(%arg0: memref<i32, #spv.storage_class<StorageBuffer>>, %value: i32) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[] : memref<i32>
+  memref.store %value, %arg0[] : memref<i32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_f32
-func.func @store_f32(%arg0: memref<f32>, %value: f32) {
+func.func @store_f32(%arg0: memref<f32, #spv.storage_class<StorageBuffer>>, %value: f32) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[] : memref<f32>
+  memref.store %value, %arg0[] : memref<f32, #spv.storage_class<StorageBuffer>>
   return
 }
 
@@ -318,7 +318,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @load_i8
-func.func @load_i8(%arg0: memref<i8>) {
+func.func @load_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>) {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
   //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
@@ -334,22 +334,22 @@ func.func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[T2:.+]] = spv.Constant 24 : i32
   //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
   //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
-  %0 = memref.load %arg0[] : memref<i8>
+  %0 = memref.load %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @load_i16
-func.func @load_i16(%arg0: memref<i16>) {
+func.func @load_i16(%arg0: memref<i16, #spv.storage_class<StorageBuffer>>) {
   // CHECK-NOT: spv.SDiv
   //     CHECK: spv.Load
   // CHECK-NOT: spv.ShiftRightArithmetic
-  %0 = memref.load %arg0[] : memref<i16>
+  %0 = memref.load %arg0[] : memref<i16, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
-func.func @store_i8(%arg0: memref<i8>, %value: i8) {
+func.func @store_i8(%arg0: memref<i8, #spv.storage_class<StorageBuffer>>, %value: i8) {
   //     CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
   //     CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@@ -367,16 +367,16 @@ func.func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]]
   //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
   //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
-  memref.store %value, %arg0[] : memref<i8>
+  memref.store %value, %arg0[] : memref<i8, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // CHECK-LABEL: @store_i16
-func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
+func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class<StorageBuffer>>, %index: index, %value: i16) {
   //     CHECK: spv.Store
   // CHECK-NOT: spv.AtomicAnd
   // CHECK-NOT: spv.AtomicOr
-  memref.store %value, %arg0[%index] : memref<10xi16>
+  memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class<StorageBuffer>>
   return
 }
 

diff  --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir
index 517163758b8ef..54a8e938187d3 100644
--- a/mlir/test/Conversion/SCFToSPIRV/for.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir
@@ -5,7 +5,7 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spv.resource_limits<>>
 } {
 
-func.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+func.func @loop_kernel(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>) {
   // CHECK: %[[LB:.*]] = spv.Constant 4 : i32
   %lb = arith.constant 4 : index
   // CHECK: %[[UB:.*]] = spv.Constant 42 : i32
@@ -36,14 +36,14 @@ func.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
   // CHECK:        spv.mlir.merge
   // CHECK:      }
   scf.for %arg4 = %lb to %ub step %step {
-    %1 = memref.load %arg2[%arg4] : memref<10xf32>
-    memref.store %1, %arg3[%arg4] : memref<10xf32>
+    %1 = memref.load %arg2[%arg4] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+    memref.store %1, %arg3[%arg4] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
   return
 }
 
 // CHECK-LABEL: @loop_yield
-func.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
+func.func @loop_yield(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>) {
   // CHECK: %[[LB:.*]] = spv.Constant 4 : i32
   %lb = arith.constant 4 : index
   // CHECK: %[[UB:.*]] = spv.Constant 42 : i32
@@ -78,8 +78,8 @@ func.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) {
   // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
   // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
   // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
-  memref.store %result#0, %arg3[%lb] : memref<10xf32>
-  memref.store %result#1, %arg3[%ub] : memref<10xf32>
+  memref.store %result#0, %arg3[%lb] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+  memref.store %result#1, %arg3[%ub] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 

diff  --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir
index f937ac6c4e06f..d8463b9db2beb 100644
--- a/mlir/test/Conversion/SCFToSPIRV/if.mlir
+++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir
@@ -6,7 +6,7 @@ module attributes {
 } {
 
 // CHECK-LABEL: @kernel_simple_selection
-func.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) {
+func.func @kernel_simple_selection(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : i1) {
   %value = arith.constant 0.0 : f32
   %i = arith.constant 0 : index
 
@@ -20,13 +20,13 @@ func.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) {
   // CHECK-NEXT:  spv.Return
 
   scf.if %arg3 {
-    memref.store %value, %arg2[%i] : memref<10xf32>
+    memref.store %value, %arg2[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
   return
 }
 
 // CHECK-LABEL: @kernel_nested_selection
-func.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) {
+func.func @kernel_nested_selection(%arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg4 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg5 : i1, %arg6 : i1) {
   %i = arith.constant 0 : index
   %j = arith.constant 9 : index
 
@@ -61,26 +61,26 @@ func.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32
 
   scf.if %arg5 {
     scf.if %arg6 {
-      %value = memref.load %arg3[%i] : memref<10xf32>
-      memref.store %value, %arg4[%i] : memref<10xf32>
+      %value = memref.load %arg3[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg4[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     } else {
-      %value = memref.load %arg4[%i] : memref<10xf32>
-      memref.store %value, %arg3[%i] : memref<10xf32>
+      %value = memref.load %arg4[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg3[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     }
   } else {
     scf.if %arg6 {
-      %value = memref.load %arg3[%j] : memref<10xf32>
-      memref.store %value, %arg4[%j] : memref<10xf32>
+      %value = memref.load %arg3[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg4[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     } else {
-      %value = memref.load %arg4[%j] : memref<10xf32>
-      memref.store %value, %arg3[%j] : memref<10xf32>
+      %value = memref.load %arg4[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+      memref.store %value, %arg3[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
     }
   }
   return
 }
 
 // CHECK-LABEL: @simple_if_yield
-func.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) {
+func.func @simple_if_yield(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : i1) {
   // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
   // CHECK:       spv.mlir.selection {
@@ -116,15 +116,15 @@ func.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) {
   }
   %i = arith.constant 0 : index
   %j = arith.constant 1 : index
-  memref.store %0#0, %arg2[%i] : memref<10xf32>
-  memref.store %0#1, %arg2[%j] : memref<10xf32>
+  memref.store %0#0, %arg2[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
+  memref.store %0#1, %arg2[%j] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 
 // TODO: The transformation should only be legal if VariablePointer capability
 // is supported. This test is still useful to make sure we can handle scf op
 // result with type change.
-func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) {
+func.func @simple_if_yield_type_change(%arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>, %arg4 : i1) {
   // CHECK-LABEL: @simple_if_yield_type_change
   // CHECK:       %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<(!spv.array<10 x f32, stride=4> [0])>, StorageBuffer>, Function>
   // CHECK:       spv.mlir.selection {
@@ -144,12 +144,12 @@ func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10
   // CHECK:       spv.Return
   %i = arith.constant 0 : index
   %value = arith.constant 0.0 : f32
-  %0 = scf.if %arg4 -> (memref<10xf32>) {
-    scf.yield %arg2 : memref<10xf32>
+  %0 = scf.if %arg4 -> (memref<10xf32, #spv.storage_class<StorageBuffer>>) {
+    scf.yield %arg2 : memref<10xf32, #spv.storage_class<StorageBuffer>>
   } else {
-    scf.yield %arg3 : memref<10xf32>
+    scf.yield %arg3 : memref<10xf32, #spv.storage_class<StorageBuffer>>
   }
-  memref.store %value, %0[%i] : memref<10xf32>
+  memref.store %value, %0[%i] : memref<10xf32, #spv.storage_class<StorageBuffer>>
   return
 }
 

diff  --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
index 9a3c1be9e8ea3..b553d1a554039 100644
--- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
+++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
@@ -75,7 +75,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   PassManager passManager(module.getContext());
   applyPassManagerCLOptions(passManager);
   passManager.addPass(createGpuKernelOutliningPass());
-  passManager.addPass(createConvertGPUToSPIRVPass());
+  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
 
   OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
   nestedPM.addPass(spirv::createLowerABIAttributesPass());

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 7e61ea746b968..d942ef96f4c4f 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -47,10 +47,12 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
 
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(memref::createFoldSubViewOpsPass());
-  passManager.addPass(createConvertGPUToSPIRVPass());
+
+  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createLowerABIAttributesPass());
   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
+
   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
   passManager.addPass(createMemRefToLLVMPass());
@@ -58,6 +60,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
   passManager.addPass(createReconcileUnrealizedCastsPass());
   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+
   return passManager.run(module);
 }
 


        


More information about the Mlir-commits mailing list