[Mlir-commits] [mlir] [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer (PR #76353)

Kohei Yamaguchi llvmlistbot at llvm.org
Sun Dec 24 23:53:25 PST 2023


https://github.com/sott0n created https://github.com/llvm/llvm-project/pull/76353

Closes #76106

>From 52210454a1b9cf51ad28a4829942c0a860ef6aaf Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Fri, 22 Dec 2023 17:22:25 +0000
Subject: [PATCH] [mlir][spirv] Support function argument decorations for ptr
 in the PhysicalStorageBuffer

---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  8 ++
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 43 +++++++----
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 74 ++++++++++++++++++-
 .../spirv-storage-class-mapping.mlir          |  2 +-
 mlir/test/Dialect/SPIRV/IR/cast-ops.mlir      |  2 +-
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 42 +++++++++++
 .../SPIRV/Transforms/vce-deduction.mlir       |  2 +-
 mlir/test/Target/SPIRV/cast-ops.mlir          |  2 +-
 8 files changed, 156 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     This op itself takes no operands and generates no results. Its region
     can take zero or more arguments and return zero or one values.
 
+    From `SPV_KHR_physical_storage_buffer`:
+    If a parameter of function is
+    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+    - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
     <!-- End of AutoGen section -->
 
     ```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     ```mlir
     spirv.func @foo() -> () "None" { ... }
     spirv.func @bar() -> () "Inline|Pure" { ... }
+
+    spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+    spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..0933ebea65a28c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -960,6 +960,12 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
   return builder.create<spirv::ConstantOp>(loc, type, value);
 }
 
+//===----------------------------------------------------------------------===//
+// TODO: Move other
+//===----------------------------------------------------------------------===//
+
+StringRef getDecorationAttrName() { return "spirv.decoration"; }
+
 //===----------------------------------------------------------------------===//
 // Shader Interface ABI
 //===----------------------------------------------------------------------===//
@@ -992,19 +998,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   StringRef symbol = attribute.getName().strref();
   Attribute attr = attribute.getValue();
 
-  if (symbol != spirv::getInterfaceVarABIAttrName())
+  if (symbol == spirv::getInterfaceVarABIAttrName()) {
+    auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+    if (!varABIAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+    if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+      return emitError(loc, "'") << symbol
+                                 << "' attribute cannot specify storage class "
+                                    "when attaching to a non-scalar value";
+  } else if (symbol == getDecorationAttrName()) {
+    auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+    if (!decAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::DecorationAttr";
+  } else {
     return emitError(loc, "found unsupported '")
            << symbol << "' attribute on region argument";
-
-  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
-  if (!varABIAttr)
-    return emitError(loc, "'")
-           << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
-  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
-    return emitError(loc, "'") << symbol
-                               << "' attribute cannot specify storage class "
-                                  "when attaching to a non-scalar value";
+  }
 
   return success();
 }
@@ -1013,9 +1025,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                      unsigned regionIndex,
                                                      unsigned argIndex,
                                                      NamedAttribute attribute) {
-  return verifyRegionAttribute(
-      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
-      attribute);
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type argType = funcOp.getArgumentTypes()[argIndex];
+
+  return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..2b2a5436bd5f5d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,80 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  if (getFunctionType().getNumResults() > 1)
+  FunctionType fnType = getFunctionType();
+  if (fnType.getNumResults() > 1)
     return emitOpError("cannot have more than one result");
+
+  auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+                                                 unsigned i) {
+    auto funcOp = cast<spirv::FuncOp>(op);
+    auto argAttrs = funcOp.getArgAttrs();
+    if (argAttrs) {
+      auto argAttr = cast<DictionaryAttr>((*argAttrs)[i]);
+      for (auto attrValue : argAttr.getValue()) {
+        auto decAttr = dyn_cast<spirv::DecorationAttr>(attrValue.getValue());
+        if (decAttr)
+          return decAttr.getValue() == decoration;
+      }
+    }
+    return false;
+  };
+
+  auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+  unsigned numArgs = funcOp.getNumArguments();
+  if (numArgs < 1)
+    return success();
+
+  for (unsigned i = 0; i < numArgs; ++i) {
+    auto param = fnType.getInputs()[i];
+    auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+    if (!inputPtrType)
+      continue;
+
+    auto pointeePtrType =
+        dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+    if (pointeePtrType) {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer)
+      // > and the type it points to is a pointer in the PhysicalStorageBuffer
+      // > storage class, the function parameter must be decorated with exactly
+      // > one of AliasedPointer or RestrictPointer.
+      if (pointeePtrType.getStorageClass() ==
+          spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliasedPtr =
+            hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+        bool hasRestrictPtr =
+            hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+        if (!hasAliasedPtr && !hasRestrictPtr)
+          return emitOpError()
+                 << " with physical buffer pointer must be decorated "
+                    "either 'AliasedPointer' or 'RestrictPointer'";
+      }
+    } else {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+      // > the PhysicalStorageBuffer storage class, the function parameter must
+      // > be decorated with exactly one of Aliased or Restrict.
+      if (auto pointeeArrayType =
+              dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+        pointeePtrType =
+            dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+      } else {
+        pointeePtrType = inputPtrType;
+      }
+      if (pointeePtrType && pointeePtrType.getStorageClass() ==
+                                spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+        bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+        if (!hasAliased && !hasRestrict)
+          return emitOpError()
+                 << " with physical buffer pointer must be decorated "
+                    "either 'Aliased' or 'Restrict'";
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
 spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
 
 // CHECK-ALL:            llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
 
 // CHECK-ALL:            llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
 spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..ebba60bd72b1f0 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -318,6 +318,48 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op  with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+    spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op  with physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Input>) "None" {
+    spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GlobalVariable
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
 } {
-  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     spirv.Return
   }
 }
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {  
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return



More information about the Mlir-commits mailing list