[Mlir-commits] [mlir] aca9bea - [mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS

River Riddle llvmlistbot at llvm.org
Fri Sep 24 12:39:50 PDT 2021


Author: River Riddle
Date: 2021-09-24T19:35:28Z
New Revision: aca9bea1992ce270d094105ae8968c703b8ffb65

URL: https://github.com/llvm/llvm-project/commit/aca9bea1992ce270d094105ae8968c703b8ffb65
DIFF: https://github.com/llvm/llvm-project/commit/aca9bea1992ce270d094105ae8968c703b8ffb65.diff

LOG: [mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS

These are among the last operations still defined explicitly in C++. I've
tried to keep this commit as NFC as possible, but these ops
definitely need a non-NFC cleanup at some point.

Differential Revision: https://reviews.llvm.org/D110440

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 8c61248567f6a..a022ffc8dbc20 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -46,206 +46,4 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
 #define GET_OP_CLASSES
 #include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
 
-namespace mlir {
-namespace memref {
-// DmaStartOp starts a non-blocking DMA operation that transfers data from a
-// source memref to a destination memref. The source and destination memref need
-// not be of the same dimensionality, but need to have the same elemental type.
-// The operands include the source and destination memref's each followed by its
-// indices, size of the data transfer in terms of the number of elements (of the
-// elemental type of the memref), a tag memref with its indices, and optionally
-// at the end, a stride and a number_of_elements_per_stride arguments. The tag
-// location is used by a DmaWaitOp to check for completion. The indices of the
-// source memref, destination memref, and the tag memref have the same
-// restrictions as any load/store. The optional stride arguments should be of
-// 'index' type, and specify a stride for the slower memory space (memory space
-// with a lower memory space id), transferring chunks of
-// number_of_elements_per_stride every stride until %num_elements are
-// transferred. Either both or no stride arguments should be specified. If the
-// source and destination locations overlap the behavior of this operation is
-// not defined.
-//
-// For example, a DmaStartOp operation that transfers 256 elements of a memref
-// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
-// 1 at indices [%k, %l], would be specified as follows:
-//
-//   %num_elements = constant 256
-//   %idx = constant 0 : index
-//   %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
-//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
-//     memref<40 x 128 x f32>, (d0) -> (d0), 0>,
-//     memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
-//     memref<1 x i32>, (d0) -> (d0), 2>
-//
-//   If %stride and %num_elt_per_stride are specified, the DMA is expected to
-//   transfer %num_elt_per_stride elements every %stride elements apart from
-//   memory space 0 until %num_elements are transferred.
-//
-//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
-//             %num_elt_per_stride :
-//
-// TODO: add additional operands to allow source and destination striding, and
-// multiple stride levels.
-// TODO: Consider replacing src/dst memref indices with view memrefs.
-class DmaStartOp
-    : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
-  using Op::Op;
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
-                    ValueRange srcIndices, Value destMemRef,
-                    ValueRange destIndices, Value numElements, Value tagMemRef,
-                    ValueRange tagIndices, Value stride = nullptr,
-                    Value elementsPerStride = nullptr);
-
-  // Returns the source MemRefType for this DMA operation.
-  Value getSrcMemRef() { return getOperand(0); }
-  // Returns the rank (number of indices) of the source MemRefType.
-  unsigned getSrcMemRefRank() {
-    return getSrcMemRef().getType().cast<MemRefType>().getRank();
-  }
-  // Returns the source memref indices for this DMA operation.
-  operand_range getSrcIndices() {
-    return {(*this)->operand_begin() + 1,
-            (*this)->operand_begin() + 1 + getSrcMemRefRank()};
-  }
-
-  // Returns the destination MemRefType for this DMA operations.
-  Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
-  // Returns the rank (number of indices) of the destination MemRefType.
-  unsigned getDstMemRefRank() {
-    return getDstMemRef().getType().cast<MemRefType>().getRank();
-  }
-  unsigned getSrcMemorySpace() {
-    return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
-  }
-  unsigned getDstMemorySpace() {
-    return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
-  }
-
-  // Returns the destination memref indices for this DMA operation.
-  operand_range getDstIndices() {
-    return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
-            (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
-                getDstMemRefRank()};
-  }
-
-  // Returns the number of elements being transferred by this DMA operation.
-  Value getNumElements() {
-    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
-  }
-
-  // Returns the Tag MemRef for this DMA operation.
-  Value getTagMemRef() {
-    return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
-  }
-  // Returns the rank (number of indices) of the tag MemRefType.
-  unsigned getTagMemRefRank() {
-    return getTagMemRef().getType().cast<MemRefType>().getRank();
-  }
-
-  // Returns the tag memref index for this DMA operation.
-  operand_range getTagIndices() {
-    unsigned tagIndexStartPos =
-        1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
-    return {(*this)->operand_begin() + tagIndexStartPos,
-            (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
-  }
-
-  /// Returns true if this is a DMA from a faster memory space to a slower one.
-  bool isDestMemorySpaceFaster() {
-    return (getSrcMemorySpace() < getDstMemorySpace());
-  }
-
-  /// Returns true if this is a DMA from a slower memory space to a faster one.
-  bool isSrcMemorySpaceFaster() {
-    // Assumes that a lower number is for a slower memory space.
-    return (getDstMemorySpace() < getSrcMemorySpace());
-  }
-
-  /// Given a DMA start operation, returns the operand position of either the
-  /// source or destination memref depending on the one that is at the higher
-  /// level of the memory hierarchy. Asserts failure if neither is true.
-  unsigned getFasterMemPos() {
-    assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
-    return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
-  }
-
-  static StringRef getOperationName() { return "memref.dma_start"; }
-  static ParseResult parse(OpAsmParser &parser, OperationState &result);
-  void print(OpAsmPrinter &p);
-  LogicalResult verify();
-
-  LogicalResult fold(ArrayRef<Attribute> cstOperands,
-                     SmallVectorImpl<OpFoldResult> &results);
-
-  bool isStrided() {
-    return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
-                                   1 + 1 + getTagMemRefRank();
-  }
-
-  Value getStride() {
-    if (!isStrided())
-      return nullptr;
-    return getOperand(getNumOperands() - 1 - 1);
-  }
-
-  Value getNumElementsPerStride() {
-    if (!isStrided())
-      return nullptr;
-    return getOperand(getNumOperands() - 1);
-  }
-};
-
-// DmaWaitOp blocks until the completion of a DMA operation associated with the
-// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
-// with the same restrictions as any load/store index. %num_elements is the
-// number of elements associated with the DMA operation. For example:
-//
-//   dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
-//     memref<2048 x f32>, (d0) -> (d0), 0>,
-//     memref<256 x f32>, (d0) -> (d0), 1>
-//     memref<1 x i32>, (d0) -> (d0), 2>
-//   ...
-//   ...
-//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
-//
-class DmaWaitOp
-    : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
-  using Op::Op;
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
-                    ValueRange tagIndices, Value numElements);
-
-  static StringRef getOperationName() { return "memref.dma_wait"; }
-
-  // Returns the Tag MemRef associated with the DMA operation being waited on.
-  Value getTagMemRef() { return getOperand(0); }
-
-  // Returns the tag memref index for this DMA operation.
-  operand_range getTagIndices() {
-    return {(*this)->operand_begin() + 1,
-            (*this)->operand_begin() + 1 + getTagMemRefRank()};
-  }
-
-  // Returns the rank (number of indices) of the tag memref.
-  unsigned getTagMemRefRank() {
-    return getTagMemRef().getType().cast<MemRefType>().getRank();
-  }
-
-  // Returns the number of elements transferred in the associated DMA operation.
-  Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
-
-  static ParseResult parse(OpAsmParser &parser, OperationState &result);
-  void print(OpAsmPrinter &p);
-  LogicalResult fold(ArrayRef<Attribute> cstOperands,
-                     SmallVectorImpl<OpFoldResult> &results);
-  LogicalResult verify();
-};
-} // namespace memref
-} // namespace mlir
-
 #endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index dd8455a7f9190..630ab8621df2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -284,8 +284,6 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
   let verifier = ?;
 }
 
-
-
 //===----------------------------------------------------------------------===//
 // BufferCastOp
 //===----------------------------------------------------------------------===//
@@ -568,6 +566,217 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// DmaStartOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
+  let summary = "non-blocking DMA operation that starts a transfer";
+  let description = [{
+    DmaStartOp starts a non-blocking DMA operation that transfers data from a
+    source memref to a destination memref. The source and destination memref
+    need not be of the same dimensionality, but need to have the same elemental
+    type. The operands include the source and destination memref's each followed
+    by its indices, size of the data transfer in terms of the number of elements
+    (of the elemental type of the memref), a tag memref with its indices, and
+    optionally at the end, a stride and a number_of_elements_per_stride
+    arguments. The tag location is used by a DmaWaitOp to check for completion.
+    The indices of the source memref, destination memref, and the tag memref
+    have the same restrictions as any load/store. The optional stride arguments
+    should be of 'index' type, and specify a stride for the slower memory space
+    (memory space with a lower memory space id), transferring chunks of
+    number_of_elements_per_stride every stride until %num_elements are
+    transferred. Either both or no stride arguments should be specified. If the
+    source and destination locations overlap the behavior of this operation is
+    not defined.
+
+    For example, a DmaStartOp operation that transfers 256 elements of a memref
+    '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory
+    space 1 at indices [%k, %l], would be specified as follows:
+
+    ```mlir
+    %num_elements = constant 256
+    %idx = constant 0 : index
+    %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
+    dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
+      memref<40 x 128 x f32>, (d0) -> (d0), 0>,
+      memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
+      memref<1 x i32>, (d0) -> (d0), 2>
+    ```
+
+    If %stride and %num_elt_per_stride are specified, the DMA is expected to
+    transfer %num_elt_per_stride elements every %stride elements apart from
+    memory space 0 until %num_elements are transferred.
+
+    ```mlir
+    dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
+              %num_elt_per_stride :
+    ```
+
+    TODO: add additional operands to allow source and destination striding, and
+    multiple stride levels.
+    TODO: Consider replacing src/dst memref indices with view memrefs.
+  }];
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [
+    OpBuilder<(ins "Value":$srcMemRef, "ValueRange":$srcIndices,
+                   "Value":$destMemRef, "ValueRange":$destIndices,
+                   "Value":$numElements, "Value":$tagMemRef,
+                   "ValueRange":$tagIndices, CArg<"Value", "{}">:$stride,
+                   CArg<"Value", "{}">:$elementsPerStride)>
+  ];
+
+  let extraClassDeclaration = [{
+    // Returns the source MemRefType for this DMA operation.
+    Value getSrcMemRef() { return getOperand(0); }
+    // Returns the rank (number of indices) of the source MemRefType.
+    unsigned getSrcMemRefRank() {
+      return getSrcMemRef().getType().cast<MemRefType>().getRank();
+    }
+    // Returns the source memref indices for this DMA operation.
+    operand_range getSrcIndices() {
+      return {(*this)->operand_begin() + 1,
+              (*this)->operand_begin() + 1 + getSrcMemRefRank()};
+    }
+
+    // Returns the destination MemRefType for this DMA operations.
+    Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+    // Returns the rank (number of indices) of the destination MemRefType.
+    unsigned getDstMemRefRank() {
+      return getDstMemRef().getType().cast<MemRefType>().getRank();
+    }
+    unsigned getSrcMemorySpace() {
+      return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+    }
+    unsigned getDstMemorySpace() {
+      return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+    }
+
+    // Returns the destination memref indices for this DMA operation.
+    operand_range getDstIndices() {
+      return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
+              (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
+                  getDstMemRefRank()};
+    }
+
+    // Returns the number of elements being transferred by this DMA operation.
+    Value getNumElements() {
+      return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
+    }
+
+    // Returns the Tag MemRef for this DMA operation.
+    Value getTagMemRef() {
+      return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+    }
+    // Returns the rank (number of indices) of the tag MemRefType.
+    unsigned getTagMemRefRank() {
+      return getTagMemRef().getType().cast<MemRefType>().getRank();
+    }
+
+    // Returns the tag memref index for this DMA operation.
+    operand_range getTagIndices() {
+      unsigned tagIndexStartPos =
+          1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
+      return {(*this)->operand_begin() + tagIndexStartPos,
+              (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
+    }
+
+    /// Returns true if this is a DMA from a faster memory space to a slower
+    /// one.
+    bool isDestMemorySpaceFaster() {
+      return (getSrcMemorySpace() < getDstMemorySpace());
+    }
+
+    /// Returns true if this is a DMA from a slower memory space to a faster
+    /// one.
+    bool isSrcMemorySpaceFaster() {
+      // Assumes that a lower number is for a slower memory space.
+      return (getDstMemorySpace() < getSrcMemorySpace());
+    }
+
+    /// Given a DMA start operation, returns the operand position of either the
+    /// source or destination memref depending on the one that is at the higher
+    /// level of the memory hierarchy. Asserts failure if neither is true.
+    unsigned getFasterMemPos() {
+      assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
+      return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
+    }
+
+    bool isStrided() {
+      return getNumOperands() != 1 + getSrcMemRefRank() + 1 +
+                                 getDstMemRefRank() + 1 + 1 +
+                                 getTagMemRefRank();
+    }
+
+    Value getStride() {
+      if (!isStrided())
+        return nullptr;
+      return getOperand(getNumOperands() - 1 - 1);
+    }
+
+    Value getNumElementsPerStride() {
+      if (!isStrided())
+        return nullptr;
+      return getOperand(getNumOperands() - 1);
+    }
+  }];
+  let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// DmaWaitOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
+  let summary = "blocking DMA operation that waits for transfer completion";
+  let description = [{
+   DmaWaitOp blocks until the completion of a DMA operation associated with the
+   tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
+   with the same restrictions as any load/store index. %num_elements is the
+   number of elements associated with the DMA operation.
+
+   Example:
+
+   ```mlir
+    dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
+      memref<2048 x f32>, (d0) -> (d0), 0>,
+      memref<256 x f32>, (d0) -> (d0), 1>
+      memref<1 x i32>, (d0) -> (d0), 2>
+    ...
+    ...
+    dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
+    ```
+  }];
+  let arguments = (ins
+    AnyMemRef:$tagMemRef,
+    Variadic<Index>:$tagIndices,
+    Index:$numElements
+  );
+  let assemblyFormat = [{
+    $tagMemRef `[` $tagIndices `]` `,` $numElements attr-dict `:`
+    type($tagMemRef)
+  }];
+  let extraClassDeclaration = [{
+    /// Returns the Tag MemRef associated with the DMA operation being waited
+    /// on.
+    Value getTagMemRef() { return tagMemRef(); }
+
+    /// Returns the tag memref index for this DMA operation.
+    operand_range getTagIndices() { return tagIndices(); }
+
+    /// Returns the rank (number of indices) of the tag memref.
+    unsigned getTagMemRefRank() {
+      return getTagMemRef().getType().cast<MemRefType>().getRank();
+    }
+
+    /// Returns the number of elements transferred in the associated DMA
+    /// operation.
+    Value getNumElements() { return numElements(); }
+  }];
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // GetGlobalOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 3da6a89207ed3..f403cb5b1e78c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -34,9 +34,9 @@ struct MemRefInlinerInterface : public DialectInlinerInterface {
 } // end anonymous namespace
 
 void mlir::memref::MemRefDialect::initialize() {
-  addOperations<DmaStartOp, DmaWaitOp,
+  addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
-                >();
+      >();
   addInterfaces<MemRefInlinerInterface>();
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f80d373c41e0d..412f49232cadd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -909,16 +909,17 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result,
     result.addOperands({stride, elementsPerStride});
 }
 
-void DmaStartOp::print(OpAsmPrinter &p) {
-  p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
-    << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
-    << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
-  if (isStrided())
-    p << ", " << getStride() << ", " << getNumElementsPerStride();
+static void print(OpAsmPrinter &p, DmaStartOp op) {
+  p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
+    << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
+    << op.getNumElements() << ", " << op.getTagMemRef() << '['
+    << op.getTagIndices() << ']';
+  if (op.isStrided())
+    p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
 
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
-    << ", " << getTagMemRef().getType();
+  p.printOptionalAttrDict(op->getAttrs());
+  p << " : " << op.getSrcMemRef().getType() << ", "
+    << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
 }
 
 // Parse DmaStartOp.
@@ -929,7 +930,8 @@ void DmaStartOp::print(OpAsmPrinter &p) {
 //                       memref<1024 x f32, 2>,
 //                       memref<1 x i32>
 //
-ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseDmaStartOp(OpAsmParser &parser,
+                                   OperationState &result) {
   OpAsmParser::OperandType srcMemRefInfo;
   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
   OpAsmParser::OperandType dstMemRefInfo;
@@ -989,66 +991,67 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-LogicalResult DmaStartOp::verify() {
-  unsigned numOperands = getNumOperands();
+static LogicalResult verify(DmaStartOp op) {
+  unsigned numOperands = op.getNumOperands();
 
   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
   // the number of elements.
   if (numOperands < 4)
-    return emitOpError("expected at least 4 operands");
+    return op.emitOpError("expected at least 4 operands");
 
   // Check types of operands. The order of these calls is important: the later
   // calls rely on some type properties to compute the operand position.
   // 1. Source memref.
-  if (!getSrcMemRef().getType().isa<MemRefType>())
-    return emitOpError("expected source to be of memref type");
-  if (numOperands < getSrcMemRefRank() + 4)
-    return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
-                         << " operands";
-  if (!getSrcIndices().empty() &&
-      !llvm::all_of(getSrcIndices().getTypes(),
+  if (!op.getSrcMemRef().getType().isa<MemRefType>())
+    return op.emitOpError("expected source to be of memref type");
+  if (numOperands < op.getSrcMemRefRank() + 4)
+    return op.emitOpError()
+           << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
+  if (!op.getSrcIndices().empty() &&
+      !llvm::all_of(op.getSrcIndices().getTypes(),
                     [](Type t) { return t.isIndex(); }))
-    return emitOpError("expected source indices to be of index type");
+    return op.emitOpError("expected source indices to be of index type");
 
   // 2. Destination memref.
-  if (!getDstMemRef().getType().isa<MemRefType>())
-    return emitOpError("expected destination to be of memref type");
-  unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
+  if (!op.getDstMemRef().getType().isa<MemRefType>())
+    return op.emitOpError("expected destination to be of memref type");
+  unsigned numExpectedOperands =
+      op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
   if (numOperands < numExpectedOperands)
-    return emitOpError() << "expected at least " << numExpectedOperands
-                         << " operands";
-  if (!getDstIndices().empty() &&
-      !llvm::all_of(getDstIndices().getTypes(),
+    return op.emitOpError()
+           << "expected at least " << numExpectedOperands << " operands";
+  if (!op.getDstIndices().empty() &&
+      !llvm::all_of(op.getDstIndices().getTypes(),
                     [](Type t) { return t.isIndex(); }))
-    return emitOpError("expected destination indices to be of index type");
+    return op.emitOpError("expected destination indices to be of index type");
 
   // 3. Number of elements.
-  if (!getNumElements().getType().isIndex())
-    return emitOpError("expected num elements to be of index type");
+  if (!op.getNumElements().getType().isIndex())
+    return op.emitOpError("expected num elements to be of index type");
 
   // 4. Tag memref.
-  if (!getTagMemRef().getType().isa<MemRefType>())
-    return emitOpError("expected tag to be of memref type");
-  numExpectedOperands += getTagMemRefRank();
+  if (!op.getTagMemRef().getType().isa<MemRefType>())
+    return op.emitOpError("expected tag to be of memref type");
+  numExpectedOperands += op.getTagMemRefRank();
   if (numOperands < numExpectedOperands)
-    return emitOpError() << "expected at least " << numExpectedOperands
-                         << " operands";
-  if (!getTagIndices().empty() &&
-      !llvm::all_of(getTagIndices().getTypes(),
+    return op.emitOpError()
+           << "expected at least " << numExpectedOperands << " operands";
+  if (!op.getTagIndices().empty() &&
+      !llvm::all_of(op.getTagIndices().getTypes(),
                     [](Type t) { return t.isIndex(); }))
-    return emitOpError("expected tag indices to be of index type");
+    return op.emitOpError("expected tag indices to be of index type");
 
   // Optional stride-related operands must be either both present or both
   // absent.
   if (numOperands != numExpectedOperands &&
       numOperands != numExpectedOperands + 2)
-    return emitOpError("incorrect number of operands");
+    return op.emitOpError("incorrect number of operands");
 
   // 5. Strides.
-  if (isStrided()) {
-    if (!getStride().getType().isIndex() ||
-        !getNumElementsPerStride().getType().isIndex())
-      return emitOpError(
+  if (op.isStrided()) {
+    if (!op.getStride().getType().isIndex() ||
+        !op.getNumElementsPerStride().getType().isIndex())
+      return op.emitOpError(
           "expected stride and num elements per stride to be of type index");
   }
 
@@ -1065,74 +1068,20 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
 // DmaWaitOp
 // ---------------------------------------------------------------------------
 
-void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
-                      Value tagMemRef, ValueRange tagIndices,
-                      Value numElements) {
-  result.addOperands(tagMemRef);
-  result.addOperands(tagIndices);
-  result.addOperands(numElements);
-}
-
-void DmaWaitOp::print(OpAsmPrinter &p) {
-  p << " " << getTagMemRef() << '[' << getTagIndices() << "], "
-    << getNumElements();
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : " << getTagMemRef().getType();
-}
-
-// Parse DmaWaitOp.
-// Eg:
-//   dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
-//
-ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType tagMemrefInfo;
-  SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
-  Type type;
-  auto indexType = parser.getBuilder().getIndexType();
-  OpAsmParser::OperandType numElementsInfo;
-
-  // Parse tag memref, its indices, and dma size.
-  if (parser.parseOperand(tagMemrefInfo) ||
-      parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
-      parser.parseComma() || parser.parseOperand(numElementsInfo) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
-      parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
-      parser.resolveOperand(numElementsInfo, indexType, result.operands))
-    return failure();
-
-  return success();
-}
-
 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
                               SmallVectorImpl<OpFoldResult> &results) {
   /// dma_wait(memrefcast) -> dma_wait
   return foldMemRefCast(*this);
 }
 
-LogicalResult DmaWaitOp::verify() {
-  // Mandatory non-variadic operands are tag and the number of elements.
-  if (getNumOperands() < 2)
-    return emitOpError() << "expected at least 2 operands";
-
-  // Check types of operands. The order of these calls is important: the later
-  // calls rely on some type properties to compute the operand position.
-  if (!getTagMemRef().getType().isa<MemRefType>())
-    return emitOpError() << "expected tag to be of memref type";
-
-  if (getNumOperands() != 2 + getTagMemRefRank())
-    return emitOpError() << "expected " << 2 + getTagMemRefRank()
-                         << " operands";
-
-  if (!getTagIndices().empty() &&
-      !llvm::all_of(getTagIndices().getTypes(),
-                    [](Type t) { return t.isIndex(); }))
-    return emitOpError() << "expected tag indices to be of index type";
-
-  if (!getNumElements().getType().isIndex())
-    return emitOpError()
-           << "expected the number of elements to be of index type";
-
+static LogicalResult verify(DmaWaitOp op) {
+  // Check that the number of tag indices matches the tagMemRef rank.
+  unsigned numTagIndices = op.tagIndices().size();
+  unsigned tagMemRefRank = op.getTagMemRefRank();
+  if (numTagIndices != tagMemRefRank)
+    return op.emitOpError() << "expected tagIndices to have the same number of "
+                               "elements as the tagMemRef rank, expected "
+                            << tagMemRefRank << ", but got " << numTagIndices;
   return success();
 }
 

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index b93815533119c..8d0a20bf09886 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1,5 +1,132 @@
 // RUN: mlir-opt -split-input-file %s -verify-diagnostics
 
+func @dma_start_not_enough_operands() {
+  // expected-error at +1 {{expected at least 4 operands}}
+  "memref.dma_start"() : () -> ()
+}
+
+// -----
+
+func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
+  // expected-error at +1 {{expected source to be of memref type}}
+  memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_src(
+    %src: memref<2x2x2xf32>, %idx: index) {
+  // expected-error at +1 {{expected at least 7 operands}}
+  "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_src_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected source indices to be of index type}}
+  "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
+      : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
+}
+
+// -----
+
+func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
+  %mref = memref.alloc() : memref<8 x f32>
+  // expected-error at +1 {{expected destination to be of memref type}}
+  memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_dst(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>) {
+  // expected-error at +1 {{expected at least 7 operands}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected destination indices to be of index type}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected num elements to be of index type}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
+}
+
+// -----
+
+func @dma_no_tag_memref(%tag : f32, %c0 : index) {
+  %mref = memref.alloc() : memref<8 x f32>
+  // expected-error at +1 {{expected tag to be of memref type}}
+  memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_tag(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<2xi32,2>) {
+  // expected-error at +1 {{expected at least 8 operands}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<2xi32,2>, %flt: f32) {
+  // expected-error at +1 {{expected tag indices to be of index type}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
+}
+
+// -----
+
+func @dma_start_too_many_operands(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>) {
+  // expected-error at +1 {{incorrect number of operands}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
+}
+
+
+// -----
+
+func @dma_start_wrong_stride_type(
+    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+    %tag: memref<i32,2>, %flt: f32) {
+  // expected-error at +1 {{expected stride and num elements per stride to be of type index}}
+  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
+      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
+}
+
+// -----
+
+func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt: index) {
+  // expected-error at +1 {{expected tagIndices to have the same number of elements as the tagMemRef rank, expected 2, but got 1}}
+  "memref.dma_wait"(%tag, %flt, %idx) : (memref<2x2xi32>, index, index) -> ()
+  return
+}
+
+// -----
+
 func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
   // expected-error @+1 {{expected a permutation map}}
   memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 265f095fe2272..3bfc2546dc4bf 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -290,153 +290,6 @@ func @invalid_cmp_shape(%idx : () -> ()) {
 
 // -----
 
-func @dma_start_not_enough_operands() {
-  // expected-error at +1 {{expected at least 4 operands}}
-  "memref.dma_start"() : () -> ()
-}
-
-// -----
-
-func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
-  // expected-error at +1 {{expected source to be of memref type}}
-  memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_src(
-    %src: memref<2x2x2xf32>, %idx: index) {
-  // expected-error at +1 {{expected at least 7 operands}}
-  "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
-}
-
-// -----
-
-func @dma_start_src_index_wrong_type(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>, %flt: f32) {
-  // expected-error at +1 {{expected source indices to be of index type}}
-  "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
-      : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
-}
-
-// -----
-
-func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
-  %mref = memref.alloc() : memref<8 x f32>
-  // expected-error at +1 {{expected destination to be of memref type}}
-  memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_dst(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>) {
-  // expected-error at +1 {{expected at least 7 operands}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>, %flt: f32) {
-  // expected-error at +1 {{expected destination indices to be of index type}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>, %flt: f32) {
-  // expected-error at +1 {{expected num elements to be of index type}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
-}
-
-// -----
-
-func @dma_no_tag_memref(%tag : f32, %c0 : index) {
-  %mref = memref.alloc() : memref<8 x f32>
-  // expected-error at +1 {{expected tag to be of memref type}}
-  memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_tag(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<2xi32,2>) {
-  // expected-error at +1 {{expected at least 8 operands}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<2xi32,2>, %flt: f32) {
-  // expected-error at +1 {{expected tag indices to be of index type}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
-}
-
-// -----
-
-func @dma_start_too_many_operands(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>) {
-  // expected-error at +1 {{incorrect number of operands}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
-}
-
-
-// -----
-
-func @dma_start_wrong_stride_type(
-    %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
-    %tag: memref<i32,2>, %flt: f32) {
-  // expected-error at +1 {{expected stride and num elements per stride to be of type index}}
-  "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
-      : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
-}
-
-// -----
-
-func @dma_wait_not_enough_operands() {
-  // expected-error at +1 {{expected at least 2 operands}}
-  "memref.dma_wait"() : () -> ()
-}
-
-// -----
-
-func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
-  // expected-error at +1 {{expected tag to be of memref type}}
-  "memref.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
-}
-
-// -----
-
-func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
-  // expected-error at +1 {{expected tag indices to be of index type}}
-  "memref.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
-}
-
-// -----
-
-func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
-  // expected-error at +1 {{expected the number of elements to be of index type}}
-  "memref.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
-}
-
-// -----
-
 func @invalid_cmp_attr(%idx : i32) {
   // expected-error at +1 {{expected string or keyword containing one of the following enum values}}
   %cmp = cmpi i1, %idx, %idx : i32


        


More information about the Mlir-commits mailing list