[Mlir-commits] [mlir] [mlir][spirv] Support function argument decorations for ptr in the PhysicalStorageBuffer (PR #76353)

Kohei Yamaguchi llvmlistbot at llvm.org
Thu Dec 28 21:16:48 PST 2023


https://github.com/sott0n updated https://github.com/llvm/llvm-project/pull/76353

>From ebd9634057e9417905d7fcd27bac829c5d0889e0 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Fri, 22 Dec 2023 17:22:25 +0000
Subject: [PATCH 1/3] [mlir][spirv] Support function argument decorations for
 ptr in the PhysicalStorageBuffer

---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     |  8 +++
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 37 ++++++----
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 69 ++++++++++++++++++-
 .../spirv-storage-class-mapping.mlir          |  2 +-
 mlir/test/Dialect/SPIRV/IR/cast-ops.mlir      |  2 +-
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 42 +++++++++++
 .../SPIRV/Transforms/vce-deduction.mlir       |  2 +-
 mlir/test/Target/SPIRV/cast-ops.mlir          |  2 +-
 8 files changed, 145 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..0afe508b4db013 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,11 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     This op itself takes no operands and generates no results. Its region
     can take zero or more arguments and return zero or one values.
 
+    From `SPV_KHR_physical_storage_buffer`:
+    If a parameter of function is
+    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
+    - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
     <!-- End of AutoGen section -->
 
     ```
@@ -280,6 +285,9 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     ```mlir
     spirv.func @foo() -> () "None" { ... }
     spirv.func @bar() -> () "Inline|Pure" { ... }
+
+    spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
+    spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..66ec520cfeca31 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,19 +992,25 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   StringRef symbol = attribute.getName().strref();
   Attribute attr = attribute.getValue();
 
-  if (symbol != spirv::getInterfaceVarABIAttrName())
+  if (symbol == spirv::getInterfaceVarABIAttrName()) {
+    auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+    if (!varABIAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+    if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+      return emitError(loc, "'") << symbol
+                                 << "' attribute cannot specify storage class "
+                                    "when attaching to a non-scalar value";
+  } else if (symbol == spirv::DecorationAttr::name) {
+    auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
+    if (!decAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::DecorationAttr";
+  } else {
     return emitError(loc, "found unsupported '")
            << symbol << "' attribute on region argument";
-
-  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
-  if (!varABIAttr)
-    return emitError(loc, "'")
-           << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
-  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
-    return emitError(loc, "'") << symbol
-                               << "' attribute cannot specify storage class "
-                                  "when attaching to a non-scalar value";
+  }
 
   return success();
 }
@@ -1013,9 +1019,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                      unsigned regionIndex,
                                                      unsigned argIndex,
                                                      NamedAttribute attribute) {
-  return verifyRegionAttribute(
-      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
-      attribute);
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type argType = funcOp.getArgumentTypes()[argIndex];
+
+  return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..d6064f446b4454 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,75 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  if (getFunctionType().getNumResults() > 1)
+  FunctionType fnType = getFunctionType();
+  if (fnType.getNumResults() > 1)
     return emitOpError("cannot have more than one result");
+
+  auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
+                                                 unsigned argIndex) {
+    if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+      for (auto argAttr : funcOp.getArgAttrs(argIndex))
+        if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+          return decAttr.getValue() == decoration;
+    }
+    return false;
+  };
+
+  auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+  unsigned numArgs = funcOp.getNumArguments();
+  if (numArgs < 1)
+    return success();
+
+  for (unsigned i = 0; i < numArgs; ++i) {
+    auto param = fnType.getInputs()[i];
+    auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+    if (!inputPtrType)
+      continue;
+
+    auto pointeePtrType =
+        dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+    if (pointeePtrType) {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer)
+      // > and the type it points to is a pointer in the PhysicalStorageBuffer
+      // > storage class, the function parameter must be decorated with exactly
+      // > one of AliasedPointer or RestrictPointer.
+      if (pointeePtrType.getStorageClass() ==
+          spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliasedPtr =
+            hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+        bool hasRestrictPtr =
+            hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+
+        if (!hasAliasedPtr && !hasRestrictPtr)
+          return emitOpError()
+                 << "with a pointer points to a physical buffer pointer must "
+                    "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+      }
+    } else {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+      // > the PhysicalStorageBuffer storage class, the function parameter must
+      // > be decorated with exactly one of Aliased or Restrict.
+      if (auto pointeeArrayType =
+              dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+        pointeePtrType =
+            dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+      } else {
+        pointeePtrType = inputPtrType;
+      }
+      if (pointeePtrType && pointeePtrType.getStorageClass() ==
+                                spirv::StorageClass::PhysicalStorageBuffer) {
+        bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+        bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+        if (!hasAliased && !hasRestrict)
+          return emitOpError()
+                 << "with physical buffer pointer must be decorated "
+                    "either 'Aliased' or 'Restrict'";
+      }
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
 spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
 
 // CHECK-ALL:            llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
 
 // CHECK-ALL:            llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
 spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..45461ac68e6c9f 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -318,6 +318,48 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}) "None" {
+    spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+    spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with a pointer points to a physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Input>) "None" {
+    spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GlobalVariable
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
 } {
-  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     spirv.Return
   }
 }
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {  
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return

>From 2b64b3e0f1dd8498e372beb9580b2b4980cc96d2 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Wed, 27 Dec 2023 15:34:00 +0000
Subject: [PATCH 2/3] Support serializer and deserializer

---
 .../SPIRV/Deserialization/Deserializer.cpp    | 52 ++++++++++++++-
 .../SPIRV/Deserialization/Deserializer.h      | 12 ++++
 .../SPIRV/Serialization/SerializeOps.cpp      | 54 +++++++++------
 .../Target/SPIRV/Serialization/Serializer.cpp | 65 ++++++++++++-------
 .../Target/SPIRV/Serialization/Serializer.h   |  3 +
 .../Target/SPIRV/function-decorations.mlir    | 36 +++++++++-
 6 files changed, 173 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..24748007bbb175 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   if (decorationName.empty()) {
     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
   }
-  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
-  auto symbol = opBuilder.getStringAttr(attrName);
+  auto symbol = getSymbolDecoration(decorationName);
   switch (static_cast<spirv::Decoration>(words[1])) {
   case spirv::Decoration::FPFastMathMode:
     if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
     break;
   }
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Block:
   case spirv::Decoration::BufferBlock:
   case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
+  case spirv::Decoration::RestrictPointer:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
   return success();
 }
 
+void spirv::Deserializer::setArgAttrs(uint32_t argID) {
+  if (!decorations.count(argID)) {
+    argAttrs.push_back(DictionaryAttr::get(context, {}));
+    return;
+  }
+
+  // Replace a decoration as UnitAttr with DecorationAttr for the physical
+  // buffer pointer in the function parameter.
+  // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
+  for (auto decAttr : decorations[argID]) {
+    if (decAttr.getName() ==
+        getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(
+          spirv::DecorationAttr::name,
+          spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::Restrict))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(
+          spirv::DecorationAttr::name,
+          spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::AliasedPointer))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(spirv::DecorationAttr::name,
+                             spirv::DecorationAttr::get(
+                                 context, spirv::Decoration::AliasedPointer));
+    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
+                                        spirv::Decoration::RestrictPointer))) {
+      decorations[argID].erase(decAttr.getName());
+      decorations[argID].set(spirv::DecorationAttr::name,
+                             spirv::DecorationAttr::get(
+                                 context, spirv::Decoration::RestrictPointer));
+    }
+  }
+
+  argAttrs.push_back(decorations[argID].getDictionary(context));
+}
+
 LogicalResult
 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   if (curFunction) {
@@ -463,11 +504,18 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
         return emitError(unknownLoc, "duplicate definition of result <id> ")
                << operands[1];
       }
+      setArgAttrs(operands[1]);
       auto argValue = funcOp.getArgument(i);
       valueMap[operands[1]] = argValue;
     }
   }
 
+  if (llvm::any_of(argAttrs, [](Attribute attr) {
+        auto argAttr = cast<DictionaryAttr>(attr);
+        return !argAttr.empty();
+      }))
+    funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
   // entryBlock is needed to access the arguments, Once that is done, we can
   // erase the block for functions with 'Import' LinkageAttributes, since these
   // are essentially function declarations, so they have no body.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..115addd49f949a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,15 @@ class Deserializer {
     return globalVariableMap.lookup(id);
   }
 
+  /// Sets the argument's attributes with the given argument <id>.
+  void setArgAttrs(uint32_t argID);
+
+  /// Gets the symbol name from the name of decoration.
+  StringAttr getSymbolDecoration(StringRef decorationName) {
+    auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+    return opBuilder.getStringAttr(attrName);
+  }
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//
@@ -605,6 +614,9 @@ class Deserializer {
   /// A list of all structs which have unresolved member types.
   SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
 
+  /// A list of argument attributes of function.
+  SmallVector<Attribute, 0> argAttrs;
+
 #ifndef NDEBUG
   /// A logger used to emit information during the deserialzation process.
   llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..2efb0ee64c9253 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,35 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   return success();
 }
 
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+  unsigned numArgs = op.getNumArguments();
+  if (numArgs != 0) {
+    for (unsigned i = 0; i < numArgs; ++i) {
+      auto arg = op.getArgument(i);
+      uint32_t argTypeID = 0;
+      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+        return failure();
+      }
+      auto argValueID = getNextID();
+
+      // Process decoration attributes of arguments.
+      auto funcOp = cast<FunctionOpInterface>(*op);
+      for (auto argAttr : funcOp.getArgAttrs(i)) {
+        if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+          if (failed(processDecorationAttr(op->getLoc(), argValueID,
+                                           decAttr.getValue(), decAttr)))
+            return failure();
+        }
+      }
+
+      valueIDMap[arg] = argValueID;
+      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+                            {argTypeID, argValueID});
+    }
+  }
+  return success();
+}
+
 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
   assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +258,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
     // is going to return false for this function from now on)
     // Hence, we'll remove the body once we are done with the serialization.
     op.addEntryBlock();
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
     // Don't need to process the added block, there is nothing to process,
     // the fake body was added just to get the arguments, remove the body,
     // since it's use is done.
     op.eraseBody();
   } else {
-    // Declare the parameters.
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
 
     // Some instructions (e.g., OpVariable) in a function must be in the first
     // block in the function. These instructions will be put in
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..6c3aaa35457470 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
 }
 
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
-                                            NamedAttribute attr) {
-  auto attrName = attr.getName().strref();
-  auto decorationName = getDecorationName(attrName);
-  auto decoration = spirv::symbolizeDecoration(decorationName);
-  if (!decoration) {
-    return emitError(
-               loc, "non-argument attributes expected to have snake-case-ified "
-                    "decoration name, unhandled attribute with name : ")
-           << attrName;
-  }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+                                                Decoration decoration,
+                                                Attribute attr) {
   SmallVector<uint32_t, 1> args;
-  switch (*decoration) {
+  switch (decoration) {
   case spirv::Decoration::LinkageAttributes: {
     // Get the value of the Linkage Attributes
     // e.g., LinkageAttributes=["linkageName", linkageType].
-    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
     auto linkageName = linkageAttr.getLinkageName();
     auto linkageType = linkageAttr.getLinkageType().getValue();
     // Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     break;
   }
   case spirv::Decoration::FPFastMathMode:
-    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
       args.push_back(static_cast<uint32_t>(intAttr.getValue()));
       break;
     }
     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
-           << attrName;
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
       args.push_back(intAttr.getValue().getZExtValue());
       break;
     }
-    return emitError(loc, "expected integer attribute for ") << attrName;
+    return emitError(loc, "expected integer attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::BuiltIn:
-    if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+    if (auto strAttr = dyn_cast<StringAttr>(attr)) {
       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
       if (enumVal) {
         args.push_back(static_cast<uint32_t>(*enumVal));
         break;
       }
       return emitError(loc, "invalid ")
-             << attrName << " attribute " << strAttr.getValue();
+             << stringifyDecoration(decoration) << " decoration attribute "
+             << strAttr.getValue();
     }
-    return emitError(loc, "expected string attribute for ") << attrName;
+    return emitError(loc, "expected string attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Flat:
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
@@ -275,14 +271,33 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
-    // For unit attributes, the args list has no values so we do nothing
-    if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
+  case spirv::Decoration::RestrictPointer:
+    // For unit attributes and decoration attributes, the args list
+    // has no values so we do nothing.
+    if (isa<UnitAttr>(attr) || isa<DecorationAttr>(attr))
       break;
-    return emitError(loc, "expected unit attribute for ") << attrName;
+    return emitError(loc,
+                     "expected unit attribute or decoration attribute for ")
+           << stringifyDecoration(decoration);
   default:
-    return emitError(loc, "unhandled decoration ") << decorationName;
+    return emitError(loc, "unhandled decoration ")
+           << stringifyDecoration(decoration);
+  }
+  return emitDecoration(resultID, decoration, args);
+}
+
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+                                            NamedAttribute attr) {
+  auto attrName = attr.getName().strref();
+  auto decorationName = getDecorationName(attrName);
+  auto decoration = spirv::symbolizeDecoration(decorationName);
+  if (!decoration) {
+    return emitError(
+               loc, "non-argument attributes expected to have snake-case-ified "
+                    "decoration name, unhandled attribute with name : ")
+           << attrName;
   }
-  return emitDecoration(resultID, *decoration, args);
+  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
 }
 
 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 4b2ebf610bd723..9edb0f4af008dd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -127,6 +127,7 @@ class Serializer {
 
   /// Processes a SPIR-V function op.
   LogicalResult processFuncOp(spirv::FuncOp op);
+  LogicalResult processFuncParameter(spirv::FuncOp op);
 
   LogicalResult processVariableOp(spirv::VariableOp op);
 
@@ -134,6 +135,8 @@ class Serializer {
   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
 
   /// Process attributes that translate to decorations on the result <id>
+  LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
+                                      Decoration decoration, Attribute attr);
   LogicalResult processDecoration(Location loc, uint32_t resultID,
                                   NamedAttribute attr);
 
diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir
index b0f6705df9ca41..3696d37532d20e 100644
--- a/mlir/test/Target/SPIRV/function-decorations.mlir
+++ b/mlir/test/Target/SPIRV/function-decorations.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     spirv.func @linkage_attr_test_kernel()  "DontInline"  attributes {}  {
@@ -17,3 +17,37 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     }
     spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
 }
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+
+  // check: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+  spirv.func @func_arg_decoration_aliased(
+      %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>}
+  ) "None" {
+    spirv.Return
+  }
+
+  // check: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+  spirv.func @func_arg_decoration_restrict(
+      %arg0 : !spirv.ptr<i32,PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}
+  ) "None" {
+    spirv.Return
+  }
+
+  // check: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>})
+  spirv.func @func_arg_decoration_aliased_pointer(
+      %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer>}
+  ) "None" {
+    spirv.Return
+  }
+
+  // check: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>})
+  spirv.func @func_arg_decoration_restrict_pointer(
+      %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer>}
+  ) "None" {
+    spirv.Return
+  }
+}

>From d23a46d54c60d7344bf3355cf3e39a13a6accfe9 Mon Sep 17 00:00:00 2001
From: Kohei Yamaguchi <fix7211 at gmail.com>
Date: Fri, 29 Dec 2023 14:16:21 +0900
Subject: [PATCH 3/3] addressed comments

---
 .../Dialect/SPIRV/IR/SPIRVStructureOps.td     | 23 +++++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 14 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 65 +++++++++----------
 .../SPIRV/Deserialization/Deserializer.cpp    | 40 ++++--------
 .../SPIRV/Serialization/SerializeOps.cpp      | 40 +++++-------
 .../Target/SPIRV/Serialization/Serializer.cpp |  9 +--
 6 files changed, 94 insertions(+), 97 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 0afe508b4db013..fbf750d643031f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -269,8 +269,12 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
 
     From `SPV_KHR_physical_storage_buffer`:
     If a parameter of function is
-    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `Aliased` or `Restrict`.
-    - a pointer (or contains a pointer) and the type it points to is a pointer in the PhysicalStorageBuffer storage class, the function parameter must be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage
+      class, the function parameter must be decorated with exactly one of
+      `Aliased` or `Restrict`.
+    - a pointer (or contains a pointer) and the type it points to is a pointer
+      in the PhysicalStorageBuffer storage class, the function parameter must
+      be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
 
     <!-- End of AutoGen section -->
 
@@ -286,8 +290,19 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     spirv.func @foo() -> () "None" { ... }
     spirv.func @bar() -> () "Inline|Pure" { ... }
 
-    spirv.func @baz(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased|Restrict>}) -> () "None" { ... }
-    spirv.func @qux(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer|RestrictPointer>}) "None)
+    spirv.func @aliased_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+        { spirv.decoration = #spirv.decoration<Aliased> }) -> () "None" { ... }
+
+    spirv.func @restrict_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+        { spirv.decoration = #spirv.decoration<Restrict> }) -> () "None" { ... }
+
+    spirv.func @aliased_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+        PhysicalStorageBuffer>, Generic> { spirv.decoration =
+        #spirv.decoration<AliasedPointer> }) -> () "None" { ... }
+
+    spirv.func @restrict_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+        PhysicalStorageBuffer>, Generic> { spirv.decoration =
+        #spirv.decoration<RestrictPointer> }) -> () "None" { ... }
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 66ec520cfeca31..d7944d600b0a2b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1002,17 +1002,17 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
       return emitError(loc, "'") << symbol
                                  << "' attribute cannot specify storage class "
                                     "when attaching to a non-scalar value";
-  } else if (symbol == spirv::DecorationAttr::name) {
-    auto decAttr = llvm::dyn_cast<spirv::DecorationAttr>(attr);
-    if (!decAttr)
+    return success();
+  }
+  if (symbol == spirv::DecorationAttr::name) {
+    if (!isa<spirv::DecorationAttr>(attr))
       return emitError(loc, "'")
              << symbol << "' must be a spirv::DecorationAttr";
-  } else {
-    return emitError(loc, "found unsupported '")
-           << symbol << "' attribute on region argument";
+    return success();
   }
 
-  return success();
+  return emitError(loc, "found unsupported '")
+         << symbol << "' attribute on region argument";
 }
 
 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d6064f446b4454..995e1b2a33d278 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -976,22 +976,19 @@ LogicalResult spirv::FuncOp::verifyType() {
   if (fnType.getNumResults() > 1)
     return emitOpError("cannot have more than one result");
 
-  auto hasDecorationAttr = [op = getOperation()](spirv::Decoration decoration,
-                                                 unsigned argIndex) {
-    if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-      for (auto argAttr : funcOp.getArgAttrs(argIndex))
-        if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
-          return decAttr.getValue() == decoration;
+  auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
+
+  auto hasDecorationAttr = [&](spirv::Decoration decoration,
+                               unsigned argIndex) {
+    for (auto argAttr :
+         cast<FunctionOpInterface>(*funcOp).getArgAttrs(argIndex)) {
+      if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+        return decAttr.getValue() == decoration;
     }
     return false;
   };
 
-  auto funcOp = dyn_cast<spirv::FuncOp>(getOperation());
-  unsigned numArgs = funcOp.getNumArguments();
-  if (numArgs < 1)
-    return success();
-
-  for (unsigned i = 0; i < numArgs; ++i) {
+  for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
     auto param = fnType.getInputs()[i];
     auto inputPtrType = dyn_cast<spirv::PointerType>(param);
     if (!inputPtrType)
@@ -1005,18 +1002,18 @@ LogicalResult spirv::FuncOp::verifyType() {
       // > and the type it points to is a pointer in the PhysicalStorageBuffer
       // > storage class, the function parameter must be decorated with exactly
       // > one of AliasedPointer or RestrictPointer.
-      if (pointeePtrType.getStorageClass() ==
-          spirv::StorageClass::PhysicalStorageBuffer) {
-        bool hasAliasedPtr =
-            hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
-        bool hasRestrictPtr =
-            hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
-
-        if (!hasAliasedPtr && !hasRestrictPtr)
-          return emitOpError()
-                 << "with a pointer points to a physical buffer pointer must "
-                    "be decorated either 'AliasedPointer' or 'RestrictPointer'";
-      }
+      if (pointeePtrType.getStorageClass() !=
+          spirv::StorageClass::PhysicalStorageBuffer)
+        continue;
+
+      bool hasAliasedPtr =
+          hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+      bool hasRestrictPtr =
+          hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+      if (!hasAliasedPtr && !hasRestrictPtr)
+        return emitOpError()
+               << "with a pointer points to a physical buffer pointer must "
+                  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
     } else {
       // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
       // > If an OpFunctionParameter is a pointer (or contains a pointer) in
@@ -1029,15 +1026,17 @@ LogicalResult spirv::FuncOp::verifyType() {
       } else {
         pointeePtrType = inputPtrType;
       }
-      if (pointeePtrType && pointeePtrType.getStorageClass() ==
-                                spirv::StorageClass::PhysicalStorageBuffer) {
-        bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
-        bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
-        if (!hasAliased && !hasRestrict)
-          return emitOpError()
-                 << "with physical buffer pointer must be decorated "
-                    "either 'Aliased' or 'Restrict'";
-      }
+
+      if (!pointeePtrType || pointeePtrType.getStorageClass() !=
+                                 spirv::StorageClass::PhysicalStorageBuffer)
+        continue;
+
+      bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+      bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+      if (!hasAliased && !hasRestrict)
+        return emitOpError()
+               << "with physical buffer pointer must be decorated "
+                  "either 'Aliased' or 'Restrict'";
     }
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 24748007bbb175..5cf429068610cc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -371,39 +371,25 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
 }
 
 void spirv::Deserializer::setArgAttrs(uint32_t argID) {
-  if (!decorations.count(argID)) {
+  if (!decorations.contains(argID)) {
     argAttrs.push_back(DictionaryAttr::get(context, {}));
     return;
   }
 
   // Replace a decoration as UnitAttr with DecorationAttr for the physical
   // buffer pointer in the function parameter.
-  // e.g. "aliased" -> "spirv.decoration = #spirv.decoration<Aliased>").
-  for (auto decAttr : decorations[argID]) {
-    if (decAttr.getName() ==
-        getSymbolDecoration(stringifyDecoration(spirv::Decoration::Aliased))) {
-      decorations[argID].erase(decAttr.getName());
-      decorations[argID].set(
-          spirv::DecorationAttr::name,
-          spirv::DecorationAttr::get(context, spirv::Decoration::Aliased));
-    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
-                                        spirv::Decoration::Restrict))) {
-      decorations[argID].erase(decAttr.getName());
-      decorations[argID].set(
-          spirv::DecorationAttr::name,
-          spirv::DecorationAttr::get(context, spirv::Decoration::Restrict));
-    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
-                                        spirv::Decoration::AliasedPointer))) {
-      decorations[argID].erase(decAttr.getName());
-      decorations[argID].set(spirv::DecorationAttr::name,
-                             spirv::DecorationAttr::get(
-                                 context, spirv::Decoration::AliasedPointer));
-    } else if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
-                                        spirv::Decoration::RestrictPointer))) {
-      decorations[argID].erase(decAttr.getName());
-      decorations[argID].set(spirv::DecorationAttr::name,
-                             spirv::DecorationAttr::get(
-                                 context, spirv::Decoration::RestrictPointer));
+  // e.g. `aliased` -> `spirv.decoration = #spirv.decoration<Aliased>`).
+  for (NamedAttribute decAttr : decorations[argID]) {
+    for (auto decoration :
+         {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
+          spirv::Decoration::AliasedPointer,
+          spirv::Decoration::RestrictPointer}) {
+      if (decAttr.getName() ==
+          getSymbolDecoration(stringifyDecoration(decoration))) {
+        decorations[argID].erase(decAttr.getName());
+        decorations[argID].set(spirv::DecorationAttr::name,
+                               spirv::DecorationAttr::get(context, decoration));
+      }
     }
   }
 
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 2efb0ee64c9253..cc3c4237ee3443 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -178,30 +178,26 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
 }
 
 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
-  unsigned numArgs = op.getNumArguments();
-  if (numArgs != 0) {
-    for (unsigned i = 0; i < numArgs; ++i) {
-      auto arg = op.getArgument(i);
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-
-      // Process decoration attributes of arguments.
-      auto funcOp = cast<FunctionOpInterface>(*op);
-      for (auto argAttr : funcOp.getArgAttrs(i)) {
-        if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
-          if (failed(processDecorationAttr(op->getLoc(), argValueID,
-                                           decAttr.getValue(), decAttr)))
-            return failure();
-        }
+  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+      return failure();
+    }
+    auto argValueID = getNextID();
+
+    // Process decoration attributes of arguments.
+    auto funcOp = cast<FunctionOpInterface>(*op);
+    for (auto argAttr : funcOp.getArgAttrs(idx)) {
+      if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+        if (failed(processDecorationAttr(op->getLoc(), argValueID,
+                                         decAttr.getValue(), decAttr)))
+          return failure();
       }
-
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
     }
+
+    valueIDMap[arg] = argValueID;
+    encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+                          {argTypeID, argValueID});
   }
   return success();
 }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 6c3aaa35457470..1029fb933175fd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -274,7 +274,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
   case spirv::Decoration::RestrictPointer:
     // For unit attributes and decoration attributes, the args list
     // has no values so we do nothing.
-    if (isa<UnitAttr>(attr) || isa<DecorationAttr>(attr))
+    if (isa<UnitAttr, DecorationAttr>(attr))
       break;
     return emitError(loc,
                      "expected unit attribute or decoration attribute for ")
@@ -288,9 +288,10 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
 
 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
                                             NamedAttribute attr) {
-  auto attrName = attr.getName().strref();
-  auto decorationName = getDecorationName(attrName);
-  auto decoration = spirv::symbolizeDecoration(decorationName);
+  StringRef attrName = attr.getName().strref();
+  std::string decorationName = getDecorationName(attrName);
+  std::optional<Decoration> decoration =
+      spirv::symbolizeDecoration(decorationName);
   if (!decoration) {
     return emitError(
                loc, "non-argument attributes expected to have snake-case-ified "



More information about the Mlir-commits mailing list