[Mlir-commits] [mlir] [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer (PR #76353)
Kohei Yamaguchi
llvmlistbot at llvm.org
Wed Dec 27 21:27:56 PST 2023
https://github.com/sott0n updated https://github.com/llvm/llvm-project/pull/76353
>From 2a9b9186976a5ab920b07194f92c22ca14680e16 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Fri, 22 Dec 2023 17:22:25 +0000
Subject: [PATCH 1/2] [mlir][spirv] Support function argument decorations for
ptr in the PhysicalStorageBuffer
---
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 8 +++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 37 ++++++----
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 69 ++++++++++++++++++-
.../spirv-storage-class-mapping.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 42 +++++++++++
.../SPIRV/Transforms/vce-deduction.mlir | 2 +-
mlir/test/Target/SPIRV/cast-ops.mlir | 2 +-
8 files changed, 145 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or one values.
+ From `SPV_KHR_physical_storage_buffer`:
+ If a parameter of function is
+ - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+ - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
<!-- End of AutoGen section -->
```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
```mlir
spirv.func @foo() -> () "None" { ... }
spirv.func @bar() -> () "Inline|Pure" { ... }
+
+ spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+ spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..66ec520cfeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,19 +992,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();
- if (symbol != spirv::getInterfaceVarABIAttrName())
+ if (symbol == spirv::getInterfaceVarABIAttrName()) {
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ if (!varABIAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+ return emitError(loc, "'") << symbol
+ << "' attribute cannot specify storage class "
+ "when attaching to a non-scalar value";
+ } else if (symbol == spirv::DecorationAttr::name) {
+ auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+ if (!decAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::DecorationAttr";
+ } else {
return emitError(loc, "found unsupported '")
<< symbol << "' attribute on region argument";
-
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
- if (!varABIAttr)
- return emitError(loc, "'")
- << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
- if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
- return emitError(loc, "'") << symbol
- << "' attribute cannot specify storage class "
- "when attaching to a non-scalar value";
+ }
return success();
}
@@ -1013,9 +1019,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) {
- return verifyRegionAttribute(
- op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
- attribute);
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+ Type argType = funcOp.getArgumentTypes()[argIndex];
+
+ return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..d6064f446b4454 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,75 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
- if (getFunctionType().getNumResults() > 1)
+ FunctionType fnType = getFunctionType();
+ if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
+
+ auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+ unsigned argIndex) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (auto argAttr : funcOp.getArgAttrs(argIndex))
+ if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+ return decAttr.getValue() == decoration;
+ }
+ return false;
+ };
+
+ auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+ unsigned numArgs = funcOp.getNumArguments();
+ if (numArgs < 1)
+ return success();
+
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto param = fnType.getInputs()[i];
+ auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+ if (!inputPtrType)
+ continue;
+
+ auto pointeePtrType =
+ dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+ if (pointeePtrType) {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer)
+ // > and the type it points to is a pointer in the PhysicalStorageBuffer
+ // > storage class, the function parameter must be decorated with exactly
+ // > one of AliasedPointer or RestrictPointer.
+ if (pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliasedPtr =
+ hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+ bool hasRestrictPtr =
+ hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+ if (!hasAliasedPtr && !hasRestrictPtr)
+ return emitOpError()
+ << "with a pointer points to a physical buffer pointer must "
+ "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+ }
+ } else {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+ // > the PhysicalStorageBuffer storage class, the function parameter must
+ // > be decorated with exactly one of Aliased or Restrict.
+ if (auto pointeeArrayType =
+ dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+ pointeePtrType =
+ dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+ } else {
+ pointeePtrType = inputPtrType;
+ }
+ if (pointeePtrType && pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+ bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+ if (!hasAliased && !hasRestrict)
+ return emitOpError()
+ << "with physical buffer pointer must be decorated "
+ "either 'Aliased' or 'Restrict'";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
// CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
// CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..45461ac68e6c9f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -318,6 +318,48 @@ spirv.module Logical GLSL450 {
// -----
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with a pointer points to a physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Input>) "None" {
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GlobalVariable
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
} {
- spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+ spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
spirv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
>From 50200c4448898be66264dc15ebca08b66117ada0 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Wed, 27 Dec 2023 15:34:00 +0000
Subject: [PATCH 2/2] Support serializer and deserializer
---
.../SPIRV/Deserialization/Deserializer.cpp | 52 ++++++++++++++-
.../SPIRV/Deserialization/Deserializer.h | 12 ++++
.../SPIRV/Serialization/SerializeOps.cpp | 54 +++++++++------
.../Target/SPIRV/Serialization/Serializer.cpp | 65 ++++++++++++-------
.../Target/SPIRV/Serialization/Serializer.h | 3 +
.../Target/SPIRV/function-decorations.mlir | 36 +++++++++-
6 files changed, 173 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..24748007bbb175 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
if (decorationName.empty()) {
return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
}
- auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
- auto symbol = opBuilder.getStringAttr(attrName);
+ auto symbol = getSymbolDecoration(decorationName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::FPFastMathMode:
if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
break;
}
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Block:
case spirv::Decoration::BufferBlock:
case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::RestrictPointer:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
return success();
}
+void spirv::Deserializer::setArgAttrs(uint32_t argID) {
+ if (!decorations.count(argID)) {
+ argAttrs.push_back(DictionaryAttr::get(context, {}));
+ return;
+ }
+
+ // Replace a decoration as UnitAttr with DecorationAttr for the physical
+ // buffer pointer in the function parameter.
+ // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
+ for (auto decAttr : decorations[argID]) {
+ if (decAttr.getName() ==
+ getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::Restrict))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::AliasedPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::AliasedPointer));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::RestrictPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::RestrictPointer));
+ }
+ }
+
+ argAttrs.push_back(decorations[argID].getDictionary(context));
+}
+
LogicalResult
spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
@@ -463,11 +504,18 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "duplicate definition of result <id> ")
<< operands[1];
}
+ setArgAttrs(operands[1]);
auto argValue = funcOp.getArgument(i);
valueMap[operands[1]] = argValue;
}
}
+ if (llvm::any_of(argAttrs, [](Attribute attr) {
+ auto argAttr = cast<DictionaryAttr>(attr);
+ return !argAttr.empty();
+ }))
+ funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
// entryBlock is needed to access the arguments, Once that is done, we can
// erase the block for functions with 'Import' LinkageAttributes, since these
// are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..115addd49f949a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,15 @@ class Deserializer {
return globalVariableMap.lookup(id);
}
+ /// Sets the argument's attributes with the given argument <id>.
+ void setArgAttrs(uint32_t argID);
+
+ /// Gets the symbol name from the name of decoration.
+ StringAttr getSymbolDecoration(StringRef decorationName) {
+ auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+ return opBuilder.getStringAttr(attrName);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -605,6 +614,9 @@ class Deserializer {
/// A list of all structs which have unresolved member types.
SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
+ /// A list of argument attributes of function.
+ SmallVector<Attribute, 0> argAttrs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..2efb0ee64c9253 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,35 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
return success();
}
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+ unsigned numArgs = op.getNumArguments();
+ if (numArgs != 0) {
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto arg = op.getArgument(i);
+ uint32_t argTypeID = 0;
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+ return failure();
+ }
+ auto argValueID = getNextID();
+
+ // Process decoration attributes of arguments.
+ auto funcOp = cast<FunctionOpInterface>(*op);
+ for (auto argAttr : funcOp.getArgAttrs(i)) {
+ if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+ if (failed(processDecorationAttr(op->getLoc(), argValueID,
+ decAttr.getValue(), decAttr)))
+ return failure();
+ }
+ }
+
+ valueIDMap[arg] = argValueID;
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+ {argTypeID, argValueID});
+ }
+ }
+ return success();
+}
+
LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +258,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
// is going to return false for this function from now on)
// Hence, we'll remove the body once we are done with the serialization.
op.addEntryBlock();
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Don't need to process the added block, there is nothing to process,
// the fake body was added just to get the arguments, remove the body,
// since it's use is done.
op.eraseBody();
} else {
- // Declare the parameters.
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Some instructions (e.g., OpVariable) in a function must be in the first
// block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..6c3aaa35457470 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
- NamedAttribute attr) {
- auto attrName = attr.getName().strref();
- auto decorationName = getDecorationName(attrName);
- auto decoration = spirv::symbolizeDecoration(decorationName);
- if (!decoration) {
- return emitError(
- loc, "non-argument attributes expected to have snake-case-ified "
- "decoration name, unhandled attribute with name : ")
- << attrName;
- }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration,
+ Attribute attr) {
SmallVector<uint32_t, 1> args;
- switch (*decoration) {
+ switch (decoration) {
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
break;
}
case spirv::Decoration::FPFastMathMode:
- if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
- << attrName;
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
- if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
- return emitError(loc, "expected integer attribute for ") << attrName;
+ return emitError(loc, "expected integer attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::BuiltIn:
- if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
- << attrName << " attribute " << strAttr.getValue();
+ << stringifyDecoration(decoration) << " decoration attribute "
+ << strAttr.getValue();
}
- return emitError(loc, "expected string attribute for ") << attrName;
+ return emitError(loc, "expected string attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
@@ -275,14 +271,33 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
- // For unit attributes, the args list has no values so we do nothing
- if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
+ case spirv::Decoration::RestrictPointer:
+ // For unit attributes and decoration attributes, the args list
+ // has no values so we do nothing.
+ if (isa<UnitAttr>(attr) || isa<DecorationAttr>(attr))
break;
- return emitError(loc, "expected unit attribute for ") << attrName;
+ return emitError(loc,
+ "expected unit attribute or decoration attribute for ")
+ << stringifyDecoration(decoration);
default:
- return emitError(loc, "unhandled decoration ") << decorationName;
+ return emitError(loc, "unhandled decoration ")
+ << stringifyDecoration(decoration);
+ }
+ return emitDecoration(resultID, decoration, args);
+}
+
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+ NamedAttribute attr) {
+ auto attrName = attr.getName().strref();
+ auto decorationName = getDecorationName(attrName);
+ auto decoration = spirv::symbolizeDecoration(decorationName);
+ if (!decoration) {
+ return emitError(
+ loc, "non-argument attributes expected to have snake-case-ified "
+ "decoration name, unhandled attribute with name : ")
+ << attrName;
}
- return emitDecoration(resultID, *decoration, args);
+ return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 4b2ebf610bd723..9edb0f4af008dd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -127,6 +127,7 @@ class Serializer {
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(spirv::FuncOp op);
+ LogicalResult processFuncParameter(spirv::FuncOp op);
LogicalResult processVariableOp(spirv::VariableOp op);
@@ -134,6 +135,8 @@ class Serializer {
LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
/// Process attributes that translate to decorations on the result <id>
+ LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration, Attribute attr);
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir
index b0f6705df9ca41..3696d37532d20e 100644
--- a/mlir/test/Target/SPIRV/function-decorations.mlir
+++ b/mlir/test/Target/SPIRV/function-decorations.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} {
@@ -17,3 +17,37 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
}
spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+
+ // check: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+ spirv.func @func_arg_decoration_aliased(
+ %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}
+ ) "None" {
+ spirv.Return
+ }
+
+ // check: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+ spirv.func @func_arg_decoration_restrict(
+ %arg0 : !spirv.ptr<i32,PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}
+ ) "None" {
+ spirv.Return
+ }
+
+ // check: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>})
+ spirv.func @func_arg_decoration_aliased_pointer(
+ %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}
+ ) "None" {
+ spirv.Return
+ }
+
+ // check: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>})
+ spirv.func @func_arg_decoration_restrict_pointer(
+ %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}
+ ) "None" {
+ spirv.Return
+ }
+}
More information about the Mlir-commits
mailing list