[Mlir-commits] [mlir] 747d8fb - [mlir][spirv] Support alias/restrict function argument decorations (#76353)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 6 11:51:27 PST 2024
Author: Kohei Yamaguchi
Date: 2024-01-06T11:51:23-08:00
New Revision: 747d8fb01c2417546ebaa774874ff8c3005e058a
URL: https://github.com/llvm/llvm-project/commit/747d8fb01c2417546ebaa774874ff8c3005e058a
DIFF: https://github.com/llvm/llvm-project/commit/747d8fb01c2417546ebaa774874ff8c3005e058a.diff
LOG: [mlir][spirv] Support alias/restrict function argument decorations (#76353)
Closes #76106
---------
Co-authored-by: Lei Zhang <antiagainst at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.h
mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
mlir/test/Target/SPIRV/cast-ops.mlir
mlir/test/Target/SPIRV/function-decorations.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..fbf750d643031f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,15 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or one values.
+ From `SPV_KHR_physical_storage_buffer`:
+ If a parameter of function is
+ - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage
+ class, the function parameter must be decorated with exactly one of
+ `Aliased` or `Restrict`.
+ - a pointer (or contains a pointer) and the type it points to is a pointer
+ in the PhysicalStorageBuffer storage class, the function parameter must
+ be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
<!-- End of AutoGen section -->
```
@@ -280,6 +289,20 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
```mlir
spirv.func @foo() -> () "None" { ... }
spirv.func @bar() -> () "Inline|Pure" { ... }
+
+ spirv.func @aliased_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+ { spirv.decoration = #spirv.decoration<Aliased> }) -> () "None" { ... }
+
+ spirv.func @restrict_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+ { spirv.decoration = #spirv.decoration<Restrict> }) -> () "None" { ... }
+
+ spirv.func @aliased_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+ PhysicalStorageBuffer>, Generic> { spirv.decoration =
+ #spirv.decoration<AliasedPointer> }) -> () "None" { ... }
+
+ spirv.func @restrict_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+ PhysicalStorageBuffer>, Generic> { spirv.decoration =
+ #spirv.decoration<RestrictPointer> }) -> () "None" { ... }
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..d7944d600b0a2b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,30 +992,39 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();
- if (symbol != spirv::getInterfaceVarABIAttrName())
- return emitError(loc, "found unsupported '")
- << symbol << "' attribute on region argument";
-
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
- if (!varABIAttr)
- return emitError(loc, "'")
- << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
- if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
- return emitError(loc, "'") << symbol
- << "' attribute cannot specify storage class "
- "when attaching to a non-scalar value";
+ if (symbol == spirv::getInterfaceVarABIAttrName()) {
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ if (!varABIAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+ return emitError(loc, "'") << symbol
+ << "' attribute cannot specify storage class "
+ "when attaching to a non-scalar value";
+ return success();
+ }
+ if (symbol == spirv::DecorationAttr::name) {
+ if (!isa<spirv::DecorationAttr>(attr))
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::DecorationAttr";
+ return success();
+ }
- return success();
+ return emitError(loc, "found unsupported '")
+ << symbol << "' attribute on region argument";
}
LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) {
- return verifyRegionAttribute(
- op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
- attribute);
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+ Type argType = funcOp.getArgumentTypes()[argIndex];
+
+ return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5343a12132a912..3b159030cab75c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,73 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
- if (getFunctionType().getNumResults() > 1)
+ FunctionType fnType = getFunctionType();
+ if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
+
+ auto hasDecorationAttr = [&](spirv::Decoration decoration,
+ unsigned argIndex) {
+ auto func = llvm::cast<FunctionOpInterface>(getOperation());
+ for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
+ if (argAttr.getName() != spirv::DecorationAttr::name)
+ continue;
+ if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+ return decAttr.getValue() == decoration;
+ }
+ return false;
+ };
+
+ for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
+ Type param = fnType.getInputs()[i];
+ auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+ if (!inputPtrType)
+ continue;
+
+ auto pointeePtrType =
+ dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+ if (pointeePtrType) {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer)
+ // > and the type it points to is a pointer in the PhysicalStorageBuffer
+ // > storage class, the function parameter must be decorated with exactly
+ // > one of AliasedPointer or RestrictPointer.
+ if (pointeePtrType.getStorageClass() !=
+ spirv::StorageClass::PhysicalStorageBuffer)
+ continue;
+
+ bool hasAliasedPtr =
+ hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+ bool hasRestrictPtr =
+ hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+ if (!hasAliasedPtr && !hasRestrictPtr)
+ return emitOpError()
+ << "with a pointer points to a physical buffer pointer must "
+ "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+ continue;
+ }
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+ // > the PhysicalStorageBuffer storage class, the function parameter must
+ // > be decorated with exactly one of Aliased or Restrict.
+ if (auto pointeeArrayType =
+ dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+ pointeePtrType =
+ dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+ } else {
+ pointeePtrType = inputPtrType;
+ }
+
+ if (!pointeePtrType || pointeePtrType.getStorageClass() !=
+ spirv::StorageClass::PhysicalStorageBuffer)
+ continue;
+
+ bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+ bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+ if (!hasAliased && !hasRestrict)
+ return emitOpError() << "with physical buffer pointer must be decorated "
+ "either 'Aliased' or 'Restrict'";
+ }
+
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 00645d2c45519e..0c521adb11332b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
if (decorationName.empty()) {
return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
}
- auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
- auto symbol = opBuilder.getStringAttr(attrName);
+ auto symbol = getSymbolDecoration(decorationName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::FPFastMathMode:
if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
break;
}
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Block:
case spirv::Decoration::BufferBlock:
case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::RestrictPointer:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
return success();
}
+LogicalResult spirv::Deserializer::setFunctionArgAttrs(
+ uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
+ if (!decorations.contains(argID)) {
+ argAttrs[argIndex] = DictionaryAttr::get(context, {});
+ return success();
+ }
+
+ spirv::DecorationAttr foundDecorationAttr;
+ for (NamedAttribute decAttr : decorations[argID]) {
+ for (auto decoration :
+ {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
+ spirv::Decoration::AliasedPointer,
+ spirv::Decoration::RestrictPointer}) {
+
+ if (decAttr.getName() !=
+ getSymbolDecoration(stringifyDecoration(decoration)))
+ continue;
+
+ if (foundDecorationAttr)
+ return emitError(unknownLoc,
+ "more than one Aliased/Restrict decorations for "
+ "function argument with result <id> ")
+ << argID;
+
+ foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
+ break;
+ }
+ }
+
+ if (!foundDecorationAttr)
+ return emitError(unknownLoc, "unimplemented decoration support for "
+ "function argument with result <id> ")
+ << argID;
+
+ NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
+ foundDecorationAttr);
+ argAttrs[argIndex] = DictionaryAttr::get(context, attr);
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
@@ -430,6 +471,9 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
logger.indent();
});
+ SmallVector<Attribute> argAttrs;
+ argAttrs.resize(functionType.getNumInputs());
+
// Parse the op argument instructions
if (functionType.getNumInputs()) {
for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
@@ -463,11 +507,21 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "duplicate definition of result <id> ")
<< operands[1];
}
+ if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
+ return failure();
+ }
+
auto argValue = funcOp.getArgument(i);
valueMap[operands[1]] = argValue;
}
}
+ if (llvm::any_of(argAttrs, [](Attribute attr) {
+ auto argAttr = cast<DictionaryAttr>(attr);
+ return !argAttr.empty();
+ }))
+ funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
// entryBlock is needed to access the arguments, Once that is done, we can
// erase the block for functions with 'Import' LinkageAttributes, since these
// are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..fc9a8f5f9364b2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,19 @@ class Deserializer {
return globalVariableMap.lookup(id);
}
+ /// Sets the function argument's attributes. |argID| is the function
+ /// argument's result <id>, and |argIndex| is its index in the function's
+ /// argument list.
+ LogicalResult setFunctionArgAttrs(uint32_t argID,
+ SmallVectorImpl<Attribute> &argAttrs,
+ size_t argIndex);
+
+ /// Gets the symbol name from the name of decoration.
+ StringAttr getSymbolDecoration(StringRef decorationName) {
+ auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+ return opBuilder.getStringAttr(attrName);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 7bfcca5b4dcdca..41d2c0310d0008 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,34 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
return success();
}
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+ for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+ uint32_t argTypeID = 0;
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+ return failure();
+ }
+ auto argValueID = getNextID();
+
+ // Process decoration attributes of arguments.
+ auto funcOp = cast<FunctionOpInterface>(*op);
+ for (auto argAttr : funcOp.getArgAttrs(idx)) {
+ if (argAttr.getName() != DecorationAttr::name)
+ continue;
+
+ if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+ if (failed(processDecorationAttr(op->getLoc(), argValueID,
+ decAttr.getValue(), decAttr)))
+ return failure();
+ }
+ }
+
+ valueIDMap[arg] = argValueID;
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+ {argTypeID, argValueID});
+ }
+ return success();
+}
+
LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +257,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
// is going to return false for this function from now on)
// Hence, we'll remove the body once we are done with the serialization.
op.addEntryBlock();
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Don't need to process the added block, there is nothing to process,
// the fake body was added just to get the arguments, remove the body,
// since it's use is done.
op.eraseBody();
} else {
- // Declare the parameters.
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Some instructions (e.g., OpVariable) in a function must be in the first
// block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..1029fb933175fd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
- NamedAttribute attr) {
- auto attrName = attr.getName().strref();
- auto decorationName = getDecorationName(attrName);
- auto decoration = spirv::symbolizeDecoration(decorationName);
- if (!decoration) {
- return emitError(
- loc, "non-argument attributes expected to have snake-case-ified "
- "decoration name, unhandled attribute with name : ")
- << attrName;
- }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration,
+ Attribute attr) {
SmallVector<uint32_t, 1> args;
- switch (*decoration) {
+ switch (decoration) {
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
break;
}
case spirv::Decoration::FPFastMathMode:
- if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
- << attrName;
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
- if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
- return emitError(loc, "expected integer attribute for ") << attrName;
+ return emitError(loc, "expected integer attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::BuiltIn:
- if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
- << attrName << " attribute " << strAttr.getValue();
+ << stringifyDecoration(decoration) << " decoration attribute "
+ << strAttr.getValue();
}
- return emitError(loc, "expected string attribute for ") << attrName;
+ return emitError(loc, "expected string attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
@@ -275,14 +271,34 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
- // For unit attributes, the args list has no values so we do nothing
- if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
+ case spirv::Decoration::RestrictPointer:
+ // For unit attributes and decoration attributes, the args list
+ // has no values so we do nothing.
+ if (isa<UnitAttr, DecorationAttr>(attr))
break;
- return emitError(loc, "expected unit attribute for ") << attrName;
+ return emitError(loc,
+ "expected unit attribute or decoration attribute for ")
+ << stringifyDecoration(decoration);
default:
- return emitError(loc, "unhandled decoration ") << decorationName;
+ return emitError(loc, "unhandled decoration ")
+ << stringifyDecoration(decoration);
+ }
+ return emitDecoration(resultID, decoration, args);
+}
+
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+ NamedAttribute attr) {
+ StringRef attrName = attr.getName().strref();
+ std::string decorationName = getDecorationName(attrName);
+ std::optional<Decoration> decoration =
+ spirv::symbolizeDecoration(decorationName);
+ if (!decoration) {
+ return emitError(
+ loc, "non-argument attributes expected to have snake-case-ified "
+ "decoration name, unhandled attribute with name : ")
+ << attrName;
}
- return emitDecoration(resultID, *decoration, args);
+ return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 4b2ebf610bd723..9edb0f4af008dd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -127,6 +127,7 @@ class Serializer {
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(spirv::FuncOp op);
+ LogicalResult processFuncParameter(spirv::FuncOp op);
LogicalResult processVariableOp(spirv::VariableOp op);
@@ -134,6 +135,8 @@ class Serializer {
LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
/// Process attributes that translate to decorations on the result <id>
+ LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration, Attribute attr);
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
// CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
// CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
diff --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
index 2e39421df13cca..d915f8820c4f40 100644
--- a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
@@ -17,3 +17,59 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
}
spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict> }) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer> }) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer> }) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with a pointer points to a physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Function>) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { random_attr = #spirv.decoration<Aliased> }) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op arguments may only have dialect attributes}}
+spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>, random_attr = #spirv.decoration<Aliased> }) "None" {
+ spirv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
} {
- spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+ spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
spirv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir
index b0f6705df9ca41..117d4ca628f76a 100644
--- a/mlir/test/Target/SPIRV/function-decorations.mlir
+++ b/mlir/test/Target/SPIRV/function-decorations.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file -verify-diagnostics %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} {
@@ -17,3 +17,72 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
}
spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ // CHECK-LABEL: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+ spirv.func @func_arg_decoration_aliased(
+ %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }
+ ) "None" {
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ // CHECK-LABEL: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+ spirv.func @func_arg_decoration_restrict(
+ %arg0 : !spirv.ptr<i32,PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict> }
+ ) "None" {
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ // CHECK-LABEL: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>})
+ spirv.func @func_arg_decoration_aliased_pointer(
+ %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer> }
+ ) "None" {
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ // CHECK-LABEL: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>})
+ spirv.func @func_arg_decoration_restrict_pointer(
+ %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer> }
+ ) "None" {
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+ [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ // CHECK-LABEL: spirv.func @fn1(%{{.*}}: i32, %{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+ spirv.func @fn1(
+ %arg0: i32,
+ %arg1: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }
+ ) "None" {
+ spirv.Return
+ }
+
+ // CHECK-LABEL: spirv.func @fn2(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}, %{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+ spirv.func @fn2(
+ %arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> },
+ %arg1: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}
+ ) "None" {
+ spirv.Return
+ }
+}
More information about the Mlir-commits
mailing list