[Mlir-commits] [mlir] ec99d6e - [mlir][spirv] Add a `spirv::InterfaceVarABIAttr`.
Denis Khalikov
llvmlistbot at llvm.org
Mon Apr 13 12:49:28 PDT 2020
Author: Denis Khalikov
Date: 2020-04-13T22:47:47+03:00
New Revision: ec99d6e62f0a3b1146bf670e90cd8f48c62be41e
URL: https://github.com/llvm/llvm-project/commit/ec99d6e62f0a3b1146bf670e90cd8f48c62be41e
DIFF: https://github.com/llvm/llvm-project/commit/ec99d6e62f0a3b1146bf670e90cd8f48c62be41e.diff
LOG: [mlir][spirv] Add a `spirv::InterfaceVarABIAttr`.
Summary:
Add a proper dialect-specific attribute for interface variable ABI.
Differential Revision: https://reviews.llvm.org/D77941
Added:
Modified:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/simple.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
mlir/test/Dialect/SPIRV/target-and-abi.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index de896b8e7daf..76ba0b893171 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -883,14 +883,30 @@ interfaces:
* `spv.entry_point_abi` is a struct attribute that should be attached to the
entry function. It contains:
* `local_size` for specifying the local work group size for the dispatch.
-* `spv.interface_var_abi` is a struct attribute that should be attached to
- each operand and result of the entry function. It contains:
- * `descriptor_set` for specifying the descriptor set number for the
- corresponding resource variable.
- * `binding` for specifying the binding number for the corresponding
- resource variable.
- * `storage_class` for specifying the storage class for the corresponding
- resource variable.
+* `spv.interface_var_abi` is attribute that should be attached to each operand
+ and result of the entry function. It should be of `#spv.interface_var_abi`
+ attribute kind, which is defined as:
+
+```
+spv-storage-class ::= `StorageBuffer` | ...
+spv-descriptor-set ::= integer-literal
+spv-binding ::= integer-literal
+spv-interface-var-abi ::= `#` `spv.interface_var_abi` `<(` spv-descriptor-set
+ `,` spv-binding `)` (`,` spv-storage-class)? `>`
+```
+
+For example,
+
+```
+#spv.interface_var_abi<(0, 0), StorageBuffer>
+#spv.interface_var_abi<(0, 1)>
+```
+
+The attribute has a few fields:
+
+* Descriptor set number for the corresponding resource variable.
+* Binding number for the corresponding resource variable.
+* Storage class for the corresponding resource variable.
The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for
consuming these attributes and create SPIR-V module complying with the
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
index f7cbbbe75757..36344b41d6cb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
#define MLIR_DIALECT_SPIRV_SPIRVATTRIBUTES_H
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Support/LLVM.h"
@@ -26,6 +27,7 @@ enum class Extension;
enum class Version : uint32_t;
namespace detail {
+struct InterfaceVarABIAttributeStorage;
struct TargetEnvAttributeStorage;
struct VerCapExtAttributeStorage;
} // namespace detail
@@ -33,11 +35,49 @@ struct VerCapExtAttributeStorage;
/// SPIR-V dialect-specific attribute kinds.
namespace AttrKind {
enum Kind {
- TargetEnv = Attribute::FIRST_SPIRV_ATTR, /// Target environment
+ InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI
+ TargetEnv, /// Target environment
VerCapExt, /// (version, extension, capability) triple
};
} // namespace AttrKind
+/// An attribute that specifies the information regarding the interface
+/// variable: descriptor set, binding, storage class.
+class InterfaceVarABIAttr
+ : public Attribute::AttrBase<InterfaceVarABIAttr, Attribute,
+ detail::InterfaceVarABIAttributeStorage> {
+public:
+ using Base::Base;
+
+ /// Gets a InterfaceVarABIAttr.
+ static InterfaceVarABIAttr get(uint32_t descirptorSet, uint32_t binding,
+ Optional<StorageClass> storageClass,
+ MLIRContext *context);
+ static InterfaceVarABIAttr get(IntegerAttr descriptorSet, IntegerAttr binding,
+ IntegerAttr storageClass);
+
+ /// Returns the attribute kind's name (without the 'spv.' prefix).
+ static StringRef getKindName();
+
+ /// Returns descriptor set.
+ uint32_t getDescriptorSet();
+
+ /// Returns binding.
+ uint32_t getBinding();
+
+ /// Returns `spirv::StorageClass`.
+ Optional<StorageClass> getStorageClass();
+
+ static bool kindof(unsigned kind) {
+ return kind == AttrKind::InterfaceVarABI;
+ }
+
+ static LogicalResult verifyConstructionInvariants(Location loc,
+ IntegerAttr descriptorSet,
+ IntegerAttr binding,
+ IntegerAttr storageClass);
+};
+
/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
/// triple.
class VerCapExtAttr
diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
index 5d08aa1f2d7c..231ec54f09f4 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
@@ -23,18 +23,6 @@
include "mlir/Dialect/SPIRV/SPIRVBase.td"
-// For arguments that eventually map to spv.globalVariable for the
-// shader interface, this attribute specifies the information regarding
-// the global variable:
-// 1) Descriptor Set.
-// 2) Binding number.
-// 3) Storage class.
-def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
- StructFieldAttr<"descriptor_set", I32Attr>,
- StructFieldAttr<"binding", I32Attr>,
- StructFieldAttr<"storage_class", OptionalAttr<SPV_StorageClassAttr>>
-]>;
-
// For entry functions, this attribute specifies information related to entry
// points in the generated SPIR-V module:
// 1) WorkGroup Size.
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
index 4ce42a5a7ee1..b2df52b07608 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp
@@ -25,6 +25,32 @@ namespace mlir {
namespace spirv {
namespace detail {
+
+struct InterfaceVarABIAttributeStorage : public AttributeStorage {
+ using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
+
+ InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
+ Attribute storageClass)
+ : descriptorSet(descriptorSet), binding(binding),
+ storageClass(storageClass) {}
+
+ bool operator==(const KeyTy &key) const {
+ return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
+ std::get<2>(key) == storageClass;
+ }
+
+ static InterfaceVarABIAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
+ InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
+ std::get<2>(key));
+ }
+
+ Attribute descriptorSet;
+ Attribute binding;
+ Attribute storageClass;
+};
+
struct VerCapExtAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
@@ -72,6 +98,74 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
} // namespace spirv
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// InterfaceVarABIAttr
+//===----------------------------------------------------------------------===//
+
+spirv::InterfaceVarABIAttr
+spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
+ Optional<spirv::StorageClass> storageClass,
+ MLIRContext *context) {
+ Builder b(context);
+ auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
+ auto bindingAttr = b.getI32IntegerAttr(binding);
+ auto storageClassAttr =
+ storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
+ : IntegerAttr();
+ return get(descriptorSetAttr, bindingAttr, storageClassAttr);
+}
+
+spirv::InterfaceVarABIAttr
+spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
+ IntegerAttr storageClass) {
+ assert(descriptorSet && binding);
+ MLIRContext *context = descriptorSet.getContext();
+ return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet,
+ binding, storageClass);
+}
+
+StringRef spirv::InterfaceVarABIAttr::getKindName() {
+ return "interface_var_abi";
+}
+
+uint32_t spirv::InterfaceVarABIAttr::getBinding() {
+ return getImpl()->binding.cast<IntegerAttr>().getInt();
+}
+
+uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
+ return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
+}
+
+Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
+ if (getImpl()->storageClass)
+ return static_cast<spirv::StorageClass>(
+ getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
+ return llvm::None;
+}
+
+LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
+ Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
+ IntegerAttr storageClass) {
+ if (!descriptorSet.getType().isSignlessInteger(32))
+ return emitError(loc, "expected 32-bit integer for descriptor set");
+
+ if (!binding.getType().isSignlessInteger(32))
+ return emitError(loc, "expected 32-bit integer for binding");
+
+ if (storageClass) {
+ if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
+ auto storageClassValue =
+ spirv::symbolizeStorageClass(storageClassAttr.getInt());
+ if (!storageClassValue)
+ return emitError(loc, "unknown storage class");
+ } else {
+ return emitError(loc, "expected valid storage class");
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// VerCapExtAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index e811fe6ec40a..ce4e5906de95 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -118,7 +118,7 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
- addAttributes<TargetEnvAttr, VerCapExtAttr>();
+ addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
// Add SPIR-V ops.
addOperations<
@@ -649,6 +649,75 @@ static ParseResult parseKeywordList(
return success();
}
+/// Parses a spirv::InterfaceVarABIAttr.
+static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ Builder &builder = parser.getBuilder();
+
+ if (parser.parseLParen())
+ return {};
+
+ IntegerAttr descriptorSetAttr;
+ {
+ auto loc = parser.getCurrentLocation();
+ uint32_t descriptorSet = 0;
+ auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
+
+ if (!descriptorSetParseResult.hasValue() ||
+ failed(*descriptorSetParseResult)) {
+ parser.emitError(loc, "missing descriptor set");
+ return {};
+ }
+ descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
+ }
+
+ if (parser.parseComma())
+ return {};
+
+ IntegerAttr bindingAttr;
+ {
+ auto loc = parser.getCurrentLocation();
+ uint32_t binding = 0;
+ auto bindingParseResult = parser.parseOptionalInteger(binding);
+
+ if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
+ parser.emitError(loc, "missing binding");
+ return {};
+ }
+ bindingAttr = builder.getI32IntegerAttr(binding);
+ }
+
+ if (parser.parseRParen())
+ return {};
+
+ IntegerAttr storageClassAttr;
+ {
+ if (succeeded(parser.parseOptionalComma())) {
+ auto loc = parser.getCurrentLocation();
+ StringRef storageClass;
+ if (parser.parseKeyword(&storageClass))
+ return {};
+
+ if (auto storageClassSymbol =
+ spirv::symbolizeStorageClass(storageClass)) {
+ storageClassAttr = builder.getI32IntegerAttr(
+ static_cast<uint32_t>(*storageClassSymbol));
+ } else {
+ parser.emitError(loc, "unknown storage class: ") << storageClass;
+ return {};
+ }
+ }
+ }
+
+ if (parser.parseGreater())
+ return {};
+
+ return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
+ storageClassAttr);
+}
+
static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
if (parser.parseLess())
return {};
@@ -771,6 +840,8 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
return parseTargetEnvAttr(parser);
if (attrKind == spirv::VerCapExtAttr::getKindName())
return parseVerCapExtAttr(parser);
+ if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
+ return parseInterfaceVarABIAttr(parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
<< attrKind;
@@ -801,12 +872,25 @@ static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
printer << ", " << targetEnv.getResourceLimits() << ">";
}
+static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
+ DialectAsmPrinter &printer) {
+ printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
+ << interfaceVarABIAttr.getDescriptorSet() << ", "
+ << interfaceVarABIAttr.getBinding() << ")";
+ auto storageClass = interfaceVarABIAttr.getStorageClass();
+ if (storageClass)
+ printer << ", " << spirv::stringifyStorageClass(*storageClass);
+ printer << ">";
+}
+
void SPIRVDialect::printAttribute(Attribute attr,
DialectAsmPrinter &printer) const {
if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
print(targetEnv, printer);
else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
print(vceAttr, printer);
+ else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
+ print(interfaceVarABIAttr, printer);
else
llvm_unreachable("unhandled SPIR-V attribute kind");
}
@@ -866,11 +950,9 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
if (!varABIAttr)
return emitError(loc, "'")
- << symbol
- << "' attribute must be a dictionary attribute containing two or "
- "three 32-bit integer attributes: 'descriptor_set', 'binding', "
- "and optional 'storage_class'";
- if (varABIAttr.storage_class() && !valueType.isIntOrIndexOrFloat())
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
return emitError(loc, "'") << symbol
<< "' attribute cannot specify storage class "
"when attaching to a non-scalar value";
diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 491fcf9a6f21..2bc99b695056 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -86,14 +86,8 @@ spirv::InterfaceVarABIAttr
spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
Optional<spirv::StorageClass> storageClass,
MLIRContext *context) {
- Type i32Type = IntegerType::get(32, context);
- auto scAttr =
- storageClass
- ? IntegerAttr::get(i32Type, static_cast<int64_t>(*storageClass))
- : IntegerAttr();
- return spirv::InterfaceVarABIAttr::get(
- IntegerAttr::get(i32Type, descriptorSet),
- IntegerAttr::get(i32Type, binding), scAttr, context);
+ return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass,
+ context);
}
StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index d6b32436c0b4..daee70976ac2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -41,10 +41,11 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// it must already be a !spv.ptr<!spv.struct<...>>.
auto varType = funcOp.getType().getInput(argIndex);
if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
- auto storageClass =
- static_cast<spirv::StorageClass>(abiInfo.storage_class().getInt());
+ auto storageClass = abiInfo.getStorageClass();
+ if (!storageClass)
+ return nullptr;
varType =
- spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
+ spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = varType.cast<spirv::PointerType>();
auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
@@ -56,8 +57,8 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
spirv::PointerType::get(varPointeeType, varPtrType.getStorageClass());
return builder.create<spirv::GlobalVariableOp>(
- funcOp.getLoc(), varType, varName, abiInfo.descriptor_set().getInt(),
- abiInfo.binding().getInt());
+ funcOp.getLoc(), varType, varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
}
/// Gets the global variables that need to be specified as interface variable
diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index 13dc621af8b7..94f7c650fa0d 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -27,13 +27,13 @@ module attributes {
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-LABEL: spv.func @load_store_kernel
- // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+ // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
+ // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}
+ // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}
+ // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
+ // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
+ // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
// CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index 0d0b4c891337..81b842a11c96 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -4,8 +4,8 @@ module attributes {gpu.container_module} {
gpu.module @kernels {
// CHECK: spv.module Logical GLSL450 {
// CHECK-LABEL: spv.func @basic_module_structure
- // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
- // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32{{[}][}]}}
+ // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>}
+ // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>)
attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index 187c5741d7f0..28c44bf7b936 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -14,12 +14,9 @@ spv.module Logical GLSL450 {
// CHECK: spv.func [[FN:@.*]]()
spv.func @kernel(
%arg0: f32
- {spv.interface_var_abi = {binding = 0 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>},
%arg1: !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
- {spv.interface_var_abi = {binding = 1 : i32,
- descriptor_set = 0 : i32}}) "None"
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None"
attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
// CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]]
// CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index c75c4d0f979c..075ef3398d83 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -27,30 +27,19 @@ spv.module Logical GLSL450 {
// CHECK: spv.func [[FN:@.*]]()
spv.func @load_store_kernel(
%arg0: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
- {spv.interface_var_abi = {binding = 0 : i32,
- descriptor_set = 0 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>},
%arg1: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
- {spv.interface_var_abi = {binding = 1 : i32,
- descriptor_set = 0 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>},
%arg2: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
- {spv.interface_var_abi = {binding = 2 : i32,
- descriptor_set = 0 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>},
%arg3: i32
- {spv.interface_var_abi = {binding = 3 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>},
%arg4: i32
- {spv.interface_var_abi = {binding = 4 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>},
%arg5: i32
- {spv.interface_var_abi = {binding = 5 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}},
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>},
%arg6: i32
- {spv.interface_var_abi = {binding = 6 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}}) "None"
+ {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
// CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
// CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
index 2c380e8ff039..8d11f4ca0c64 100644
--- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
@@ -51,34 +51,51 @@ func @spv_entry_point() attributes {
// spv.interface_var_abi
//===----------------------------------------------------------------------===//
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
+// expected-error @+1 {{'spv.interface_var_abi' must be a spirv::InterfaceVarABIAttr}}
func @interface_var(
%arg0 : f32 {spv.interface_var_abi = 64}
) { return }
// -----
-// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing two or three 32-bit integer attributes: 'descriptor_set', 'binding', and optional 'storage_class'}}
func @interface_var(
- %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
+// expected-error @+1 {{missing descriptor set}}
+ %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<()>}
) { return }
// -----
-// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}
func @interface_var(
- %arg0 : f32 {spv.interface_var_abi = {binding = 0 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}}
+// expected-error @+1 {{missing binding}}
+ %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,)>}
+) { return }
+
+// -----
+
+func @interface_var(
+// expected-error @+1 {{unknown storage class: }}
+ %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1,2), Foo>}
+) { return }
+
+// -----
+
+// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
+func @interface_var(
+ %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
+) { return }
+
+// -----
+
+// CHECK: {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
+func @interface_var(
+ %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
) { return }
// -----
// expected-error @+1 {{'spv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}}
func @interface_var(
- %arg0 : memref<4xf32> {spv.interface_var_abi = {binding = 0 : i32,
- descriptor_set = 0 : i32,
- storage_class = 12 : i32}}
+ %arg0 : memref<4xf32> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1), Uniform>}
) { return }
// -----
More information about the Mlir-commits
mailing list