[Mlir-commits] [mlir] [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer (PR #76353)
Kohei Yamaguchi
llvmlistbot at llvm.org
Sun Dec 24 23:53:25 PST 2023
https://github.com/sott0n created https://github.com/llvm/llvm-project/pull/76353
Closes #76106
>From 52210454a1b9cf51ad28a4829942c0a860ef6aaf Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Fri, 22 Dec 2023 17:22:25 +0000
Subject: [PATCH] [mlir][spirv] Support function argument decorations for ptr
in the PhysicalStorageBuffer
---
.../Dialect/SPIRV/IR/SPIRVStructureOps.td | 8 ++
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 43 +++++++----
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 74 ++++++++++++++++++-
.../spirv-storage-class-mapping.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 42 +++++++++++
.../SPIRV/Transforms/vce-deduction.mlir | 2 +-
mlir/test/Target/SPIRV/cast-ops.mlir | 2 +-
8 files changed, 156 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
This op itself takes no operands and generates no results. Its region
can take zero or more arguments and return zero or one values.
+ From `SPV_KHR_physical_storage_buffer`:
+ If a parameter of function is
+ - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+ - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
<!-- End of AutoGen section -->
```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
```mlir
spirv.func @foo() -> () "None" { ... }
spirv.func @bar() -> () "Inline|Pure" { ... }
+
+ spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+ spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..0933ebea65a28c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -960,6 +960,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
return builder.create<spirv::ConstantOp>(loc, type, value);
}
+//===----------------------------------------------------------------------===//
+// TODO: Move other
+//===----------------------------------------------------------------------===//
+
+StringRef getDecorationAttrName() { return "spirv.decoration"; }
+
//===----------------------------------------------------------------------===//
// Shader Interface ABI
//===----------------------------------------------------------------------===//
@@ -992,19 +998,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();
- if (symbol != spirv::getInterfaceVarABIAttrName())
+ if (symbol == spirv::getInterfaceVarABIAttrName()) {
+ auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+ if (!varABIAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+ if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+ return emitError(loc, "'") << symbol
+ << "' attribute cannot specify storage class "
+ "when attaching to a non-scalar value";
+ } else if (symbol == getDecorationAttrName()) {
+ auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+ if (!decAttr)
+ return emitError(loc, "'")
+ << symbol << "' must be a spirv::DecorationAttr";
+ } else {
return emitError(loc, "found unsupported '")
<< symbol << "' attribute on region argument";
-
- auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
- if (!varABIAttr)
- return emitError(loc, "'")
- << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
- if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
- return emitError(loc, "'") << symbol
- << "' attribute cannot specify storage class "
- "when attaching to a non-scalar value";
+ }
return success();
}
@@ -1013,9 +1025,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute attribute) {
- return verifyRegionAttribute(
- op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
- attribute);
+ auto funcOp = dyn_cast<FunctionOpInterface>(op);
+ if (!funcOp)
+ return success();
+ Type argType = funcOp.getArgumentTypes()[argIndex];
+
+ return verifyRegionAttribute(op->getLoc(), argType, attribute);
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..2b2a5436bd5f5d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,80 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::FuncOp::verifyType() {
- if (getFunctionType().getNumResults() > 1)
+ FunctionType fnType = getFunctionType();
+ if (fnType.getNumResults() > 1)
return emitOpError("cannot have more than one result");
+
+ auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+ unsigned i) {
+ auto funcOp = cast<spirv::FuncOp>(op);
+ auto argAttrs = funcOp.getArgAttrs();
+ if (argAttrs) {
+ auto argAttr = cast<DictionaryAttr>((*argAttrs)[i]);
+ for (auto attrValue : argAttr.getValue()) {
+ auto decAttr = dyn_cast<spirv::DecorationAttr>(attrValue.getValue());
+ if (decAttr)
+ return decAttr.getValue() == decoration;
+ }
+ }
+ return false;
+ };
+
+ auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+ unsigned numArgs = funcOp.getNumArguments();
+ if (numArgs < 1)
+ return success();
+
+ for (unsigned i = 0; i < numArgs; ++i) {
+ auto param = fnType.getInputs()[i];
+ auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+ if (!inputPtrType)
+ continue;
+
+ auto pointeePtrType =
+ dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+ if (pointeePtrType) {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer)
+ // > and the type it points to is a pointer in the PhysicalStorageBuffer
+ // > storage class, the function parameter must be decorated with exactly
+ // > one of AliasedPointer or RestrictPointer.
+ if (pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliasedPtr =
+ hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+ bool hasRestrictPtr =
+ hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+ if (!hasAliasedPtr && !hasRestrictPtr)
+ return emitOpError()
+ << " with physical buffer pointer must be decorated "
+ "either 'AliasedPointer' or 'RestrictPointer'";
+ }
+ } else {
+ // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+ // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+ // > the PhysicalStorageBuffer storage class, the function parameter must
+ // > be decorated with exactly one of Aliased or Restrict.
+ if (auto pointeeArrayType =
+ dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+ pointeePtrType =
+ dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+ } else {
+ pointeePtrType = inputPtrType;
+ }
+ if (pointeePtrType && pointeePtrType.getStorageClass() ==
+ spirv::StorageClass::PhysicalStorageBuffer) {
+ bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+ bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+ if (!hasAliased && !hasRestrict)
+ return emitOpError()
+ << " with physical buffer pointer must be decorated "
+ "either 'Aliased' or 'Restrict'";
+ }
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
// CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
// CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..ebba60bd72b1f0 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -318,6 +318,48 @@ spirv.module Logical GLSL450 {
// -----
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Input>) "None" {
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GlobalVariable
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
} {
- spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+ spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
spirv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
// -----
spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
- spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+ spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
// CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
%0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
spirv.Return
More information about the Mlir-commits
mailing list