[Mlir-commits] [mlir] ec99d6e - [mlir][spirv] Add a `spirv::InterfaceVarABIAttr`.

Denis Khalikov llvmlistbot at llvm.org
Mon Apr 13 12:49:28 PDT 2020


Author: Denis Khalikov
Date: 2020-04-13T22:47:47+03:00
New Revision: ec99d6e62f0a3b1146bf670e90cd8f48c62be41e

URL: https://github.com/llvm/llvm-project/commit/ec99d6e62f0a3b1146bf670e90cd8f48c62be41e
DIFF: https://github.com/llvm/llvm-project/commit/ec99d6e62f0a3b1146bf670e90cd8f48c62be41e.diff

LOG: [mlir][spirv] Add a `spirv::InterfaceVarABIAttr`.

Summary:
Add a proper dialect-specific attribute for interface variable ABI.

Differential Revision: https://reviews.llvm.org/D77941

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
    mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/GPUToSPIRV/simple.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
    mlir/test/Dialect/SPIRV/target-and-abi.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index de896b8e7daf..76ba0b893171 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -883,14 +883,30 @@ interfaces:
 *   `spv.entry_point_abi` is a struct attribute that should be attached to the
     entry function. It contains:
     *   `local_size` for specifying the local work group size for the dispatch.
-*   `spv.interface_var_abi` is a struct attribute that should be attached to
-    each operand and result of the entry function. It contains:
-    *   `descriptor_set` for specifying the descriptor set number for the
-        corresponding resource variable.
-    *   `binding` for specifying the binding number for the corresponding
-        resource variable.
-    *   `storage_class` for specifying the storage class for the corresponding
-        resource variable.
+*   `spv.interface_var_abi` is attribute that should be attached to each operand
+    and result of the entry function. It should be of `#spv.interface_var_abi`
+    attribute kind, which is defined as:
+
+```
+spv-storage-class     ::= `StorageBuffer` | ...
+spv-descriptor-set    ::= integer-literal
+spv-binding           ::= integer-literal
+spv-interface-var-abi ::= `#` `spv.interface_var_abi` `<(` spv-descriptor-set
+                          `,` spv-binding `)` (`,` spv-storage-class)? `>`
+```
+
+For example,
+
+```
+#spv.interface_var_abi<(0, 0), StorageBuffer>
+#spv.interface_var_abi<(0, 1)>
+```
+
+The attribute has a few fields:
+
+*   Descriptor set number for the corresponding resource variable.
+*   Binding number for the corresponding resource variable.
+*   Storage class for the corresponding resource variable.
 
 The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
 consuming these attributes and create SPIR-V module complying with the

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
index f7cbbbe75757..36344b41d6cb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
 #define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
 
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/Support/LLVM.h"
 
@@ -26,6 +27,7 @@ enum class Extension;
 enum class Version : uint32_t;
 
 namespace detail {
+struct InterfaceVarABIAttributeStorage;
 struct TargetEnvAttributeStorage;
 struct VerCapExtAttributeStorage;
 } // namespace detail
@@ -33,11 +35,49 @@ struct VerCapExtAttributeStorage;
 /// SPIR-V dialect-specific attribute kinds.
 namespace AttrKind {
 enum Kind {
-  TargetEnv = Attribute::FIRST_SPIRV_ATTR, /// Target environment
+  InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
+  TargetEnv,                                     /// Target environment
   VerCapExt, /// (version, extension, capability) triple
 };
 } // namespace AttrKind
 
+/// An attribute that specifies the information regarding the interface
+/// variable: descriptor set, binding, storage class.
+class InterfaceVarABIAttr
+    : public Attribute::AttrBase<InterfaceVarABIAttr, Attribute,
+                                 detail::InterfaceVarABIAttributeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets a InterfaceVarABIAttr.
+  static InterfaceVarABIAttr get(uint32_t descirptorSet, uint32_t binding,
+                                 Optional<StorageClass> storageClass,
+                                 MLIRContext *context);
+  static InterfaceVarABIAttr get(IntegerAttr descriptorSet, IntegerAttr binding,
+                                 IntegerAttr storageClass);
+
+  /// Returns the attribute kind's name (without the 'spv.' prefix).
+  static StringRef getKindName();
+
+  /// Returns descriptor set.
+  uint32_t getDescriptorSet();
+
+  /// Returns binding.
+  uint32_t getBinding();
+
+  /// Returns `spirv::StorageClass`.
+  Optional<StorageClass> getStorageClass();
+
+  static bool kindof(unsigned kind) {
+    return kind == AttrKind::InterfaceVarABI;
+  }
+
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    IntegerAttr descriptorSet,
+                                                    IntegerAttr binding,
+                                                    IntegerAttr storageClass);
+};
+
 /// An attribute that specifies the SPIR-V (version, capabilities, extensions)
 /// triple.
 class VerCapExtAttr

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
index 5d08aa1f2d7c..231ec54f09f4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
@@ -23,18 +23,6 @@
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
 
-// For arguments that eventually map to spv.globalVariable for the
-// shader interface, this attribute specifies the information regarding
-// the global variable:
-// 1) Descriptor Set.
-// 2) Binding number.
-// 3) Storage class.
-def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
-    StructFieldAttr<"descriptor_set", I32Attr>,
-    StructFieldAttr<"binding", I32Attr>,
-    StructFieldAttr<"storage_class", OptionalAttr<SPV_StorageClassAttr>>
-]>;
-
 // For entry functions, this attribute specifies information related to entry
 // points in the generated SPIR-V module:
 // 1) WorkGroup Size.

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
index 4ce42a5a7ee1..b2df52b07608 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
@@ -25,6 +25,32 @@ namespace mlir {
 
 namespace spirv {
 namespace detail {
+
+struct InterfaceVarABIAttributeStorage : public AttributeStorage {
+  using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
+
+  InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
+                                  Attribute storageClass)
+      : descriptorSet(descriptorSet), binding(binding),
+        storageClass(storageClass) {}
+
+  bool operator==(const KeyTy &key) const {
+    return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
+           std::get<2>(key) == storageClass;
+  }
+
+  static InterfaceVarABIAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
+        InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
+                                        std::get<2>(key));
+  }
+
+  Attribute descriptorSet;
+  Attribute binding;
+  Attribute storageClass;
+};
+
 struct VerCapExtAttributeStorage : public AttributeStorage {
   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
 
@@ -72,6 +98,74 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
 } // namespace spirv
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// InterfaceVarABIAttr
+//===----------------------------------------------------------------------===//
+
+spirv::InterfaceVarABIAttr
+spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
+                                Optional<spirv::StorageClass> storageClass,
+                                MLIRContext *context) {
+  Builder b(context);
+  auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
+  auto bindingAttr = b.getI32IntegerAttr(binding);
+  auto storageClassAttr =
+      storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
+                   : IntegerAttr();
+  return get(descriptorSetAttr, bindingAttr, storageClassAttr);
+}
+
+spirv::InterfaceVarABIAttr
+spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
+                                IntegerAttr storageClass) {
+  assert(descriptorSet && binding);
+  MLIRContext *context = descriptorSet.getContext();
+  return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
+                   binding, storageClass);
+}
+
+StringRef spirv::InterfaceVarABIAttr::getKindName() {
+  return "interface_var_abi";
+}
+
+uint32_t spirv::InterfaceVarABIAttr::getBinding() {
+  return getImpl()->binding.cast<IntegerAttr>().getInt();
+}
+
+uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
+  return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
+}
+
+Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
+  if (getImpl()->storageClass)
+    return static_cast<spirv::StorageClass>(
+        getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
+  return llvm::None;
+}
+
+LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
+    Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
+    IntegerAttr storageClass) {
+  if (!descriptorSet.getType().isSignlessInteger(32))
+    return emitError(loc, "expected 32-bit integer for descriptor set");
+
+  if (!binding.getType().isSignlessInteger(32))
+    return emitError(loc, "expected 32-bit integer for binding");
+
+  if (storageClass) {
+    if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
+      auto storageClassValue =
+          spirv::symbolizeStorageClass(storageClassAttr.getInt());
+      if (!storageClassValue)
+        return emitError(loc, "unknown storage class");
+    } else {
+      return emitError(loc, "expected valid storage class");
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // VerCapExtAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index e811fe6ec40a..ce4e5906de95 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -118,7 +118,7 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
 
-  addAttributes<TargetEnvAttr, VerCapExtAttr>();
+  addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
 
   // Add SPIR-V ops.
   addOperations<
@@ -649,6 +649,75 @@ static ParseResult parseKeywordList(
   return success();
 }
 
+/// Parses a spirv::InterfaceVarABIAttr.
+static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  Builder &builder = parser.getBuilder();
+
+  if (parser.parseLParen())
+    return {};
+
+  IntegerAttr descriptorSetAttr;
+  {
+    auto loc = parser.getCurrentLocation();
+    uint32_t descriptorSet = 0;
+    auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
+
+    if (!descriptorSetParseResult.hasValue() ||
+        failed(*descriptorSetParseResult)) {
+      parser.emitError(loc, "missing descriptor set");
+      return {};
+    }
+    descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
+  }
+
+  if (parser.parseComma())
+    return {};
+
+  IntegerAttr bindingAttr;
+  {
+    auto loc = parser.getCurrentLocation();
+    uint32_t binding = 0;
+    auto bindingParseResult = parser.parseOptionalInteger(binding);
+
+    if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
+      parser.emitError(loc, "missing binding");
+      return {};
+    }
+    bindingAttr = builder.getI32IntegerAttr(binding);
+  }
+
+  if (parser.parseRParen())
+    return {};
+
+  IntegerAttr storageClassAttr;
+  {
+    if (succeeded(parser.parseOptionalComma())) {
+      auto loc = parser.getCurrentLocation();
+      StringRef storageClass;
+      if (parser.parseKeyword(&storageClass))
+        return {};
+
+      if (auto storageClassSymbol =
+              spirv::symbolizeStorageClass(storageClass)) {
+        storageClassAttr = builder.getI32IntegerAttr(
+            static_cast<uint32_t>(*storageClassSymbol));
+      } else {
+        parser.emitError(loc, "unknown storage class: ") << storageClass;
+        return {};
+      }
+    }
+  }
+
+  if (parser.parseGreater())
+    return {};
+
+  return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
+                                         storageClassAttr);
+}
+
 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
   if (parser.parseLess())
     return {};
@@ -771,6 +840,8 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
     return parseTargetEnvAttr(parser);
   if (attrKind == spirv::VerCapExtAttr::getKindName())
     return parseVerCapExtAttr(parser);
+  if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
+    return parseInterfaceVarABIAttr(parser);
 
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
       << attrKind;
@@ -801,12 +872,25 @@ static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
   printer << ", " << targetEnv.getResourceLimits() << ">";
 }
 
+static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
+                  DialectAsmPrinter &printer) {
+  printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
+          << interfaceVarABIAttr.getDescriptorSet() << ", "
+          << interfaceVarABIAttr.getBinding() << ")";
+  auto storageClass = interfaceVarABIAttr.getStorageClass();
+  if (storageClass)
+    printer << ", " << spirv::stringifyStorageClass(*storageClass);
+  printer << ">";
+}
+
 void SPIRVDialect::printAttribute(Attribute attr,
                                   DialectAsmPrinter &printer) const {
   if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
     print(targetEnv, printer);
   else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
     print(vceAttr, printer);
+  else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
+    print(interfaceVarABIAttr, printer);
   else
     llvm_unreachable("unhandled SPIR-V attribute kind");
 }
@@ -866,11 +950,9 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
   if (!varABIAttr)
     return emitError(loc, "'")
-           << symbol
-           << "' attribute must be a dictionary attribute containing two or "
-              "three 32-bit integer attributes: 'descriptor_set', 'binding', "
-              "and optional 'storage_class'";
-  if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat())
+           << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
     return emitError(loc, "'") << symbol
                                << "' attribute cannot specify storage class "
                                   "when attaching to a non-scalar value";

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 491fcf9a6f21..2bc99b695056 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -86,14 +86,8 @@ spirv::InterfaceVarABIAttr
 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
                               Optional<spirv::StorageClass> storageClass,
                               MLIRContext *context) {
-  Type i32Type = IntegerType::get(32, context);
-  auto scAttr =
-      storageClass
-          ? IntegerAttr::get(i32Type, static_cast<int64_t>(*storageClass))
-          : IntegerAttr();
-  return spirv::InterfaceVarABIAttr::get(
-      IntegerAttr::get(i32Type, descriptorSet),
-      IntegerAttr::get(i32Type, binding), scAttr, context);
+  return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
+                                         context);
 }
 
 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index d6b32436c0b4..daee70976ac2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -41,10 +41,11 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
   // it must already be a !spv.ptr<!spv.struct<...>>.
   auto varType = funcOp.getType().getInput(argIndex);
   if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
-    auto storageClass =
-        static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
+    auto storageClass = abiInfo.getStorageClass();
+    if (!storageClass)
+      return nullptr;
     varType =
-        spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
+        spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
   }
   auto varPtrType = varType.cast<spirv::PointerType>();
   auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
@@ -56,8 +57,8 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
       spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
 
   return builder.create<spirv::GlobalVariableOp>(
-      funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(),
-      abiInfo.binding().getInt());
+      funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
+      abiInfo.getBinding());
 }
 
 /// Gets the global variables that need to be specified as interface variable

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index 13dc621af8b7..94f7c650fa0d 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -27,13 +27,13 @@ module attributes {
     // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-LABEL:    spv.func @load_store_kernel
-    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
+    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}
+    // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}
+    // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
+    // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
+    // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
     gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
       attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
       // CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]

diff  --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index 0d0b4c891337..81b842a11c96 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -4,8 +4,8 @@ module attributes {gpu.container_module} {
   gpu.module @kernels {
     // CHECK:       spv.module Logical GLSL450 {
     // CHECK-LABEL: spv.func @basic_module_structure
-    // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
+    // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}
+    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
     // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>)
       attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index 187c5741d7f0..28c44bf7b936 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -14,12 +14,9 @@ spv.module Logical GLSL450 {
   // CHECK:    spv.func [[FN:@.*]]()
   spv.func @kernel(
     %arg0: f32
-           {spv.interface_var_abi = {binding = 0 : i32,
-                                     descriptor_set = 0 : i32,
-                                     storage_class = 12 : i32}},
+           {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>},
     %arg1: !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
-           {spv.interface_var_abi = {binding = 1 : i32,
-                                     descriptor_set = 0 : i32}}) "None"
+           {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None"
   attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
     // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]]
     // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index c75c4d0f979c..075ef3398d83 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -27,30 +27,19 @@ spv.module Logical GLSL450 {
   // CHECK: spv.func [[FN:@.*]]()
   spv.func @load_store_kernel(
     %arg0: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-    {spv.interface_var_abi = {binding = 0 : i32,
-                              descriptor_set = 0 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>},
     %arg1: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-    {spv.interface_var_abi = {binding = 1 : i32,
-                              descriptor_set = 0 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>},
     %arg2: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-    {spv.interface_var_abi = {binding = 2 : i32,
-                              descriptor_set = 0 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>},
     %arg3: i32
-    {spv.interface_var_abi = {binding = 3 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>},
     %arg4: i32
-    {spv.interface_var_abi = {binding = 4 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>},
     %arg5: i32
-    {spv.interface_var_abi = {binding = 5 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}},
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>},
     %arg6: i32
-    {spv.interface_var_abi = {binding = 6 : i32,
-                              descriptor_set = 0 : i32,
-                              storage_class = 12 : i32}}) "None"
+    {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
   attributes  {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
     // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
     // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32

diff  --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
index 2c380e8ff039..8d11f4ca0c64 100644
--- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
@@ -51,34 +51,51 @@ func @spv_entry_point() attributes {
 // spv.interface_var_abi
 //===----------------------------------------------------------------------===//
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' must be a spirv::InterfaceVarABIAttr}}
 func @interface_var(
   %arg0 : f32 {spv.interface_var_abi = 64}
 ) { return }
 
 // -----
 
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
 func @interface_var(
-  %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
+// expected-error @+1 {{missing descriptor set}}
+  %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<()>}
 ) { return }
 
 // -----
 
-// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}
 func @interface_var(
-  %arg0 : f32 {spv.interface_var_abi = {binding = 0 : i32,
-                                        descriptor_set = 0 : i32,
-                                        storage_class = 12 : i32}}
+// expected-error @+1 {{missing binding}}
+  %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,)>}
+) { return }
+
+// -----
+
+func @interface_var(
+// expected-error @+1 {{unknown storage class: }}
+  %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,2), Foo>}
+) { return }
+
+// -----
+
+// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
+func @interface_var(
+    %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
+) { return }
+
+// -----
+
+// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
+func @interface_var(
+    %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
 ) { return }
 
 // -----
 
 // expected-error @+1 {{'spv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}}
 func @interface_var(
-  %arg0 : memref<4xf32> {spv.interface_var_abi = {binding = 0 : i32,
-                                                  descriptor_set = 0 : i32,
-                                                  storage_class = 12 : i32}}
+  %arg0 : memref<4xf32> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
 ) { return }
 
 // -----


        


More information about the Mlir-commits mailing list