[Mlir-commits] [mlir] 3b35f9d - [mlir][spirv] Use memref memory space for storage class

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:20 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:04-04:00
New Revision: 3b35f9d8b51018c0a4301c0d7a5b81bbe33863ee

URL: https://github.com/llvm/llvm-project/commit/3b35f9d8b51018c0a4301c0d7a5b81bbe33863ee
DIFF: https://github.com/llvm/llvm-project/commit/3b35f9d8b51018c0a4301c0d7a5b81bbe33863ee.diff

LOG: [mlir][spirv] Use memref memory space for storage class

Previously in SPIRVTypeConverter, we always convert memref types
to StorageBuffer regardless of their memory spaces. This commit
fixes that to let the conversion to look into memory space
properly. For this purpose, a mapping between SPIR-V storage class
and memref memory space is introduced. The mapping is arbitary
decided at the moment and the hope is that we can leverage
string memory space later to be more clear.

Now spv.interface_var_abi cannot contain storage class unless it's
attached to a scalar value, where we need the storage class as side
channel information. Verifications and tests are properly adjusted.

Differential Revision: https://reviews.llvm.org/D76241

Added: 
    mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
    mlir/include/mlir/IR/Types.h
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/GPUToSPIRV/simple.mlir
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
    mlir/test/Dialect/SPIRV/target-and-abi.mlir

Removed: 
    mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index 85b42eeea291..b29c62e67116 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -31,6 +31,15 @@ class SPIRVTypeConverter : public TypeConverter {
 
   /// Gets the SPIR-V correspondence for the standard index type.
   static Type getIndexType(MLIRContext *context);
+
+  /// Returns the corresponding memory space for memref given a SPIR-V storage
+  /// class.
+  static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
+
+  /// Returns the SPIR-V storage class given a memory space for memref. Return
+  /// llvm::None if the memory space does not map to any SPIR-V storage class.
+  static Optional<spirv::StorageClass>
+  getStorageClassForMemorySpace(unsigned space);
 };
 
 /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index 5ffd00c530c6..bf9c51e5b110 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -28,7 +28,7 @@ StringRef getInterfaceVarABIAttrName();
 /// Gets the InterfaceVarABIAttr given its fields.
 InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
                                            unsigned binding,
-                                           StorageClass storageClass,
+                                           Optional<StorageClass> storageClass,
                                            MLIRContext *context);
 
 /// Returns the attribute name for specifying entry point information.

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
index a463f0e8da95..5d08aa1f2d7c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
@@ -32,7 +32,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
 def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
     StructFieldAttr<"descriptor_set", I32Attr>,
     StructFieldAttr<"binding", I32Attr>,
-    StructFieldAttr<"storage_class", SPV_StorageClassAttr>
+    StructFieldAttr<"storage_class", OptionalAttr<SPV_StorageClassAttr>>
 ]>;
 
 // For entry functions, this attribute specifies information related to entry

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index eccc90cdae0c..e45fa9037470 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,8 +169,11 @@ class Type {
   /// Return true of this is a signless integer or a float type.
   bool isSignlessIntOrFloat();
 
-  /// Return true of this is an integer(of any signedness) or a float type.
+  /// Return true if this is an integer (of any signedness) or a float type.
   bool isIntOrFloat();
+  /// Return true if this is an integer (of any signedness), index, or float
+  /// type.
+  bool isIntOrIndexOrFloat();
 
   /// Print the current type.
   void print(raw_ostream &os);

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 533ef7f53b92..5483c2330c20 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -349,10 +349,15 @@ LogicalResult GPUFuncOpConversion::matchAndRewrite(
   if (!gpu::GPUDialect::isKernel(funcOp))
     return failure();
 
+  // TODO(antiagainst): we are dictating the ABI by ourselves here; it should be
+  // specified outside.
   SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
-  for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
-    argABI.push_back(spirv::getInterfaceVarABIAttr(
-        0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
+  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+    Optional<spirv::StorageClass> sc;
+    if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
+      sc = spirv::StorageClass::StorageBuffer;
+    argABI.push_back(
+        spirv::getInterfaceVarABIAttr(0, argIndex, sc, rewriter.getContext()));
   }
 
   auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index 50ecf9ef7cbd..f2868a34f076 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -61,7 +61,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
                        BlockAndValueMapping &) const final {
     // Return true here when inlining into spv.func, spv.selection, and
     // spv.loop operations.
-    auto op = dest->getParentOp();
+    auto *op = dest->getParentOp();
     return isa<spirv::FuncOp>(op) || isa<spirv::SelectionOp>(op) ||
            isa<spirv::LoopOp>(op);
   }
@@ -383,7 +383,8 @@ namespace {
 // parseAndVerify does the actual parsing and verification of individual
 // elements. This is a functor since parsing the last element of the list
 // (termination condition) needs partial specialization.
-template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
+template <typename ParseType, typename... Args>
+struct ParseCommaSeparatedList {
   Optional<std::tuple<ParseType, Args...>>
   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
@@ -393,7 +394,7 @@ template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
     if (numArgs != 0 && failed(parser.parseComma()))
       return llvm::None;
-    auto remainingValues = parseCommaSeparatedList<Args...>{}(dialect, parser);
+    auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
     if (!remainingValues)
       return llvm::None;
     return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
@@ -403,7 +404,8 @@ template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
 
 // Partial specialization of the function to parse a comma separated list of
 // specs to parse the last element of the list.
-template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
+template <typename ParseType>
+struct ParseCommaSeparatedList<ParseType> {
   Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
                                              DialectAsmParser &parser) const {
     if (auto value = parseAndVerify<ParseType>(dialect, parser))
@@ -434,7 +436,7 @@ static Type parseImageType(SPIRVDialect const &dialect,
     return Type();
 
   auto value =
-      parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
+      ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                               ImageSamplingInfo, ImageSamplerUseInfo,
                               ImageFormat>{}(dialect, parser);
   if (!value)
@@ -597,10 +599,10 @@ static void print(StructType type, DialectAsmPrinter &os) {
         if (!decorations.empty())
           os << ", ";
       }
-      auto each_fn = [&os](spirv::Decoration decoration) {
+      auto eachFn = [&os](spirv::Decoration decoration) {
         os << stringifyDecoration(decoration);
       };
-      interleaveComma(decorations, os, each_fn);
+      interleaveComma(decorations, os, eachFn);
       os << "]";
     }
   };
@@ -865,39 +867,44 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
   return success();
 }
 
-// Verifies the given SPIR-V `attribute` attached to a region's argument or
-// result and reports error to the given location if invalid.
-static LogicalResult
-verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) {
+/// Verifies the given SPIR-V `attribute` attached to a value of the given
+/// `valueType` is valid.
+static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
+                                           NamedAttribute attribute) {
   StringRef symbol = attribute.first.strref();
   Attribute attr = attribute.second;
 
   if (symbol != spirv::getInterfaceVarABIAttrName())
     return emitError(loc, "found unsupported '")
-           << symbol << "' attribute on region "
-           << (forArg ? "argument" : "result");
+           << symbol << "' attribute on region argument";
 
-  if (!attr.isa<spirv::InterfaceVarABIAttr>())
+  auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
+  if (!varABIAttr)
     return emitError(loc, "'")
            << symbol
-           << "' attribute must be a dictionary attribute containing three "
-              "32-bit integer attributes: 'descriptor_set', 'binding', and "
-              "'storage_class'";
+           << "' attribute must be a dictionary attribute containing two or "
+              "three 32-bit integer attributes: 'descriptor_set', 'binding', "
+              "and optional 'storage_class'";
+  if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat())
+    return emitError(loc, "'") << symbol
+                               << "' attribute cannot specify storage class "
+                                  "when attaching to a non-scalar value";
 
   return success();
 }
 
 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
-                                                     unsigned /*regionIndex*/,
-                                                     unsigned /*argIndex*/,
+                                                     unsigned regionIndex,
+                                                     unsigned argIndex,
                                                      NamedAttribute attribute) {
-  return verifyRegionAttribute(op->getLoc(), attribute,
-                               /*forArg=*/true);
+  return verifyRegionAttribute(
+      op->getLoc(),
+      op->getRegion(regionIndex).front().getArgument(argIndex).getType(),
+      attribute);
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
     NamedAttribute attribute) {
-  return verifyRegionAttribute(op->getLoc(), attribute,
-                               /*forArg=*/false);
+  return op->emitError("cannot attach SPIR-V attributes to region result");
 }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 4adabdaa597e..e9250c56a1d2 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -38,6 +38,63 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
   return IntegerType::get(32, context);
 }
 
+/// Mapping between SPIR-V storage classes to memref memory spaces.
+///
+/// Note: memref does not have a defined smenatics for each memory space; it
+/// depends on the context where it is used. There are no particular reasons
+/// behind the number assigments; we try to follow NVVM conventions and largely
+/// give common storage classes a smaller number. The hope is use symbolic
+/// memory space representation eventually after memref supports it.
+// TODO(antiagainst): swap Generic and StorageBuffer assignment to be more akin
+// to NVVM.
+#define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
+  MAP_FN(spirv::StorageClass::Generic, 1)                                      \
+  MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
+  MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
+  MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
+  MAP_FN(spirv::StorageClass::Private, 5)                                      \
+  MAP_FN(spirv::StorageClass::Function, 6)                                     \
+  MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
+  MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
+  MAP_FN(spirv::StorageClass::Input, 9)                                        \
+  MAP_FN(spirv::StorageClass::Output, 10)                                      \
+  MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
+  MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
+  MAP_FN(spirv::StorageClass::Image, 13)                                       \
+  MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
+  MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
+  MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
+  MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
+  MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
+  MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
+  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
+
+unsigned
+SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case storage:                                                                \
+    return space;
+
+  switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
+#undef STORAGE_SPACE_MAP_FN
+}
+
+Optional<spirv::StorageClass>
+SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
+#define STORAGE_SPACE_MAP_FN(storage, space)                                   \
+  case space:                                                                  \
+    return storage;
+
+  switch (space) {
+    STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
+  default:
+    return llvm::None;
+  }
+#undef STORAGE_SPACE_MAP_FN
+}
+
+#undef STORAGE_SPACE_MAP_LIST
+
 // TODO(ravishankarm): This is a utility function that should probably be
 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t> getTypeNumBytes(Type t) {
@@ -110,14 +167,6 @@ SPIRVTypeConverter::SPIRVTypeConverter() {
     return SPIRVTypeConverter::getIndexType(indexType.getContext());
   });
   addConversion([this](MemRefType memRefType) -> Type {
-    // TODO(ravishankarm): For now only support default memory space. The memory
-    // space description is not set is stone within MLIR, i.e. it depends on the
-    // context it is being used. To map this to SPIR-V storage classes, we
-    // should rely on the ABI attributes, and not on the memory space. This is
-    // still evolving, and needs to be revisited when there is more clarity.
-    if (memRefType.getMemorySpace())
-      return Type();
-
     auto elementType = convertType(memRefType.getElementType());
     if (!elementType)
       return Type();
@@ -135,11 +184,12 @@ SPIRVTypeConverter::SPIRVTypeConverter() {
       auto arrayType = spirv::ArrayType::get(
           elementType, arraySize.getValue() / elementSize.getValue(),
           elementSize.getValue());
+
+      // Wrap in a struct to satisfy Vulkan interface requirements.
       auto structType = spirv::StructType::get(arrayType, 0);
-      // For now initialize the storage class to StorageBuffer. This will be
-      // updated later based on whats passed in w.r.t to the ABI attributes.
-      return spirv::PointerType::get(structType,
-                                     spirv::StorageClass::StorageBuffer);
+      if (auto sc = getStorageClassForMemorySpace(memRefType.getMemorySpace()))
+        return spirv::PointerType::get(structType, *sc);
+      return Type();
     }
     return Type();
   });

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 88f3037ccc1e..16f79349ac1e 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -21,13 +21,16 @@ StringRef spirv::getInterfaceVarABIAttrName() {
 
 spirv::InterfaceVarABIAttr
 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
-                              spirv::StorageClass storageClass,
+                              Optional<spirv::StorageClass> storageClass,
                               MLIRContext *context) {
   Type i32Type = IntegerType::get(32, context);
+  auto scAttr =
+      storageClass
+          ? IntegerAttr::get(i32Type, static_cast<int64_t>(*storageClass))
+          : IntegerAttr();
   return spirv::InterfaceVarABIAttr::get(
       IntegerAttr::get(i32Type, descriptorSet),
-      IntegerAttr::get(i32Type, binding),
-      IntegerAttr::get(i32Type, static_cast<int64_t>(storageClass)), context);
+      IntegerAttr::get(i32Type, binding), scAttr, context);
 }
 
 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 4dbc54ecfca2..cb986fd8b282 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -29,25 +29,25 @@ static bool isScalarOrVectorType(Type type) {
 
 /// Creates a global variable for an argument based on the ABI info.
 static spirv::GlobalVariableOp
-createGlobalVariableForArg(spirv::FuncOp funcOp, OpBuilder &builder,
-                           unsigned argNum,
-                           spirv::InterfaceVarABIAttr abiInfo) {
+createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
+                                     unsigned argIndex,
+                                     spirv::InterfaceVarABIAttr abiInfo) {
   auto spirvModule = funcOp.getParentOfType<spirv::ModuleOp>();
-  if (!spirvModule) {
+  if (!spirvModule)
     return nullptr;
-  }
+
   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
   builder.setInsertionPoint(funcOp.getOperation());
   std::string varName =
-      funcOp.getName().str() + "_arg_" + std::to_string(argNum);
+      funcOp.getName().str() + "_arg_" + std::to_string(argIndex);
 
   // Get the type of variable. If this is a scalar/vector type and has an ABI
-  // info create a variable of type !spv.ptr<!spv.struct<elementTYpe>>. If not
+  // info create a variable of type !spv.ptr<!spv.struct<elementType>>. If not
   // it must already be a !spv.ptr<!spv.struct<...>>.
-  auto varType = funcOp.getType().getInput(argNum);
-  auto storageClass =
-      static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
+  auto varType = funcOp.getType().getInput(argIndex);
   if (isScalarOrVectorType(varType)) {
+    auto storageClass =
+        static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
     varType =
         spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
   }
@@ -84,9 +84,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
     auto var =
         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
-    if (var.type().cast<spirv::PointerType>().getStorageClass() !=
-        spirv::StorageClass::StorageBuffer) {
+    // TODO(antiagainst): Per SPIR-V spec: "Before version 1.4, the interface’s
+    // storage classes are limited to the Input and Output storage classes.
+    // Starting with version 1.4, the interface’s storage classes are all
+    // storage classes used in declaring all global variables referenced by the
+    // entry point’s call tree." We should consider the target environment here.
+    switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
+    case spirv::StorageClass::Input:
+    case spirv::StorageClass::Output:
       interfaceVarSet.insert(var.getOperation());
+      break;
+    default:
+      break;
     }
   });
   for (auto &var : interfaceVarSet) {
@@ -173,11 +182,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
       // produce an error.
       return failure();
     }
-    auto var =
-        createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
-    if (!var) {
+    spirv::GlobalVariableOp var = createGlobalVarForEntryPointArgument(
+        rewriter, funcOp, argType.index(), abiInfo);
+    if (!var)
       return failure();
-    }
 
     OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
     rewriter.setInsertionPointToStart(&funcOp.front());

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 488601cdb16b..1e7d9f38a2ce 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -86,6 +86,8 @@ bool Type::isSignlessIntOrFloat() {
 
 bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
 
+bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
+
 //===----------------------------------------------------------------------===//
 // Integer Type
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index 6588de870057..d0224fd16e02 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -21,9 +21,9 @@ module attributes {gpu.container_module} {
     // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-LABEL:    spv.func @load_store_kernel
-    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}

diff  --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index d9b32a6e571b..3076cd04b9fe 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -5,7 +5,7 @@ module attributes {gpu.container_module} {
     // CHECK:       spv.module Logical GLSL450 {
     // CHECK-LABEL: spv.func @basic_module_structure
     // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
     // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>)
       attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 341df27460a0..9b8d695af422 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -313,6 +313,22 @@ func @memref_type(%arg0: memref<3xi1>) {
   return
 }
 
+// CHECK-LABEL: func @memref_mem_space
+// CHECK-SAME: StorageBuffer
+// CHECK-SAME: Uniform
+// CHECK-SAME: Workgroup
+// CHECK-SAME: PushConstant
+// CHECK-SAME: Private
+// CHECK-SAME: Function
+func @memref_mem_space(
+    %arg0: memref<4xf32, 0>,
+    %arg1: memref<4xf32, 4>,
+    %arg2: memref<4xf32, 3>,
+    %arg3: memref<4xf32, 7>,
+    %arg4: memref<4xf32, 5>,
+    %arg5: memref<4xf32, 6>
+) { return }
+
 // CHECK-LABEL: @load_store_zero_rank_float
 // CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
 // CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
similarity index 65%
rename from mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir
rename to mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index edc66c41591c..a1f662300412 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -5,14 +5,14 @@ spv.module Logical GLSL450 {
   // CHECK-DAG:    spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
   // CHECK-DAG:    spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
   // CHECK:    spv.func [[FN:@.*]]()
-  spv.func @kernel(%arg0: f32
-                {spv.interface_var_abi = {binding = 0 : i32,
-                                          descriptor_set = 0 : i32,
-                                          storage_class = 12 : i32}},
-                 %arg1: !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
-                 {spv.interface_var_abi = {binding = 1 : i32,
-                                           descriptor_set = 0 : i32,
-                                           storage_class = 12 : i32}}) "None"
+  spv.func @kernel(
+    %arg0: f32
+           {spv.interface_var_abi = {binding = 0 : i32,
+                                     descriptor_set = 0 : i32,
+                                     storage_class = 12 : i32}},
+    %arg1: !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
+           {spv.interface_var_abi = {binding = 1 : i32,
+                                     descriptor_set = 0 : i32}}) "None"
   attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
     // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]]
     // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index d8af9fa82607..f3158e310d79 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -21,16 +21,13 @@ spv.module Logical GLSL450 {
   spv.func @load_store_kernel(
     %arg0: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
     {spv.interface_var_abi = {binding = 0 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+                              descriptor_set = 0 : i32}},
     %arg1: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
     {spv.interface_var_abi = {binding = 1 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+                              descriptor_set = 0 : i32}},
     %arg2: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
     {spv.interface_var_abi = {binding = 2 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+                              descriptor_set = 0 : i32}},
     %arg3: i32
     {spv.interface_var_abi = {binding = 3 : i32,
                               descriptor_set = 0 : i32,

diff  --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
index a28ca29e0ab9..2c380e8ff039 100644
--- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
@@ -14,7 +14,7 @@ func @unknown_attr_on_region(%arg: i32 {spv.something}) {
 
 // -----
 
-// expected-error @+1 {{found unsupported 'spv.something' attribute on region result}}
+// expected-error @+1 {{cannot attach SPIR-V attributes to region result}}
 func @unknown_attr_on_region() -> (i32 {spv.something}) {
   %0 = constant 10.0 : f32
   return %0: f32
@@ -51,14 +51,14 @@ func @spv_entry_point() attributes {
 // spv.interface_var_abi
 //===----------------------------------------------------------------------===//
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
 func @interface_var(
   %arg0 : f32 {spv.interface_var_abi = 64}
 ) { return }
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
 func @interface_var(
   %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
 ) { return }
@@ -74,31 +74,12 @@ func @interface_var(
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
-func @interface_var() -> (f32 {spv.interface_var_abi = 64})
-{
-  %0 = constant 10.0 : f32
-  return %0: f32
-}
-
-// -----
-
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
-func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}})
-{
-  %0 = constant 10.0 : f32
-  return %0: f32
-}
-
-// -----
-
-// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}
-func @interface_var() -> (f32 {spv.interface_var_abi = {
-    binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}})
-{
-  %0 = constant 10.0 : f32
-  return %0: f32
-}
+// expected-error @+1 {{'spv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}}
+func @interface_var(
+  %arg0 : memref<4xf32> {spv.interface_var_abi = {binding = 0 : i32,
+                                                  descriptor_set = 0 : i32,
+                                                  storage_class = 12 : i32}}
+) { return }
 
 // -----
 


        


More information about the Mlir-commits mailing list