[Mlir-commits] [mlir] aca9bea - [mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS
River Riddle
llvmlistbot at llvm.org
Fri Sep 24 12:39:50 PDT 2021
Author: River Riddle
Date: 2021-09-24T19:35:28Z
New Revision: aca9bea1992ce270d094105ae8968c703b8ffb65
URL: https://github.com/llvm/llvm-project/commit/aca9bea1992ce270d094105ae8968c703b8ffb65
DIFF: https://github.com/llvm/llvm-project/commit/aca9bea1992ce270d094105ae8968c703b8ffb65.diff
LOG: [mlir:MemRef] Move DmaStartOp/DmaWaitOp to ODS
These are among the last operations still defined explicitly in C++. I've
tried to keep this commit as NFC as possible, but these ops
definitely need a non-NFC cleanup at some point.
Differential Revision: https://reviews.llvm.org/D110440
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 8c61248567f6a..a022ffc8dbc20 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -46,206 +46,4 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc"
-namespace mlir {
-namespace memref {
-// DmaStartOp starts a non-blocking DMA operation that transfers data from a
-// source memref to a destination memref. The source and destination memref need
-// not be of the same dimensionality, but need to have the same elemental type.
-// The operands include the source and destination memref's each followed by its
-// indices, size of the data transfer in terms of the number of elements (of the
-// elemental type of the memref), a tag memref with its indices, and optionally
-// at the end, a stride and a number_of_elements_per_stride arguments. The tag
-// location is used by a DmaWaitOp to check for completion. The indices of the
-// source memref, destination memref, and the tag memref have the same
-// restrictions as any load/store. The optional stride arguments should be of
-// 'index' type, and specify a stride for the slower memory space (memory space
-// with a lower memory space id), transferring chunks of
-// number_of_elements_per_stride every stride until %num_elements are
-// transferred. Either both or no stride arguments should be specified. If the
-// source and destination locations overlap the behavior of this operation is
-// not defined.
-//
-// For example, a DmaStartOp operation that transfers 256 elements of a memref
-// '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
-// 1 at indices [%k, %l], would be specified as follows:
-//
-// %num_elements = constant 256
-// %idx = constant 0 : index
-// %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
-// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
-// memref<40 x 128 x f32>, (d0) -> (d0), 0>,
-// memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
-// memref<1 x i32>, (d0) -> (d0), 2>
-//
-// If %stride and %num_elt_per_stride are specified, the DMA is expected to
-// transfer %num_elt_per_stride elements every %stride elements apart from
-// memory space 0 until %num_elements are transferred.
-//
-// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
-// %num_elt_per_stride :
-//
-// TODO: add additional operands to allow source and destination striding, and
-// multiple stride levels.
-// TODO: Consider replacing src/dst memref indices with view memrefs.
-class DmaStartOp
- : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
- using Op::Op;
- static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
- static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
- ValueRange srcIndices, Value destMemRef,
- ValueRange destIndices, Value numElements, Value tagMemRef,
- ValueRange tagIndices, Value stride = nullptr,
- Value elementsPerStride = nullptr);
-
- // Returns the source MemRefType for this DMA operation.
- Value getSrcMemRef() { return getOperand(0); }
- // Returns the rank (number of indices) of the source MemRefType.
- unsigned getSrcMemRefRank() {
- return getSrcMemRef().getType().cast<MemRefType>().getRank();
- }
- // Returns the source memref indices for this DMA operation.
- operand_range getSrcIndices() {
- return {(*this)->operand_begin() + 1,
- (*this)->operand_begin() + 1 + getSrcMemRefRank()};
- }
-
- // Returns the destination MemRefType for this DMA operations.
- Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
- // Returns the rank (number of indices) of the destination MemRefType.
- unsigned getDstMemRefRank() {
- return getDstMemRef().getType().cast<MemRefType>().getRank();
- }
- unsigned getSrcMemorySpace() {
- return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
- }
- unsigned getDstMemorySpace() {
- return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
- }
-
- // Returns the destination memref indices for this DMA operation.
- operand_range getDstIndices() {
- return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
- (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
- getDstMemRefRank()};
- }
-
- // Returns the number of elements being transferred by this DMA operation.
- Value getNumElements() {
- return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
- }
-
- // Returns the Tag MemRef for this DMA operation.
- Value getTagMemRef() {
- return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
- }
- // Returns the rank (number of indices) of the tag MemRefType.
- unsigned getTagMemRefRank() {
- return getTagMemRef().getType().cast<MemRefType>().getRank();
- }
-
- // Returns the tag memref index for this DMA operation.
- operand_range getTagIndices() {
- unsigned tagIndexStartPos =
- 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
- return {(*this)->operand_begin() + tagIndexStartPos,
- (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
- }
-
- /// Returns true if this is a DMA from a faster memory space to a slower one.
- bool isDestMemorySpaceFaster() {
- return (getSrcMemorySpace() < getDstMemorySpace());
- }
-
- /// Returns true if this is a DMA from a slower memory space to a faster one.
- bool isSrcMemorySpaceFaster() {
- // Assumes that a lower number is for a slower memory space.
- return (getDstMemorySpace() < getSrcMemorySpace());
- }
-
- /// Given a DMA start operation, returns the operand position of either the
- /// source or destination memref depending on the one that is at the higher
- /// level of the memory hierarchy. Asserts failure if neither is true.
- unsigned getFasterMemPos() {
- assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
- return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
- }
-
- static StringRef getOperationName() { return "memref.dma_start"; }
- static ParseResult parse(OpAsmParser &parser, OperationState &result);
- void print(OpAsmPrinter &p);
- LogicalResult verify();
-
- LogicalResult fold(ArrayRef<Attribute> cstOperands,
- SmallVectorImpl<OpFoldResult> &results);
-
- bool isStrided() {
- return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
- 1 + 1 + getTagMemRefRank();
- }
-
- Value getStride() {
- if (!isStrided())
- return nullptr;
- return getOperand(getNumOperands() - 1 - 1);
- }
-
- Value getNumElementsPerStride() {
- if (!isStrided())
- return nullptr;
- return getOperand(getNumOperands() - 1);
- }
-};
-
-// DmaWaitOp blocks until the completion of a DMA operation associated with the
-// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
-// with the same restrictions as any load/store index. %num_elements is the
-// number of elements associated with the DMA operation. For example:
-//
-// dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
-// memref<2048 x f32>, (d0) -> (d0), 0>,
-// memref<256 x f32>, (d0) -> (d0), 1>
-// memref<1 x i32>, (d0) -> (d0), 2>
-// ...
-// ...
-// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
-//
-class DmaWaitOp
- : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
-public:
- using Op::Op;
- static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
- static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
- ValueRange tagIndices, Value numElements);
-
- static StringRef getOperationName() { return "memref.dma_wait"; }
-
- // Returns the Tag MemRef associated with the DMA operation being waited on.
- Value getTagMemRef() { return getOperand(0); }
-
- // Returns the tag memref index for this DMA operation.
- operand_range getTagIndices() {
- return {(*this)->operand_begin() + 1,
- (*this)->operand_begin() + 1 + getTagMemRefRank()};
- }
-
- // Returns the rank (number of indices) of the tag memref.
- unsigned getTagMemRefRank() {
- return getTagMemRef().getType().cast<MemRefType>().getRank();
- }
-
- // Returns the number of elements transferred in the associated DMA operation.
- Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
-
- static ParseResult parse(OpAsmParser &parser, OperationState &result);
- void print(OpAsmPrinter &p);
- LogicalResult fold(ArrayRef<Attribute> cstOperands,
- SmallVectorImpl<OpFoldResult> &results);
- LogicalResult verify();
-};
-} // namespace memref
-} // namespace mlir
-
#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index dd8455a7f9190..630ab8621df2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -284,8 +284,6 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
let verifier = ?;
}
-
-
//===----------------------------------------------------------------------===//
// BufferCastOp
//===----------------------------------------------------------------------===//
@@ -568,6 +566,217 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// DmaStartOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
+ let summary = "non-blocking DMA operation that starts a transfer";
+ let description = [{
+ DmaStartOp starts a non-blocking DMA operation that transfers data from a
+ source memref to a destination memref. The source and destination memref
+ need not be of the same dimensionality, but need to have the same elemental
+ type. The operands include the source and destination memref's each followed
+ by its indices, size of the data transfer in terms of the number of elements
+ (of the elemental type of the memref), a tag memref with its indices, and
+ optionally at the end, a stride and a number_of_elements_per_stride
+ arguments. The tag location is used by a DmaWaitOp to check for completion.
+ The indices of the source memref, destination memref, and the tag memref
+ have the same restrictions as any load/store. The optional stride arguments
+ should be of 'index' type, and specify a stride for the slower memory space
+ (memory space with a lower memory space id), transferring chunks of
+ number_of_elements_per_stride every stride until %num_elements are
+ transferred. Either both or no stride arguments should be specified. If the
+ source and destination locations overlap the behavior of this operation is
+ not defined.
+
+ For example, a DmaStartOp operation that transfers 256 elements of a memref
+ '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory
+ space 1 at indices [%k, %l], would be specified as follows:
+
+ ```mlir
+ %num_elements = constant 256
+ %idx = constant 0 : index
+ %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
+ dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
+ memref<40 x 128 x f32>, (d0) -> (d0), 0>,
+ memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
+ memref<1 x i32>, (d0) -> (d0), 2>
+ ```
+
+ If %stride and %num_elt_per_stride are specified, the DMA is expected to
+ transfer %num_elt_per_stride elements every %stride elements apart from
+ memory space 0 until %num_elements are transferred.
+
+ ```mlir
+ dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
+ %num_elt_per_stride :
+ ```
+
+ TODO: add additional operands to allow source and destination striding, and
+ multiple stride levels.
+ TODO: Consider replacing src/dst memref indices with view memrefs.
+ }];
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [
+ OpBuilder<(ins "Value":$srcMemRef, "ValueRange":$srcIndices,
+ "Value":$destMemRef, "ValueRange":$destIndices,
+ "Value":$numElements, "Value":$tagMemRef,
+ "ValueRange":$tagIndices, CArg<"Value", "{}">:$stride,
+ CArg<"Value", "{}">:$elementsPerStride)>
+ ];
+
+ let extraClassDeclaration = [{
+ // Returns the source MemRefType for this DMA operation.
+ Value getSrcMemRef() { return getOperand(0); }
+ // Returns the rank (number of indices) of the source MemRefType.
+ unsigned getSrcMemRefRank() {
+ return getSrcMemRef().getType().cast<MemRefType>().getRank();
+ }
+ // Returns the source memref indices for this DMA operation.
+ operand_range getSrcIndices() {
+ return {(*this)->operand_begin() + 1,
+ (*this)->operand_begin() + 1 + getSrcMemRefRank()};
+ }
+
+ // Returns the destination MemRefType for this DMA operations.
+ Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+ // Returns the rank (number of indices) of the destination MemRefType.
+ unsigned getDstMemRefRank() {
+ return getDstMemRef().getType().cast<MemRefType>().getRank();
+ }
+ unsigned getSrcMemorySpace() {
+ return getSrcMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+ }
+ unsigned getDstMemorySpace() {
+ return getDstMemRef().getType().cast<MemRefType>().getMemorySpaceAsInt();
+ }
+
+ // Returns the destination memref indices for this DMA operation.
+ operand_range getDstIndices() {
+ return {(*this)->operand_begin() + 1 + getSrcMemRefRank() + 1,
+ (*this)->operand_begin() + 1 + getSrcMemRefRank() + 1 +
+ getDstMemRefRank()};
+ }
+
+ // Returns the number of elements being transferred by this DMA operation.
+ Value getNumElements() {
+ return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
+ }
+
+ // Returns the Tag MemRef for this DMA operation.
+ Value getTagMemRef() {
+ return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+ }
+ // Returns the rank (number of indices) of the tag MemRefType.
+ unsigned getTagMemRefRank() {
+ return getTagMemRef().getType().cast<MemRefType>().getRank();
+ }
+
+ // Returns the tag memref index for this DMA operation.
+ operand_range getTagIndices() {
+ unsigned tagIndexStartPos =
+ 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
+ return {(*this)->operand_begin() + tagIndexStartPos,
+ (*this)->operand_begin() + tagIndexStartPos + getTagMemRefRank()};
+ }
+
+ /// Returns true if this is a DMA from a faster memory space to a slower
+ /// one.
+ bool isDestMemorySpaceFaster() {
+ return (getSrcMemorySpace() < getDstMemorySpace());
+ }
+
+ /// Returns true if this is a DMA from a slower memory space to a faster
+ /// one.
+ bool isSrcMemorySpaceFaster() {
+ // Assumes that a lower number is for a slower memory space.
+ return (getDstMemorySpace() < getSrcMemorySpace());
+ }
+
+ /// Given a DMA start operation, returns the operand position of either the
+ /// source or destination memref depending on the one that is at the higher
+ /// level of the memory hierarchy. Asserts failure if neither is true.
+ unsigned getFasterMemPos() {
+ assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
+ return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
+ }
+
+ bool isStrided() {
+ return getNumOperands() != 1 + getSrcMemRefRank() + 1 +
+ getDstMemRefRank() + 1 + 1 +
+ getTagMemRefRank();
+ }
+
+ Value getStride() {
+ if (!isStrided())
+ return nullptr;
+ return getOperand(getNumOperands() - 1 - 1);
+ }
+
+ Value getNumElementsPerStride() {
+ if (!isStrided())
+ return nullptr;
+ return getOperand(getNumOperands() - 1);
+ }
+ }];
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// DmaWaitOp
+//===----------------------------------------------------------------------===//
+
+def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
+ let summary = "blocking DMA operation that waits for transfer completion";
+ let description = [{
+ DmaWaitOp blocks until the completion of a DMA operation associated with the
+ tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
+ with the same restrictions as any load/store index. %num_elements is the
+ number of elements associated with the DMA operation.
+
+ Example:
+
+ ```mlir
+ dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
+ memref<2048 x f32>, (d0) -> (d0), 0>,
+ memref<256 x f32>, (d0) -> (d0), 1>
+ memref<1 x i32>, (d0) -> (d0), 2>
+ ...
+ ...
+ dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
+ ```
+ }];
+ let arguments = (ins
+ AnyMemRef:$tagMemRef,
+ Variadic<Index>:$tagIndices,
+ Index:$numElements
+ );
+ let assemblyFormat = [{
+ $tagMemRef `[` $tagIndices `]` `,` $numElements attr-dict `:`
+ type($tagMemRef)
+ }];
+ let extraClassDeclaration = [{
+ /// Returns the Tag MemRef associated with the DMA operation being waited
+ /// on.
+ Value getTagMemRef() { return tagMemRef(); }
+
+ /// Returns the tag memref index for this DMA operation.
+ operand_range getTagIndices() { return tagIndices(); }
+
+ /// Returns the rank (number of indices) of the tag memref.
+ unsigned getTagMemRefRank() {
+ return getTagMemRef().getType().cast<MemRefType>().getRank();
+ }
+
+ /// Returns the number of elements transferred in the associated DMA
+ /// operation.
+ Value getNumElements() { return numElements(); }
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
index 3da6a89207ed3..f403cb5b1e78c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp
@@ -34,9 +34,9 @@ struct MemRefInlinerInterface : public DialectInlinerInterface {
} // end anonymous namespace
void mlir::memref::MemRefDialect::initialize() {
- addOperations<DmaStartOp, DmaWaitOp,
+ addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
- >();
+ >();
addInterfaces<MemRefInlinerInterface>();
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f80d373c41e0d..412f49232cadd 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -909,16 +909,17 @@ void DmaStartOp::build(OpBuilder &builder, OperationState &result,
result.addOperands({stride, elementsPerStride});
}
-void DmaStartOp::print(OpAsmPrinter &p) {
- p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
- << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
- << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
- if (isStrided())
- p << ", " << getStride() << ", " << getNumElementsPerStride();
+static void print(OpAsmPrinter &p, DmaStartOp op) {
+ p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
+ << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
+ << op.getNumElements() << ", " << op.getTagMemRef() << '['
+ << op.getTagIndices() << ']';
+ if (op.isStrided())
+ p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
- p.printOptionalAttrDict((*this)->getAttrs());
- p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
- << ", " << getTagMemRef().getType();
+ p.printOptionalAttrDict(op->getAttrs());
+ p << " : " << op.getSrcMemRef().getType() << ", "
+ << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
}
// Parse DmaStartOp.
@@ -929,7 +930,8 @@ void DmaStartOp::print(OpAsmPrinter &p) {
// memref<1024 x f32, 2>,
// memref<1 x i32>
//
-ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseDmaStartOp(OpAsmParser &parser,
+ OperationState &result) {
OpAsmParser::OperandType srcMemRefInfo;
SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
OpAsmParser::OperandType dstMemRefInfo;
@@ -989,66 +991,67 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-LogicalResult DmaStartOp::verify() {
- unsigned numOperands = getNumOperands();
+static LogicalResult verify(DmaStartOp op) {
+ unsigned numOperands = op.getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements.
if (numOperands < 4)
- return emitOpError("expected at least 4 operands");
+ return op.emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
- if (!getSrcMemRef().getType().isa<MemRefType>())
- return emitOpError("expected source to be of memref type");
- if (numOperands < getSrcMemRefRank() + 4)
- return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
- << " operands";
- if (!getSrcIndices().empty() &&
- !llvm::all_of(getSrcIndices().getTypes(),
+ if (!op.getSrcMemRef().getType().isa<MemRefType>())
+ return op.emitOpError("expected source to be of memref type");
+ if (numOperands < op.getSrcMemRefRank() + 4)
+ return op.emitOpError()
+ << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
+ if (!op.getSrcIndices().empty() &&
+ !llvm::all_of(op.getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return emitOpError("expected source indices to be of index type");
+ return op.emitOpError("expected source indices to be of index type");
// 2. Destination memref.
- if (!getDstMemRef().getType().isa<MemRefType>())
- return emitOpError("expected destination to be of memref type");
- unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
+ if (!op.getDstMemRef().getType().isa<MemRefType>())
+ return op.emitOpError("expected destination to be of memref type");
+ unsigned numExpectedOperands =
+ op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
- return emitOpError() << "expected at least " << numExpectedOperands
- << " operands";
- if (!getDstIndices().empty() &&
- !llvm::all_of(getDstIndices().getTypes(),
+ return op.emitOpError()
+ << "expected at least " << numExpectedOperands << " operands";
+ if (!op.getDstIndices().empty() &&
+ !llvm::all_of(op.getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return emitOpError("expected destination indices to be of index type");
+ return op.emitOpError("expected destination indices to be of index type");
// 3. Number of elements.
- if (!getNumElements().getType().isIndex())
- return emitOpError("expected num elements to be of index type");
+ if (!op.getNumElements().getType().isIndex())
+ return op.emitOpError("expected num elements to be of index type");
// 4. Tag memref.
- if (!getTagMemRef().getType().isa<MemRefType>())
- return emitOpError("expected tag to be of memref type");
- numExpectedOperands += getTagMemRefRank();
+ if (!op.getTagMemRef().getType().isa<MemRefType>())
+ return op.emitOpError("expected tag to be of memref type");
+ numExpectedOperands += op.getTagMemRefRank();
if (numOperands < numExpectedOperands)
- return emitOpError() << "expected at least " << numExpectedOperands
- << " operands";
- if (!getTagIndices().empty() &&
- !llvm::all_of(getTagIndices().getTypes(),
+ return op.emitOpError()
+ << "expected at least " << numExpectedOperands << " operands";
+ if (!op.getTagIndices().empty() &&
+ !llvm::all_of(op.getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
- return emitOpError("expected tag indices to be of index type");
+ return op.emitOpError("expected tag indices to be of index type");
// Optional stride-related operands must be either both present or both
// absent.
if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2)
- return emitOpError("incorrect number of operands");
+ return op.emitOpError("incorrect number of operands");
// 5. Strides.
- if (isStrided()) {
- if (!getStride().getType().isIndex() ||
- !getNumElementsPerStride().getType().isIndex())
- return emitOpError(
+ if (op.isStrided()) {
+ if (!op.getStride().getType().isIndex() ||
+ !op.getNumElementsPerStride().getType().isIndex())
+ return op.emitOpError(
"expected stride and num elements per stride to be of type index");
}
@@ -1065,74 +1068,20 @@ LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
// DmaWaitOp
// ---------------------------------------------------------------------------
-void DmaWaitOp::build(OpBuilder &builder, OperationState &result,
- Value tagMemRef, ValueRange tagIndices,
- Value numElements) {
- result.addOperands(tagMemRef);
- result.addOperands(tagIndices);
- result.addOperands(numElements);
-}
-
-void DmaWaitOp::print(OpAsmPrinter &p) {
- p << " " << getTagMemRef() << '[' << getTagIndices() << "], "
- << getNumElements();
- p.printOptionalAttrDict((*this)->getAttrs());
- p << " : " << getTagMemRef().getType();
-}
-
-// Parse DmaWaitOp.
-// Eg:
-// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
-//
-ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType tagMemrefInfo;
- SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
- Type type;
- auto indexType = parser.getBuilder().getIndexType();
- OpAsmParser::OperandType numElementsInfo;
-
- // Parse tag memref, its indices, and dma size.
- if (parser.parseOperand(tagMemrefInfo) ||
- parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
- parser.parseComma() || parser.parseOperand(numElementsInfo) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(tagMemrefInfo, type, result.operands) ||
- parser.resolveOperands(tagIndexInfos, indexType, result.operands) ||
- parser.resolveOperand(numElementsInfo, indexType, result.operands))
- return failure();
-
- return success();
-}
-
LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
return foldMemRefCast(*this);
}
-LogicalResult DmaWaitOp::verify() {
- // Mandatory non-variadic operands are tag and the number of elements.
- if (getNumOperands() < 2)
- return emitOpError() << "expected at least 2 operands";
-
- // Check types of operands. The order of these calls is important: the later
- // calls rely on some type properties to compute the operand position.
- if (!getTagMemRef().getType().isa<MemRefType>())
- return emitOpError() << "expected tag to be of memref type";
-
- if (getNumOperands() != 2 + getTagMemRefRank())
- return emitOpError() << "expected " << 2 + getTagMemRefRank()
- << " operands";
-
- if (!getTagIndices().empty() &&
- !llvm::all_of(getTagIndices().getTypes(),
- [](Type t) { return t.isIndex(); }))
- return emitOpError() << "expected tag indices to be of index type";
-
- if (!getNumElements().getType().isIndex())
- return emitOpError()
- << "expected the number of elements to be of index type";
-
+static LogicalResult verify(DmaWaitOp op) {
+ // Check that the number of tag indices matches the tagMemRef rank.
+ unsigned numTagIndices = op.tagIndices().size();
+ unsigned tagMemRefRank = op.getTagMemRefRank();
+ if (numTagIndices != tagMemRefRank)
+ return op.emitOpError() << "expected tagIndices to have the same number of "
+ "elements as the tagMemRef rank, expected "
+ << tagMemRefRank << ", but got " << numTagIndices;
return success();
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index b93815533119c..8d0a20bf09886 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1,5 +1,132 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+func @dma_start_not_enough_operands() {
+ // expected-error at +1 {{expected at least 4 operands}}
+ "memref.dma_start"() : () -> ()
+}
+
+// -----
+
+func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
+ // expected-error at +1 {{expected source to be of memref type}}
+ memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_src(
+ %src: memref<2x2x2xf32>, %idx: index) {
+ // expected-error at +1 {{expected at least 7 operands}}
+ "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_src_index_wrong_type(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>, %flt: f32) {
+ // expected-error at +1 {{expected source indices to be of index type}}
+ "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
+ : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
+}
+
+// -----
+
+func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
+ %mref = memref.alloc() : memref<8 x f32>
+ // expected-error at +1 {{expected destination to be of memref type}}
+ memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_dst(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>) {
+ // expected-error at +1 {{expected at least 7 operands}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>, %flt: f32) {
+ // expected-error at +1 {{expected destination indices to be of index type}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>, %flt: f32) {
+ // expected-error at +1 {{expected num elements to be of index type}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
+}
+
+// -----
+
+func @dma_no_tag_memref(%tag : f32, %c0 : index) {
+ %mref = memref.alloc() : memref<8 x f32>
+ // expected-error at +1 {{expected tag to be of memref type}}
+ memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
+}
+
+// -----
+
+func @dma_start_not_enough_operands_for_tag(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<2xi32,2>) {
+ // expected-error at +1 {{expected at least 8 operands}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
+}
+
+// -----
+
+func @dma_start_dst_index_wrong_type(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<2xi32,2>, %flt: f32) {
+ // expected-error at +1 {{expected tag indices to be of index type}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
+}
+
+// -----
+
+func @dma_start_too_many_operands(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>) {
+ // expected-error at +1 {{incorrect number of operands}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
+}
+
+
+// -----
+
+func @dma_start_wrong_stride_type(
+ %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
+ %tag: memref<i32,2>, %flt: f32) {
+ // expected-error at +1 {{expected stride and num elements per stride to be of type index}}
+ "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
+ : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
+}
+
+// -----
+
+func @dma_wait_wrong_index_type(%tag : memref<2x2xi32>, %idx: index, %flt: index) {
+ // expected-error at +1 {{expected tagIndices to have the same number of elements as the tagMemRef rank, expected 2, but got 1}}
+ "memref.dma_wait"(%tag, %flt, %idx) : (memref<2x2xi32>, index, index) -> ()
+ return
+}
+
+// -----
+
func @transpose_not_permutation(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{expected a permutation map}}
memref.transpose %v (i, j) -> (i, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 265f095fe2272..3bfc2546dc4bf 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -290,153 +290,6 @@ func @invalid_cmp_shape(%idx : () -> ()) {
// -----
-func @dma_start_not_enough_operands() {
- // expected-error at +1 {{expected at least 4 operands}}
- "memref.dma_start"() : () -> ()
-}
-
-// -----
-
-func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
- // expected-error at +1 {{expected source to be of memref type}}
- memref.dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_src(
- %src: memref<2x2x2xf32>, %idx: index) {
- // expected-error at +1 {{expected at least 7 operands}}
- "memref.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
-}
-
-// -----
-
-func @dma_start_src_index_wrong_type(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>, %flt: f32) {
- // expected-error at +1 {{expected source indices to be of index type}}
- "memref.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
- : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
-}
-
-// -----
-
-func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
- %mref = memref.alloc() : memref<8 x f32>
- // expected-error at +1 {{expected destination to be of memref type}}
- memref.dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_dst(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>) {
- // expected-error at +1 {{expected at least 7 operands}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>, %flt: f32) {
- // expected-error at +1 {{expected destination indices to be of index type}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>, %flt: f32) {
- // expected-error at +1 {{expected num elements to be of index type}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
-}
-
-// -----
-
-func @dma_no_tag_memref(%tag : f32, %c0 : index) {
- %mref = memref.alloc() : memref<8 x f32>
- // expected-error at +1 {{expected tag to be of memref type}}
- memref.dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32
-}
-
-// -----
-
-func @dma_start_not_enough_operands_for_tag(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<2xi32,2>) {
- // expected-error at +1 {{expected at least 8 operands}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
-}
-
-// -----
-
-func @dma_start_dst_index_wrong_type(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<2xi32,2>, %flt: f32) {
- // expected-error at +1 {{expected tag indices to be of index type}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
-}
-
-// -----
-
-func @dma_start_too_many_operands(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>) {
- // expected-error at +1 {{incorrect number of operands}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
-}
-
-
-// -----
-
-func @dma_start_wrong_stride_type(
- %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
- %tag: memref<i32,2>, %flt: f32) {
- // expected-error at +1 {{expected stride and num elements per stride to be of type index}}
- "memref.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
- : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
-}
-
-// -----
-
-func @dma_wait_not_enough_operands() {
- // expected-error at +1 {{expected at least 2 operands}}
- "memref.dma_wait"() : () -> ()
-}
-
-// -----
-
-func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
- // expected-error at +1 {{expected tag to be of memref type}}
- "memref.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
-}
-
-// -----
-
-func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
- // expected-error at +1 {{expected tag indices to be of index type}}
- "memref.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
-}
-
-// -----
-
-func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
- // expected-error at +1 {{expected the number of elements to be of index type}}
- "memref.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
-}
-
-// -----
-
func @invalid_cmp_attr(%idx : i32) {
// expected-error at +1 {{expected string or keyword containing one of the following enum values}}
%cmp = cmpi i1, %idx, %idx : i32
More information about the Mlir-commits
mailing list