[Mlir-commits] [mlir] 747d8fb - [mlir][spirv] Support alias/restrict function argument decorations (#76353)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 6 11:51:27 PST 2024


Author: Kohei Yamaguchi
Date: 2024-01-06T11:51:23-08:00
New Revision: 747d8fb01c2417546ebaa774874ff8c3005e058a

URL: https://github.com/llvm/llvm-project/commit/747d8fb01c2417546ebaa774874ff8c3005e058a
DIFF: https://github.com/llvm/llvm-project/commit/747d8fb01c2417546ebaa774874ff8c3005e058a.diff

LOG: [mlir][spirv] Support alias/restrict function argument decorations (#76353)

Closes #76106

---------

Co-authored-by: Lei Zhang <antiagainst at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.h
    mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
    mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
    mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
    mlir/test/Target/SPIRV/cast-ops.mlir
    mlir/test/Target/SPIRV/function-decorations.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 5fd25e3b576f2a..fbf750d643031f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -267,6 +267,15 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     This op itself takes no operands and generates no results. Its region
     can take zero or more arguments and return zero or one values.
 
+    From `SPV_KHR_physical_storage_buffer`:
+    If a parameter of function is
+    - a pointer (or contains a pointer) in the PhysicalStorageBuffer storage
+      class, the function parameter must be decorated with exactly one of
+      `Aliased` or `Restrict`.
+    - a pointer (or contains a pointer) and the type it points to is a pointer
+      in the PhysicalStorageBuffer storage class, the function parameter must
+      be decorated with exactly one of `AliasedPointer` or `RestrictPointer`.
+
     <!-- End of AutoGen section -->
 
     ```
@@ -280,6 +289,20 @@ def SPIRV_FuncOp : SPIRV_Op<"func", [
     ```mlir
     spirv.func @foo() -> () "None" { ... }
     spirv.func @bar() -> () "Inline|Pure" { ... }
+
+    spirv.func @aliased_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+        { spirv.decoration = #spirv.decoration<Aliased> }) -> () "None" { ... }
+
+    spirv.func @restrict_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>,
+        { spirv.decoration = #spirv.decoration<Restrict> }) -> () "None" { ... }
+
+    spirv.func @aliased_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+        PhysicalStorageBuffer>, Generic> { spirv.decoration =
+        #spirv.decoration<AliasedPointer> }) -> () "None" { ... }
+
+    spirv.func @restrict_pointee(%arg0: !spirv.ptr<!spirv.ptr<i32,
+        PhysicalStorageBuffer>, Generic> { spirv.decoration =
+        #spirv.decoration<RestrictPointer> }) -> () "None" { ... }
     ```
   }];
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 8a68decc5878c8..d7944d600b0a2b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -992,30 +992,39 @@ static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
   StringRef symbol = attribute.getName().strref();
   Attribute attr = attribute.getValue();
 
-  if (symbol != spirv::getInterfaceVarABIAttrName())
-    return emitError(loc, "found unsupported '")
-           << symbol << "' attribute on region argument";
-
-  auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
-  if (!varABIAttr)
-    return emitError(loc, "'")
-           << symbol << "' must be a spirv::InterfaceVarABIAttr";
-
-  if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
-    return emitError(loc, "'") << symbol
-                               << "' attribute cannot specify storage class "
-                                  "when attaching to a non-scalar value";
+  if (symbol == spirv::getInterfaceVarABIAttrName()) {
+    auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
+    if (!varABIAttr)
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::InterfaceVarABIAttr";
+
+    if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
+      return emitError(loc, "'") << symbol
+                                 << "' attribute cannot specify storage class "
+                                    "when attaching to a non-scalar value";
+    return success();
+  }
+  if (symbol == spirv::DecorationAttr::name) {
+    if (!isa<spirv::DecorationAttr>(attr))
+      return emitError(loc, "'")
+             << symbol << "' must be a spirv::DecorationAttr";
+    return success();
+  }
 
-  return success();
+  return emitError(loc, "found unsupported '")
+         << symbol << "' attribute on region argument";
 }
 
 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
                                                      unsigned regionIndex,
                                                      unsigned argIndex,
                                                      NamedAttribute attribute) {
-  return verifyRegionAttribute(
-      op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
-      attribute);
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type argType = funcOp.getArgumentTypes()[argIndex];
+
+  return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5343a12132a912..3b159030cab75c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -972,8 +972,73 @@ void spirv::FuncOp::print(OpAsmPrinter &printer) {
 }
 
 LogicalResult spirv::FuncOp::verifyType() {
-  if (getFunctionType().getNumResults() > 1)
+  FunctionType fnType = getFunctionType();
+  if (fnType.getNumResults() > 1)
     return emitOpError("cannot have more than one result");
+
+  auto hasDecorationAttr = [&](spirv::Decoration decoration,
+                               unsigned argIndex) {
+    auto func = llvm::cast<FunctionOpInterface>(getOperation());
+    for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
+      if (argAttr.getName() != spirv::DecorationAttr::name)
+        continue;
+      if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
+        return decAttr.getValue() == decoration;
+    }
+    return false;
+  };
+
+  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
+    Type param = fnType.getInputs()[i];
+    auto inputPtrType = dyn_cast<spirv::PointerType>(param);
+    if (!inputPtrType)
+      continue;
+
+    auto pointeePtrType =
+        dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
+    if (pointeePtrType) {
+      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+      // > If an OpFunctionParameter is a pointer (or contains a pointer)
+      // > and the type it points to is a pointer in the PhysicalStorageBuffer
+      // > storage class, the function parameter must be decorated with exactly
+      // > one of AliasedPointer or RestrictPointer.
+      if (pointeePtrType.getStorageClass() !=
+          spirv::StorageClass::PhysicalStorageBuffer)
+        continue;
+
+      bool hasAliasedPtr =
+          hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
+      bool hasRestrictPtr =
+          hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
+      if (!hasAliasedPtr && !hasRestrictPtr)
+        return emitOpError()
+               << "with a pointer points to a physical buffer pointer must "
+                  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
+      continue;
+    }
+    // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
+    // > If an OpFunctionParameter is a pointer (or contains a pointer) in
+    // > the PhysicalStorageBuffer storage class, the function parameter must
+    // > be decorated with exactly one of Aliased or Restrict.
+    if (auto pointeeArrayType =
+            dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
+      pointeePtrType =
+          dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
+    } else {
+      pointeePtrType = inputPtrType;
+    }
+
+    if (!pointeePtrType || pointeePtrType.getStorageClass() !=
+                               spirv::StorageClass::PhysicalStorageBuffer)
+      continue;
+
+    bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
+    bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
+    if (!hasAliased && !hasRestrict)
+      return emitOpError() << "with physical buffer pointer must be decorated "
+                              "either 'Aliased' or 'Restrict'";
+  }
+
   return success();
 }
 

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 00645d2c45519e..0c521adb11332b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -239,8 +239,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   if (decorationName.empty()) {
     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
   }
-  auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
-  auto symbol = opBuilder.getStringAttr(attrName);
+  auto symbol = getSymbolDecoration(decorationName);
   switch (static_cast<spirv::Decoration>(words[1])) {
   case spirv::Decoration::FPFastMathMode:
     if (words.size() != 3) {
@@ -298,6 +297,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
     break;
   }
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Block:
   case spirv::Decoration::BufferBlock:
   case spirv::Decoration::Flat:
@@ -308,6 +308,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
+  case spirv::Decoration::RestrictPointer:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
@@ -369,6 +370,46 @@ LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::setFunctionArgAttrs(
+    uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
+  if (!decorations.contains(argID)) {
+    argAttrs[argIndex] = DictionaryAttr::get(context, {});
+    return success();
+  }
+
+  spirv::DecorationAttr foundDecorationAttr;
+  for (NamedAttribute decAttr : decorations[argID]) {
+    for (auto decoration :
+         {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
+          spirv::Decoration::AliasedPointer,
+          spirv::Decoration::RestrictPointer}) {
+
+      if (decAttr.getName() !=
+          getSymbolDecoration(stringifyDecoration(decoration)))
+        continue;
+
+      if (foundDecorationAttr)
+        return emitError(unknownLoc,
+                         "more than one Aliased/Restrict decorations for "
+                         "function argument with result <id> ")
+               << argID;
+
+      foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
+      break;
+    }
+  }
+
+  if (!foundDecorationAttr)
+    return emitError(unknownLoc, "unimplemented decoration support for "
+                                 "function argument with result <id> ")
+           << argID;
+
+  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
+                      foundDecorationAttr);
+  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   if (curFunction) {
@@ -430,6 +471,9 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
     logger.indent();
   });
 
+  SmallVector<Attribute> argAttrs;
+  argAttrs.resize(functionType.getNumInputs());
+
   // Parse the op argument instructions
   if (functionType.getNumInputs()) {
     for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
@@ -463,11 +507,21 @@ spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
         return emitError(unknownLoc, "duplicate definition of result <id> ")
                << operands[1];
       }
+      if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
+        return failure();
+      }
+
       auto argValue = funcOp.getArgument(i);
       valueMap[operands[1]] = argValue;
     }
   }
 
+  if (llvm::any_of(argAttrs, [](Attribute attr) {
+        auto argAttr = cast<DictionaryAttr>(attr);
+        return !argAttr.empty();
+      }))
+    funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
+
   // entryBlock is needed to access the arguments, Once that is done, we can
   // erase the block for functions with 'Import' LinkageAttributes, since these
   // are essentially function declarations, so they have no body.

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 69be47851ef3c5..fc9a8f5f9364b2 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -233,6 +233,19 @@ class Deserializer {
     return globalVariableMap.lookup(id);
   }
 
+  /// Sets the function argument's attributes. |argID| is the function
+  /// argument's result <id>, and |argIndex| is its index in the function's
+  /// argument list.
+  LogicalResult setFunctionArgAttrs(uint32_t argID,
+                                    SmallVectorImpl<Attribute> &argAttrs,
+                                    size_t argIndex);
+
+  /// Gets the symbol name from the name of decoration.
+  StringAttr getSymbolDecoration(StringRef decorationName) {
+    auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
+    return opBuilder.getStringAttr(attrName);
+  }
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 7bfcca5b4dcdca..41d2c0310d0008 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -177,6 +177,34 @@ LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   return success();
 }
 
+LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
+  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+      return failure();
+    }
+    auto argValueID = getNextID();
+
+    // Process decoration attributes of arguments.
+    auto funcOp = cast<FunctionOpInterface>(*op);
+    for (auto argAttr : funcOp.getArgAttrs(idx)) {
+      if (argAttr.getName() != DecorationAttr::name)
+        continue;
+
+      if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
+        if (failed(processDecorationAttr(op->getLoc(), argValueID,
+                                         decAttr.getValue(), decAttr)))
+          return failure();
+      }
+    }
+
+    valueIDMap[arg] = argValueID;
+    encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
+                          {argTypeID, argValueID});
+  }
+  return success();
+}
+
 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
   assert(functionHeader.empty() && functionBody.empty());
@@ -229,32 +257,15 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
     // is going to return false for this function from now on)
     // Hence, we'll remove the body once we are done with the serialization.
     op.addEntryBlock();
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
     // Don't need to process the added block, there is nothing to process,
     // the fake body was added just to get the arguments, remove the body,
     // since it's use is done.
     op.eraseBody();
   } else {
-    // Declare the parameters.
-    for (auto arg : op.getArguments()) {
-      uint32_t argTypeID = 0;
-      if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
-        return failure();
-      }
-      auto argValueID = getNextID();
-      valueIDMap[arg] = argValueID;
-      encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
-                            {argTypeID, argValueID});
-    }
+    if (failed(processFuncParameter(op)))
+      return failure();
 
     // Some instructions (e.g., OpVariable) in a function must be in the first
     // block in the function. These instructions will be put in

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 9e9a16456cc102..1029fb933175fd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -215,23 +215,15 @@ static std::string getDecorationName(StringRef attrName) {
   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
 }
 
-LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
-                                            NamedAttribute attr) {
-  auto attrName = attr.getName().strref();
-  auto decorationName = getDecorationName(attrName);
-  auto decoration = spirv::symbolizeDecoration(decorationName);
-  if (!decoration) {
-    return emitError(
-               loc, "non-argument attributes expected to have snake-case-ified "
-                    "decoration name, unhandled attribute with name : ")
-           << attrName;
-  }
+LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
+                                                Decoration decoration,
+                                                Attribute attr) {
   SmallVector<uint32_t, 1> args;
-  switch (*decoration) {
+  switch (decoration) {
   case spirv::Decoration::LinkageAttributes: {
     // Get the value of the Linkage Attributes
     // e.g., LinkageAttributes=["linkageName", linkageType].
-    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
+    auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
     auto linkageName = linkageAttr.getLinkageName();
     auto linkageType = linkageAttr.getLinkageType().getValue();
     // Encode the Linkage Name (string literal to uint32_t).
@@ -241,32 +233,36 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     break;
   }
   case spirv::Decoration::FPFastMathMode:
-    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
       args.push_back(static_cast<uint32_t>(intAttr.getValue()));
       break;
     }
     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
-           << attrName;
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
       args.push_back(intAttr.getValue().getZExtValue());
       break;
     }
-    return emitError(loc, "expected integer attribute for ") << attrName;
+    return emitError(loc, "expected integer attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::BuiltIn:
-    if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
+    if (auto strAttr = dyn_cast<StringAttr>(attr)) {
       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
       if (enumVal) {
         args.push_back(static_cast<uint32_t>(*enumVal));
         break;
       }
       return emitError(loc, "invalid ")
-             << attrName << " attribute " << strAttr.getValue();
+             << stringifyDecoration(decoration) << " decoration attribute "
+             << strAttr.getValue();
     }
-    return emitError(loc, "expected string attribute for ") << attrName;
+    return emitError(loc, "expected string attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Aliased:
+  case spirv::Decoration::AliasedPointer:
   case spirv::Decoration::Flat:
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
@@ -275,14 +271,34 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
-    // For unit attributes, the args list has no values so we do nothing
-    if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
+  case spirv::Decoration::RestrictPointer:
+    // For unit attributes and decoration attributes, the args list
+    // has no values so we do nothing.
+    if (isa<UnitAttr, DecorationAttr>(attr))
       break;
-    return emitError(loc, "expected unit attribute for ") << attrName;
+    return emitError(loc,
+                     "expected unit attribute or decoration attribute for ")
+           << stringifyDecoration(decoration);
   default:
-    return emitError(loc, "unhandled decoration ") << decorationName;
+    return emitError(loc, "unhandled decoration ")
+           << stringifyDecoration(decoration);
+  }
+  return emitDecoration(resultID, decoration, args);
+}
+
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+                                            NamedAttribute attr) {
+  StringRef attrName = attr.getName().strref();
+  std::string decorationName = getDecorationName(attrName);
+  std::optional<Decoration> decoration =
+      spirv::symbolizeDecoration(decorationName);
+  if (!decoration) {
+    return emitError(
+               loc, "non-argument attributes expected to have snake-case-ified "
+                    "decoration name, unhandled attribute with name : ")
+           << attrName;
   }
-  return emitDecoration(resultID, *decoration, args);
+  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
 }
 
 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 4b2ebf610bd723..9edb0f4af008dd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -127,6 +127,7 @@ class Serializer {
 
   /// Processes a SPIR-V function op.
   LogicalResult processFuncOp(spirv::FuncOp op);
+  LogicalResult processFuncParameter(spirv::FuncOp op);
 
   LogicalResult processVariableOp(spirv::VariableOp op);
 
@@ -134,6 +135,8 @@ class Serializer {
   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
 
   /// Process attributes that translate to decorations on the result <id>
+  LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
+                                      Decoration decoration, Attribute attr);
   LogicalResult processDecoration(Location loc, uint32_t resultID,
                                   NamedAttribute attr);
 

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
index b3991cbdbe8af1..b9c56a3fcffd04 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -81,7 +81,7 @@ spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>)
 spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
 
 // CHECK-ALL:            llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
-spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None"
 
 // CHECK-ALL:            llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
 spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"

diff  --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4f4a72da7c050a..e289dbf28ad284 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -414,7 +414,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return

diff  --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
index 2e39421df13cca..d915f8820c4f40 100644
--- a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir
@@ -17,3 +17,59 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     }
     spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
 }
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
+  spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict> }) "None" {
+  spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer> }) "None" {
+  spirv.Return
+}
+
+// -----
+
+// CHECK: spirv.func @arg_decoration_pointer(%{{.+}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>}) "None"
+spirv.func @arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer> }) "None" {
+  spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with a pointer points to a physical buffer pointer must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
+spirv.func @no_arg_decoration_pointer(%arg0: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Function>) "None" {
+  spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op with physical buffer pointer must be decorated either 'Aliased' or 'Restrict'}}
+spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { random_attr = #spirv.decoration<Aliased> }) "None" {
+  spirv.Return
+}
+
+// -----
+
+// expected-error @+1 {{'spirv.func' op arguments may only have dialect attributes}}
+spirv.func @no_decoration_name_attr(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>, random_attr = #spirv.decoration<Aliased> }) "None" {
+  spirv.Return
+}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4eaa21d2f94ef6..931034f3d5f6ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -66,7 +66,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes {
   spirv.target_env = #spirv.target_env<
     #spirv.vce<v1.0, [Shader, PhysicalStorageBufferAddresses], [SPV_EXT_physical_storage_buffer]>, #spirv.resource_limits<>>
 } {
-  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer>) "None" {
+  spirv.func @physical_ptr(%val : !spirv.ptr<f32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }) "None" {
     spirv.Return
   }
 }

diff  --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 7fe0969497c3ec..ede0bf30511ef4 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -115,7 +115,7 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses], []
 // -----
 
 spirv.module PhysicalStorageBuffer64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, PhysicalStorageBufferAddresses], []> {  
-  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer>) "None" {
+  spirv.func @covert_ptr_to_u_PhysicalStorageBuffer(%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased>} ) "None" {
     // CHECK: {{%.*}} = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     %0 = spirv.ConvertPtrToU %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> to i32
     spirv.Return

diff  --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir
index b0f6705df9ca41..117d4ca628f76a 100644
--- a/mlir/test/Target/SPIRV/function-decorations.mlir
+++ b/mlir/test/Target/SPIRV/function-decorations.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file -verify-diagnostics %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     spirv.func @linkage_attr_test_kernel()  "DontInline"  attributes {}  {
@@ -17,3 +17,72 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     }
     spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return}
 }
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+  // CHECK-LABEL: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+  spirv.func @func_arg_decoration_aliased(
+      %arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }
+  ) "None" {
+    spirv.Return
+  }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+  // CHECK-LABEL: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+  spirv.func @func_arg_decoration_restrict(
+      %arg0 : !spirv.ptr<i32,PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict> }
+  ) "None" {
+    spirv.Return
+  }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+  // CHECK-LABEL: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>})
+  spirv.func @func_arg_decoration_aliased_pointer(
+      %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer> }
+  ) "None" {
+    spirv.Return
+  }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+  // CHECK-LABEL: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>})
+  spirv.func @func_arg_decoration_restrict_pointer(
+      %arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer> }
+  ) "None" {
+    spirv.Return
+  }
+}
+
+// -----
+
+spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
+    [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+  // CHECK-LABEL: spirv.func @fn1(%{{.*}}: i32, %{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
+  spirv.func @fn1(
+      %arg0: i32,
+      %arg1: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }
+  ) "None" {
+    spirv.Return
+  }
+
+  // CHECK-LABEL: spirv.func @fn2(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>}, %{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
+  spirv.func @fn2(
+      %arg0: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> },
+      %arg1: !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict>}
+  ) "None" {
+    spirv.Return
+  }
+}


        


More information about the Mlir-commits mailing list