[Mlir-commits] [mlir] [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer (PR #76353)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 28 01:25:16 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kohei Yamaguchi (sott0n)
<details>
<summary>Changes</summary>
Closes #<!-- -->76106
---
Patch is 31.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76353.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td (+8)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+23-14)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+68-1)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+50-2)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+12)
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+33-21)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+40-25)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+3)
- (modified) mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir (+1-1)
- (modified) mlir/test/Dialect/SPIRV/IR/cast-ops.mlir (+1-1)
- (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+42)
- (modified) mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir (+1-1)
- (modified) mlir/test/Target/SPIRV/cast-ops.mlir (+1-1)
- (modified) mlir/test/Target/SPIRV/function-decorations.mlir (+35-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or one values.
+ From `SPV_KHR_physical_storage_buffer`:
+ If a parameter of function is
+ - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+ - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
<!-- End of AutoGen section -->
```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
```mlir
spirv.func @foo() -> () "None" { ... }
spirv.func @bar() -> () "Inline|Pure" { ... }
+
+ spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+ spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..66ec520cfeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,19 +992,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();
- if (symbol != spirv::getInterfaceVarABIAttrName())
+ if (symbol == spirv::getInterfaceVarABIAttrName()) {
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ if (!varABIAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+ return emitError(loc, "'") << symbol
+ << "' attribute cannot specify storage class "
+ "when attaching to a non-scalar value";
+ } else if (symbol == spirv::DecorationAttr::name) {
+ auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+ if (!decAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::DecorationAttr";
+ } else {
return emitError(loc, "found unsupported '")
<< symbol << "' attribute on region argument";
-
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
- if (!varABIAttr)
- return emitError(loc, "'")
- << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
- if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
- return emitError(loc, "'") << symbol
- << "' attribute cannot specify storage class "
- "when attaching to a non-scalar value";
+ }
return success();
}
@@ -1013,9 +1019,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) {
- return verifyRegionAttribute(
- op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
- attribute);
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+ Type argType = funcOp.getArgumentTypes()[argIndex];
+
+ return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..d6064f446b4454 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,75 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
- if (getFunctionType().getNumResults() > 1)
+ FunctionType fnType = getFunctionType();
+ if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
+
+ auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+ unsigned argIndex) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (auto argAttr : funcOp.getArgAttrs(argIndex))
+ if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+ return decAttr.getValue() == decoration;
+ }
+ return false;
+ };
+
+ auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+ unsigned numArgs = funcOp.getNumArguments();
+ if (numArgs < 1)
+ return success();
+
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto param = fnType.getInputs()[i];
+ auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+ if (!inputPtrType)
+ continue;
+
+ auto pointeePtrType =
+ dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+ if (pointeePtrType) {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer)
+ // > and the type it points to is a pointer in the PhysicalStorageBuffer
+ // > storage class, the function parameter must be decorated with exactly
+ // > one of AliasedPointer or RestrictPointer.
+ if (pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliasedPtr =
+ hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+ bool hasRestrictPtr =
+ hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+ if (!hasAliasedPtr && !hasRestrictPtr)
+ return emitOpError()
+ << "with a pointer points to a physical buffer pointer must "
+ "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+ }
+ } else {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+ // > the PhysicalStorageBuffer storage class, the function parameter must
+ // > be decorated with exactly one of Aliased or Restrict.
+ if (auto pointeeArrayType =
+ dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+ pointeePtrType =
+ dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+ } else {
+ pointeePtrType = inputPtrType;
+ }
+ if (pointeePtrType && pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+ bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+ if (!hasAliased && !hasRestrict)
+ return emitOpError()
+ << "with physical buffer pointer must be decorated "
+ "either 'Aliased' or 'Restrict'";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..24748007bbb175 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
if (decorationName.empty()) {
return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
}
- auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
- auto symbol = opBuilder.getStringAttr(attrName);
+ auto symbol = getSymbolDecoration(decorationName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::FPFastMathMode:
if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
break;
}
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Block:
case spirv::Decoration::BufferBlock:
case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
+ case spirv::Decoration::RestrictPointer:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
return success();
}
+void spirv::Deserializer::setArgAttrs(uint32_t argID) {
+ if (!decorations.count(argID)) {
+ argAttrs.push_back(DictionaryAttr::get(context, {}));
+ return;
+ }
+
+ // Replace a decoration as UnitAttr with DecorationAttr for the physical
+ // buffer pointer in the function parameter.
+ // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
+ for (auto decAttr : decorations[argID]) {
+ if (decAttr.getName() ==
+ getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::Restrict))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(
+ spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::AliasedPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::AliasedPointer));
+ } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+ spirv::Decoration::RestrictPointer))) {
+ decorations[argID].erase(decAttr.getName());
+ decorations[argID].set(spirv::DecorationAttr::name,
+ spirv::DecorationAttr::get(
+ context, spirv::Decoration::RestrictPointer));
+ }
+ }
+
+ argAttrs.push_back(decorations[argID].getDictionary(context));
+}
+
LogicalResult
spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
@@ -463,11 +504,18 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "duplicate definition of result <id> ")
<< operands[1];
}
+ setArgAttrs(operands[1]);
auto argValue = funcOp.getArgument(i);
valueMap[operands[1]] = argValue;
}
}
+ if (llvm::any_of(argAttrs, [](Attribute attr) {
+ auto argAttr = cast<DictionaryAttr>(attr);
+ return !argAttr.empty();
+ }))
+ funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
// entryBlock is needed to access the arguments, Once that is done, we can
// erase the block for functions with 'Import' LinkageAttributes, since these
// are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..115addd49f949a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,15 @@ class Deserializer {
return globalVariableMap.lookup(id);
}
+ /// Sets the argument's attributes with the given argument <id>.
+ void setArgAttrs(uint32_t argID);
+
+ /// Gets the symbol name from the name of decoration.
+ StringAttr getSymbolDecoration(StringRef decorationName) {
+ auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+ return opBuilder.getStringAttr(attrName);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -605,6 +614,9 @@ class Deserializer {
/// A list of all structs which have unresolved member types.
SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
+ /// A list of argument attributes of function.
+ SmallVector<Attribute, 0> argAttrs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..2efb0ee64c9253 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,35 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
return success();
}
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+ unsigned numArgs = op.getNumArguments();
+ if (numArgs != 0) {
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto arg = op.getArgument(i);
+ uint32_t argTypeID = 0;
+ if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+ return failure();
+ }
+ auto argValueID = getNextID();
+
+ // Process decoration attributes of arguments.
+ auto funcOp = cast<FunctionOpInterface>(*op);
+ for (auto argAttr : funcOp.getArgAttrs(i)) {
+ if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+ if (failed(processDecorationAttr(op->getLoc(), argValueID,
+ decAttr.getValue(), decAttr)))
+ return failure();
+ }
+ }
+
+ valueIDMap[arg] = argValueID;
+ encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+ {argTypeID, argValueID});
+ }
+ }
+ return success();
+}
+
LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +258,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
// is going to return false for this function from now on)
// Hence, we'll remove the body once we are done with the serialization.
op.addEntryBlock();
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Don't need to process the added block, there is nothing to process,
// the fake body was added just to get the arguments, remove the body,
// since it's use is done.
op.eraseBody();
} else {
- // Declare the parameters.
- for (auto arg : op.getArguments()) {
- uint32_t argTypeID = 0;
- if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
- return failure();
- }
- auto argValueID = getNextID();
- valueIDMap[arg] = argValueID;
- encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
- {argTypeID, argValueID});
- }
+ if (failed(processFuncParameter(op)))
+ return failure();
// Some instructions (e.g., OpVariable) in a function must be in the first
// block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..6c3aaa35457470 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
- NamedAttribute attr) {
- auto attrName = attr.getName().strref();
- auto decorationName = getDecorationName(attrName);
- auto decoration = spirv::symbolizeDecoration(decorationName);
- if (!decoration) {
- return emitError(
- loc, "non-argument attributes expected to have snake-case-ified "
- "decoration name, unhandled attribute with name : ")
- << attrName;
- }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+ Decoration decoration,
+ Attribute attr) {
SmallVector<uint32_t, 1> args;
- switch (*decoration) {
+ switch (decoration) {
case spirv::Decoration::LinkageAttributes: {
// Get the value of the Linkage Attributes
// e.g., LinkageAttributes=["linkageName", linkageType].
- auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+ auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
// Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
break;
}
case spirv::Decoration::FPFastMathMode:
- if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
- << attrName;
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
- if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
- return emitError(loc, "expected integer attribute for ") << attrName;
+ return emitError(loc, "expected integer attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::BuiltIn:
- if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
- << attrName << " attribute " << strAttr.getValue();
+ << stringifyDecoration(decoration) << " decoration attribute "
+ << strAttr.getValue();
}
- return emitError(loc, "expected string attribute for ") << attrName;
+ return emitError(loc, "expected string attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Aliased:
+ case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
@@ -275,14 +271,33 @@ LogicalResult Serializer::processDecoration(Location loc, uin...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/76353
More information about the Mlir-commits
mailing list