[Mlir-commits] [mlir] ef2f46e - [MLIR][SPIRV] Support two memory access attributes in OpCopyMemory.

Lei Zhang llvmlistbot at llvm.org
Thu Jul 2 10:20:10 PDT 2020


Author: ergawy
Date: 2020-07-02T13:17:22-04:00
New Revision: ef2f46e1f6a63040734c48ed53893298df14b6fa

URL: https://github.com/llvm/llvm-project/commit/ef2f46e1f6a63040734c48ed53893298df14b6fa
DIFF: https://github.com/llvm/llvm-project/commit/ef2f46e1f6a63040734c48ed53893298df14b6fa.diff

LOG: [MLIR][SPIRV] Support two memory access attributes in OpCopyMemory.

This commit augments spv.CopyMemory's implementation to support 2 memory
access operands. Hence, more closely following the spec. The following
changes are introduces:

- Customize logic for spv.CopyMemory serialization and deserialization.
- Add 2 additional attributes for source memory access operand.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82710

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
    mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
    mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
    mlir/test/Dialect/SPIRV/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index c92af561faf7..e20bde5904bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -198,7 +198,7 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
     ```
     copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
                        storage-class ssa-use
-                       (`[` memory-access `]`)?
+                       (`[` memory-access `]` (`, [` memory-access `]`)?)?
                        ` : ` spirv-element-type
     ```
 
@@ -215,12 +215,16 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
     SPV_AnyPtr:$target,
     SPV_AnyPtr:$source,
     OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<I32Attr>:$alignment,
+    OptionalAttr<SPV_MemoryAccessAttr>:$source_memory_access,
+    OptionalAttr<I32Attr>:$source_alignment
   );
 
   let results = (outs);
 
   let verifier = [{ return verifyCopyMemory(*this); }];
+
+  let autogenSerialization = 0;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 1ac6a1e6d75b..9637ea3ec0c6 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -28,7 +28,11 @@
 using namespace mlir;
 
 // TODO(antiagainst): generate these strings using ODS.
+static constexpr const char kMemoryAccessAttrName[] = "memory_access";
+static constexpr const char kSourceMemoryAccessAttrName[] =
+    "source_memory_access";
 static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
 static constexpr const char kCallee[] = "callee";
 static constexpr const char kClusterSize[] = "cluster_size";
@@ -157,6 +161,8 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
   return success();
 }
 
+template <const char memoryAccessAttrName[] = kMemoryAccessAttrName,
+          const char alignmentAttrName[] = kAlignmentAttrName>
 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                                OperationState &state) {
   // Parse an optional list of attributes staring with '['
@@ -166,7 +172,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  if (parseEnumStrAttr(memoryAccessAttr, parser, state)) {
+  if (parseEnumStrAttr(memoryAccessAttr, parser, state, memoryAccessAttrName)) {
     return failure();
   }
 
@@ -175,7 +181,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
     Attribute alignmentAttr;
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                               state.attributes)) {
       return failure();
     }
@@ -183,19 +189,33 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   return parser.parseRSquare();
 }
 
-template <typename MemoryOpTy>
-static void
-printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
-                           SmallVectorImpl<StringRef> &elidedAttrs) {
+template <typename MemoryOpTy,
+          const char memoryAccessAttrName[] = kMemoryAccessAttrName,
+          const char alignmentAttrName[] = kAlignmentAttrName,
+          bool first = true>
+static void printMemoryAccessAttribute(
+    MemoryOpTy memoryOp, OpAsmPrinter &printer,
+    SmallVectorImpl<StringRef> &elidedAttrs,
+    Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
+    Optional<llvm::APInt> alignmentAttrValue = None) {
   // Print optional memory access attribute.
-  if (auto memAccess = memoryOp.memory_access()) {
-    elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
+  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
+                                              : memoryOp.memory_access())) {
+    elidedAttrs.push_back(memoryAccessAttrName);
+
+    if (!first) {
+      printer << ", ";
+    }
+
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
-    // Print integer alignment attribute.
-    if (auto alignment = memoryOp.alignment()) {
-      elidedAttrs.push_back(kAlignmentAttrName);
-      printer << ", " << alignment;
+    if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+      // Print integer alignment attribute.
+      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
+                                               : memoryOp.alignment())) {
+        elidedAttrs.push_back(alignmentAttrName);
+        printer << ", " << alignment;
+      }
     }
     printer << "]";
   }
@@ -243,17 +263,19 @@ static LogicalResult verifyCastOp(Operation *op,
   return success();
 }
 
-template <typename MemoryOpTy>
+template <typename MemoryOpTy,
+          const char memoryAccessAttrName[] = kMemoryAccessAttrName,
+          const char alignmentAttrName[] = kAlignmentAttrName>
 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // ODS checks for attributes values. Just need to verify that if the
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
+  auto memAccessAttr = op->getAttr(memoryAccessAttrName);
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(alignmentAttrName)) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -270,11 +292,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   }
 
   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
+    if (!op->getAttr(alignmentAttrName)) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(alignmentAttrName)) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -2839,6 +2861,10 @@ static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
 
   SmallVector<StringRef, 4> elidedAttrs;
   printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
+  printMemoryAccessAttribute<decltype(copyMemory), kSourceMemoryAccessAttrName,
+                             kSourceAlignmentAttrName, false>(
+      copyMemory, printer, elidedAttrs, copyMemory.source_memory_access(),
+      copyMemory.source_alignment());
 
   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
@@ -2861,9 +2887,23 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
       parseEnumStrAttr(sourceStorageClass, parser) ||
       parser.parseOperand(sourcePtrInfo) ||
-      parseMemoryAccessAttributes(parser, state) ||
-      parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
-      parser.parseType(elementType)) {
+      parseMemoryAccessAttributes(parser, state)) {
+    return failure();
+  }
+
+  if (!parser.parseOptionalComma()) {
+    // Parse 2nd memory access attributes.
+    if (parseMemoryAccessAttributes<kSourceMemoryAccessAttrName,
+                                    kSourceAlignmentAttrName>(parser, state)) {
+      return failure();
+    }
+  }
+
+  if (parser.parseColon() || parser.parseType(elementType)) {
+    return failure();
+  }
+
+  if (parser.parseOptionalAttrDict(state.attributes)) {
     return failure();
   }
 
@@ -2890,7 +2930,21 @@ static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
         "both operands must be pointers to the same type");
   }
 
-  return verifyMemoryAccessAttribute(copyMemory);
+  if (failed(verifyMemoryAccessAttribute(copyMemory))) {
+    return failure();
+  }
+
+  // TODO (ergawy): According to the spec:
+  //
+  // If two masks are present, the first applies to Target and cannot include
+  // MakePointerVisible, and the second applies to Source and cannot include
+  // MakePointerAvailable.
+  //
+  // Add such verification here.
+
+  return verifyMemoryAccessAttribute<decltype(copyMemory),
+                                     kSourceMemoryAccessAttrName,
+                                     kSourceAlignmentAttrName>(copyMemory);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 4bb9e6d26b97..56787385f36a 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -400,7 +400,8 @@ class Deserializer {
   /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
   /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
   /// == 1 and autogenSerialization == 1 in ODS.
-  template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
+  template <typename OpTy>
+  LogicalResult processOp(ArrayRef<uint32_t> words) {
     return emitError(unknownLoc, "unsupported deserialization for ")
            << OpTy::getOperationName() << " op";
   }
@@ -1566,8 +1567,8 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
     return success();
   }
 
-    return emitError(unknownLoc, "unsupported OpConstantNull type: ")
-           << resultType;
+  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
+         << resultType;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2509,6 +2510,76 @@ Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+template <>
+LogicalResult
+Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
+  SmallVector<Type, 1> resultTypes;
+  size_t wordIndex = 0;
+  SmallVector<Value, 4> operands;
+  SmallVector<NamedAttribute, 4> attributes;
+
+  if (wordIndex < words.size()) {
+    auto arg = getValue(words[wordIndex]);
+
+    if (!arg) {
+      return emitError(unknownLoc, "unknown result <id> : ")
+             << words[wordIndex];
+    }
+
+    operands.push_back(arg);
+    wordIndex++;
+  }
+
+  if (wordIndex < words.size()) {
+    auto arg = getValue(words[wordIndex]);
+
+    if (!arg) {
+      return emitError(unknownLoc, "unknown result <id> : ")
+             << words[wordIndex];
+    }
+
+    operands.push_back(arg);
+    wordIndex++;
+  }
+
+  bool isAlignedAttr = false;
+
+  if (wordIndex < words.size()) {
+    auto attrValue = words[wordIndex++];
+    attributes.push_back(opBuilder.getNamedAttr(
+        "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
+    isAlignedAttr = (attrValue == 2);
+  }
+
+  if (isAlignedAttr && wordIndex < words.size()) {
+    attributes.push_back(opBuilder.getNamedAttr(
+        "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
+  }
+
+  if (wordIndex < words.size()) {
+    attributes.push_back(opBuilder.getNamedAttr(
+        "source_memory_access",
+        opBuilder.getI32IntegerAttr(words[wordIndex++])));
+  }
+
+  if (wordIndex < words.size()) {
+    attributes.push_back(opBuilder.getNamedAttr(
+        "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
+  }
+
+  if (wordIndex != words.size()) {
+    return emitError(unknownLoc,
+                     "found more operands than expected when deserializing "
+                     "spirv::CopyMemoryOp, only ")
+           << wordIndex << " of " << words.size() << " processed";
+  }
+
+  Location loc = createFileLineColLoc(opBuilder);
+  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
+
+  return success();
+}
+
 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
 // various Deserializer::processOp<...>() specializations.
 #define GET_DESERIALIZATION_FNS

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f8641873fd95..6011dd5adfdc 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -364,7 +364,8 @@ class Serializer {
   /// Method to serialize an operation in the SPIR-V dialect that is a mirror of
   /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
   /// 1 and autogenSerialization == 1 in ODS.
-  template <typename OpTy> LogicalResult processOp(OpTy op) {
+  template <typename OpTy>
+  LogicalResult processOp(OpTy op) {
     return op.emitError("unsupported op serialization");
   }
 
@@ -1904,6 +1905,51 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
                                operands);
 }
 
+template <>
+LogicalResult
+Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
+  SmallVector<uint32_t, 4> operands;
+  SmallVector<StringRef, 2> elidedAttrs;
+
+  for (Value operand : op.getOperation()->getOperands()) {
+    auto id = getValueID(operand);
+    assert(id && "use before def!");
+    operands.push_back(id);
+  }
+
+  if (auto attr = op.getAttr("memory_access")) {
+    operands.push_back(static_cast<uint32_t>(
+        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+  }
+
+  elidedAttrs.push_back("memory_access");
+
+  if (auto attr = op.getAttr("alignment")) {
+    operands.push_back(static_cast<uint32_t>(
+        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+  }
+
+  elidedAttrs.push_back("alignment");
+
+  if (auto attr = op.getAttr("source_memory_access")) {
+    operands.push_back(static_cast<uint32_t>(
+        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+  }
+
+  elidedAttrs.push_back("source_memory_access");
+
+  if (auto attr = op.getAttr("source_alignment")) {
+    operands.push_back(static_cast<uint32_t>(
+        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+  }
+
+  elidedAttrs.push_back("source_alignment");
+  emitDebugLine(functionBody, op.getLoc());
+  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
+
+  return success();
+}
+
 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
 // various Serializer::processOp<...>() specializations.
 #define GET_SERIALIZATION_FNS

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 25b54c055394..0e18c1ea37bf 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -93,6 +93,18 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
     spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32
 
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Volatile"] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32
+
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4], ["Volatile"] : f32
+
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Aligned", 4] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Aligned", 4] : f32
+
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 8], ["Aligned", 4] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 8], ["Aligned", 4] : f32
+
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index c42619100362..c91a81fe239c 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -1247,7 +1247,7 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
 
 // -----
 
-func @copy_memory_incompatible_ptrs() -> () {
+func @copy_memory_incompatible_ptrs() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<i32, Function>
   // expected-error @+1 {{both operands must be pointers to the same type}}
@@ -1257,7 +1257,7 @@ func @copy_memory_incompatible_ptrs() -> () {
 
 // -----
 
-func @copy_memory_invalid_maa() -> () {
+func @copy_memory_invalid_maa() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{missing alignment value}}
@@ -1267,7 +1267,27 @@ func @copy_memory_invalid_maa() -> () {
 
 // -----
 
-func @copy_memory_print_maa() -> () {
+func @copy_memory_invalid_source_maa() {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+  // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}}
+  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  spv.Return
+}
+
+// -----
+
+func @copy_memory_invalid_source_maa2() {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+  // expected-error @+1 {{missing alignment value}}
+  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  spv.Return
+}
+
+// -----
+
+func @copy_memory_print_maa() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<f32, Function>
 
@@ -1277,5 +1297,11 @@ func @copy_memory_print_maa() -> () {
   // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
   "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
 
+  // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
+  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+
+  // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32
+  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+
   spv.Return
 }


        


More information about the Mlir-commits mailing list