[Mlir-commits] [mlir] d6485ed - [MLIR][SPIRV] Add support for OpCopyMemory.

Lei Zhang llvmlistbot at llvm.org
Fri Jun 26 06:44:05 PDT 2020


Author: ergawy
Date: 2020-06-26T09:43:53-04:00
New Revision: d6485ed3a7701364650bffabcbc277733f37eaa7

URL: https://github.com/llvm/llvm-project/commit/d6485ed3a7701364650bffabcbc277733f37eaa7
DIFF: https://github.com/llvm/llvm-project/commit/d6485ed3a7701364650bffabcbc277733f37eaa7.diff

LOG: [MLIR][SPIRV] Add support for OpCopyMemory.

This patch add support for 'spv.CopyMemory'. The following changes are
introduced:
- 'CopyMemory' op is added to SPIRVOps.td.
- Custom parse and print methods are introduced.
- A few Roundtripping tests are added.

Differential Revision: https://reviews.llvm.org/D82384

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
    mlir/test/Dialect/SPIRV/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 6bff480ab83b..832171e92336 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3135,6 +3135,7 @@ def SPV_OC_OpFunctionCall              : I32EnumAttrCase<"OpFunctionCall", 57>;
 def SPV_OC_OpVariable                  : I32EnumAttrCase<"OpVariable", 59>;
 def SPV_OC_OpLoad                      : I32EnumAttrCase<"OpLoad", 61>;
 def SPV_OC_OpStore                     : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpCopyMemory                : I32EnumAttrCase<"OpCopyMemory", 63>;
 def SPV_OC_OpAccessChain               : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
@@ -3264,23 +3265,23 @@ def SPV_OpcodeAttr :
       SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
       SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
       SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
-      SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
-      SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
-      SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
-      SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
-      SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
-      SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
-      SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
-      SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
-      SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
-      SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr,
-      SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
-      SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
-      SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
-      SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
-      SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
+      SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
+      SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
+      SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
+      SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
+      SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
+      SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
+      SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub,
+      SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv,
+      SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod,
+      SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
+      SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
+      SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
+      SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
+      SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
+      SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
       SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
       SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
       SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 8b3a25037078..c92af561faf7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -173,6 +173,58 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
 
 // -----
 
+def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
+  let summary = [{
+    Copy from the memory pointed to by Source to the memory pointed to by
+    Target. Both operands must be non-void pointers and having the same <id>
+    Type operand in their OpTypePointer type declaration.  Matching Storage
+    Class is not required.  The amount of memory copied is the size of the
+    type pointed to. The copied type must have a fixed size; i.e., it cannot
+    be, nor include, any OpTypeRuntimeArray types.
+  }];
+
+  let description = [{
+    If present, any Memory Operands must begin with a memory operand
+    literal. If not present, it is the same as specifying the memory operand
+    None. Before version 1.4, at most one memory operands mask can be
+    provided. Starting with version 1.4 two masks can be provided, as
+    described in Memory Operands. If no masks or only one mask is present,
+    it applies to both Source and Target. If two masks are present, the
+    first applies to Target and cannot include MakePointerVisible, and the
+    second applies to Source and cannot include MakePointerAvailable.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use
+                       storage-class ssa-use
+                       (`[` memory-access `]`)?
+                       ` : ` spirv-element-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %0 = spv.Variable : !spv.ptr<f32, Function>
+    %1 = spv.Variable : !spv.ptr<f32, Function>
+    spv.CopyMemory "Function" %0, "Function" %1 : f32
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_AnyPtr:$target,
+    SPV_AnyPtr:$source,
+    OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
+    OptionalAttr<I32Attr>:$alignment
+  );
+
+  let results = (outs);
+
+  let verifier = [{ return verifyCopyMemory(*this); }];
+}
+
+// -----
+
 def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
   let summary = "Declare an execution mode for an entry point.";
 

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6415218f74c2..8368a8e1857b 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -183,17 +183,17 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   return parser.parseRSquare();
 }
 
-template <typename LoadStoreOpTy>
+template <typename MemoryOpTy>
 static void
-printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
+printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer,
                            SmallVectorImpl<StringRef> &elidedAttrs) {
   // Print optional memory access attribute.
-  if (auto memAccess = loadStoreOp.memory_access()) {
+  if (auto memAccess = memoryOp.memory_access()) {
     elidedAttrs.push_back(spirv::attributeName<spirv::MemoryAccess>());
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
     // Print integer alignment attribute.
-    if (auto alignment = loadStoreOp.alignment()) {
+    if (auto alignment = memoryOp.alignment()) {
       elidedAttrs.push_back(kAlignmentAttrName);
       printer << ", " << alignment;
     }
@@ -243,18 +243,18 @@ static LogicalResult verifyCastOp(Operation *op,
   return success();
 }
 
-template <typename LoadStoreOpTy>
-static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
+template <typename MemoryOpTy>
+static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // ODS checks for attributes values. Just need to verify that if the
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
-  auto *op = loadStoreOp.getOperation();
+  auto *op = memoryOp.getOperation();
   auto memAccessAttr = op->getAttr(spirv::attributeName<spirv::MemoryAccess>());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
     if (op->getAttr(kAlignmentAttrName)) {
-      return loadStoreOp.emitOpError(
+      return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
     }
@@ -265,17 +265,17 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
 
   if (!memAccess) {
-    return loadStoreOp.emitOpError("invalid memory access specifier: ")
+    return memoryOp.emitOpError("invalid memory access specifier: ")
            << memAccessVal;
   }
 
   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
     if (!op->getAttr(kAlignmentAttrName)) {
-      return loadStoreOp.emitOpError("missing alignment value");
+      return memoryOp.emitOpError("missing alignment value");
     }
   } else {
     if (op->getAttr(kAlignmentAttrName)) {
-      return loadStoreOp.emitOpError(
+      return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
     }
@@ -2752,8 +2752,7 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
 static LogicalResult
 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
   if (op.c().getType() != op.result().getType())
-    return op.emitOpError(
-        "result and third operand must have the same type");
+    return op.emitOpError("result and third operand must have the same type");
   auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
   auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
   auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
@@ -2812,9 +2811,89 @@ static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
                             "have the same size");
     }
   }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CopyMemory
+//===----------------------------------------------------------------------===//
+
+static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
+  auto *op = copyMemory.getOperation();
+  printer << spirv::CopyMemoryOp::getOperationName() << ' ';
+
+  StringRef targetStorageClass =
+      stringifyStorageClass(copyMemory.target()
+                                .getType()
+                                .cast<spirv::PointerType>()
+                                .getStorageClass());
+  printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
+          << ", ";
+
+  StringRef sourceStorageClass =
+      stringifyStorageClass(copyMemory.source()
+                                .getType()
+                                .cast<spirv::PointerType>()
+                                .getStorageClass());
+  printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
+
+  SmallVector<StringRef, 4> elidedAttrs;
+  printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
+
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+
+  Type pointeeType =
+      copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
+  printer << " : " << pointeeType;
+}
+
+static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
+                                     OperationState &state) {
+  spirv::StorageClass targetStorageClass;
+  OpAsmParser::OperandType targetPtrInfo;
+
+  spirv::StorageClass sourceStorageClass;
+  OpAsmParser::OperandType sourcePtrInfo;
+
+  Type elementType;
+
+  if (parseEnumStrAttr(targetStorageClass, parser) ||
+      parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
+      parseEnumStrAttr(sourceStorageClass, parser) ||
+      parser.parseOperand(sourcePtrInfo) ||
+      parseMemoryAccessAttributes(parser, state) ||
+      parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+
+  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
+  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
+
+  if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
+      parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
+    return failure();
+  }
+
   return success();
 }
 
+static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
+  Type targetType =
+      copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
+
+  Type sourceType =
+      copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
+
+  if (targetType != sourceType) {
+    return copyMemory.emitOpError(
+        "both operands must be pointers to the same type");
+  }
+
+  return verifyMemoryAccessAttribute(copyMemory);
+}
+
 //===----------------------------------------------------------------------===//
 // spv.Transpose
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index 26584a479dec..25b54c055394 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -57,3 +57,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.Return
   }
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  spv.func @copy_memory_simple() "None" {
+    %0 = spv.Variable : !spv.ptr<f32, Function>
+    %1 = spv.Variable : !spv.ptr<f32, Function>
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} : f32
+    spv.CopyMemory "Function" %0, "Function" %1 : f32
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  spv.func @copy_memory_
diff erent_storage_classes(%in : !spv.ptr<!spv.array<4xf32>, Input>, %out : !spv.ptr<!spv.array<4xf32>, Output>) "None" {
+    // CHECK: spv.CopyMemory "Output" %{{.*}}, "Input" %{{.*}} : !spv.array<4 x f32>
+    spv.CopyMemory "Output" %out, "Input" %in : !spv.array<4xf32>
+    spv.Return
+  }
+}
+
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+  spv.func @copy_memory_with_access_operands() "None" {
+    %0 = spv.Variable : !spv.ptr<f32, Function>
+    %1 = spv.Variable : !spv.ptr<f32, Function>
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4] : f32
+
+    // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
+    spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32
+
+    spv.Return
+  }
+}
+

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 7dea7942d426..c42619100362 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -1244,3 +1244,38 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
   %0 = spv.Variable : !spv.ptr<f32, Generic>
   return
 }
+
+// -----
+
+func @copy_memory_incompatible_ptrs() -> () {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<i32, Function>
+  // expected-error @+1 {{both operands must be pointers to the same type}}
+  "spv.CopyMemory"(%0, %1) {} : (!spv.ptr<f32, Function>, !spv.ptr<i32, Function>) -> ()
+  spv.Return
+}
+
+// -----
+
+func @copy_memory_invalid_maa() -> () {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+  // expected-error @+1 {{missing alignment value}}
+  "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  spv.Return
+}
+
+// -----
+
+func @copy_memory_print_maa() -> () {
+  %0 = spv.Variable : !spv.ptr<f32, Function>
+  %1 = spv.Variable : !spv.ptr<f32, Function>
+
+  // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
+  "spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+
+  // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
+  "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+
+  spv.Return
+}


        


More information about the Mlir-commits mailing list