[Mlir-commits] [mlir] 461605c - [mlir] Add MemRefReinterpretCastOp definition to Standard.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Oct 22 06:18:00 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-22T15:17:22+02:00
New Revision: 461605c418e9059aa50de65c60bbd49e8f270b4a
URL: https://github.com/llvm/llvm-project/commit/461605c418e9059aa50de65c60bbd49e8f270b4a
DIFF: https://github.com/llvm/llvm-project/commit/461605c418e9059aa50de65c60bbd49e8f270b4a.diff
LOG: [mlir] Add MemRefReinterpretCastOp definition to Standard.
Reuse most code for printing/parsing/verification from SubViewOp.
https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15
Differential Revision: https://https://reviews.llvm.org/D89720
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/Dialect/Standard/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6de8ace044cc..e3e1930580b4 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2217,6 +2217,52 @@ def MemRefCastOp : CastOp<"memref_cast", [
}];
}
+
+//===----------------------------------------------------------------------===//
+// MemRefReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+def MemRefReinterpretCastOp:
+ BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [
+ NoSideEffect, ViewLikeOpInterface
+ ]> {
+ let summary = "memref reinterpret cast operation";
+ let description = [{
+ Modify offset, sizes and strides of an unranked/ranked memref.
+
+ Example:
+ ```mlir
+ memref_reinterpret_cast %ranked to
+ offset: [0],
+ sizes: [%size0, 10],
+ strides: [1, %stride1]
+ : memref<?x?xf32> to memref<?x10xf32, offset: 0, strides: [1, ?]>
+
+ memref_reinterpret_cast %unranked to
+ offset: [%offset],
+ sizes: [%size0, %size1],
+ strides: [%stride0, %stride1]
+ : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ ```
+ }];
+
+ let arguments = (ins
+ Arg<AnyRankedOrUnrankedMemRef, "", []>:$source,
+ Variadic<Index>:$offsets,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides,
+ I64ArrayAttr:$static_offsets,
+ I64ArrayAttr:$static_sizes,
+ I64ArrayAttr:$static_strides
+ );
+ let results = (outs AnyMemRef:$result);
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
+ // The result of the op is always a ranked memref.
+ MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+ Value getViewSource() { return source(); }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MemRefReshapeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7010ad5f34c9..4638ae4e9268 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -261,6 +261,126 @@ OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a + b; });
}
+//===----------------------------------------------------------------------===//
+// BaseOpWithOffsetSizesAndStridesOp
+//===----------------------------------------------------------------------===//
+
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void
+printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+ ArrayAttr arrayAttr,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ p << '[';
+ unsigned idx = 0;
+ llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+ int64_t val = a.cast<IntegerAttr>().getInt();
+ if (isDynamic(val))
+ p << values[idx++];
+ else
+ p << val;
+ });
+ p << ']';
+}
+
+/// Parse a mixed list with either (1) static integer values or (2) SSA values.
+/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
+/// encode the position of SSA values. Add the parsed SSA values to `ssa`
+/// in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+/// 2. `ssa` is filled with "[%arg0, %arg1]".
+static ParseResult
+parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
+ StringRef attrName, int64_t dynVal,
+ SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ // 0-D.
+ if (succeeded(parser.parseOptionalRSquare())) {
+ result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
+ return success();
+ }
+
+ SmallVector<int64_t, 4> attrVals;
+ while (true) {
+ OpAsmParser::OperandType operand;
+ auto res = parser.parseOptionalOperand(operand);
+ if (res.hasValue() && succeeded(res.getValue())) {
+ ssa.push_back(operand);
+ attrVals.push_back(dynVal);
+ } else {
+ IntegerAttr attr;
+ if (failed(parser.parseAttribute<IntegerAttr>(attr)))
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ attrVals.push_back(attr.getInt());
+ }
+
+ if (succeeded(parser.parseOptionalComma()))
+ continue;
+ if (failed(parser.parseRSquare()))
+ return failure();
+ break;
+ }
+
+ auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
+ result.addAttribute(attrName, arrayAttr);
+ return success();
+}
+
+/// Verify that a particular offset/size/stride static attribute is well-formed.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
+ OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName,
+ ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
+ ValueRange values) {
+ /// Check static and dynamic offsets/sizes/strides breakdown.
+ if (attr.size() != expectedNumElements)
+ return op.emitError("expected ")
+ << expectedNumElements << " " << name << " values";
+ unsigned expectedNumDynamicEntries =
+ llvm::count_if(attr.getValue(), [&](Attribute attr) {
+ return isDynamic(attr.cast<IntegerAttr>().getInt());
+ });
+ if (values.size() != expectedNumDynamicEntries)
+ return op.emitError("expected ")
+ << expectedNumDynamicEntries << " dynamic " << name << " values";
+ return success();
+}
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
+}
+
+/// Verify static attributes offsets/sizes/strides.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
+ unsigned srcRank = op.getSourceRank();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "offset", srcRank, op.getStaticOffsetsAttrName(),
+ op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
+ op.offsets())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(),
+ ShapedType::isDynamic, op.sizes())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "stride", srcRank, op.getStaticStridesAttrName(),
+ op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+ op.strides())))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@@ -2145,6 +2265,169 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
+//===----------------------------------------------------------------------===//
+// MemRefReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+/// Print of the form:
+/// ```
+/// `name` ssa-name to
+/// offset: `[` offset `]`
+/// sizes: `[` size-list `]`
+/// strides:`[` stride-list `]`
+/// `:` any-memref-type to strided-memref-type
+/// ```
+static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op.getOperationName().drop_front(stdDotLen) << " " << op.source()
+ << " to offset: ";
+ printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ", sizes: ";
+ printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ p << ", strides: ";
+ printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p.printOptionalAttrDict(
+ op.getAttrs(),
+ /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+ MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+ MemRefReinterpretCastOp::getStaticSizesAttrName(),
+ MemRefReinterpretCastOp::getStaticStridesAttrName()});
+ p << ": " << op.source().getType() << " to " << op.getType();
+}
+
+/// Parse of the form:
+/// ```
+/// `name` ssa-name to
+/// offset: `[` offset `]`
+/// sizes: `[` size-list `]`
+/// strides:`[` stride-list `]`
+/// `:` any-memref-type to strided-memref-type
+/// ```
+static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse `operand` and `offset`.
+ OpAsmParser::OperandType operand;
+ if (parser.parseOperand(operand))
+ return failure();
+
+ // Parse offset.
+ SmallVector<OpAsmParser::OperandType, 1> offset;
+ if (parser.parseKeyword("to") || parser.parseKeyword("offset") ||
+ parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+ ShapedType::kDynamicStrideOrOffset, offset) ||
+ parser.parseComma())
+ return failure();
+
+ // Parse `sizes`.
+ SmallVector<OpAsmParser::OperandType, 4> sizes;
+ if (parser.parseKeyword("sizes") || parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(),
+ ShapedType::kDynamicSize, sizes) ||
+ parser.parseComma())
+ return failure();
+
+ // Parse `strides`.
+ SmallVector<OpAsmParser::OperandType, 4> strides;
+ if (parser.parseKeyword("strides") || parser.parseColon() ||
+ parseListOfOperandsOrIntegers(
+ parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(),
+ ShapedType::kDynamicStrideOrOffset, strides))
+ return failure();
+
+ // Handle segment sizes.
+ auto b = parser.getBuilder();
+ SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offset.size()),
+ static_cast<int>(sizes.size()),
+ static_cast<int>(strides.size())};
+ result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+
+ b.getI32VectorAttr(segmentSizes));
+
+ // Parse types and resolve.
+ Type indexType = b.getIndexType();
+ Type operandType, resultType;
+ return failure(
+ (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(operandType) || parser.parseKeyword("to") ||
+ parser.parseType(resultType) ||
+ parser.resolveOperand(operand, operandType, result.operands) ||
+ parser.resolveOperands(offset, indexType, result.operands) ||
+ parser.resolveOperands(sizes, indexType, result.operands) ||
+ parser.resolveOperands(strides, indexType, result.operands) ||
+ parser.addTypeToList(resultType, result.types)));
+}
+
+static LogicalResult verify(MemRefReinterpretCastOp op) {
+ // The source and result memrefs should be in the same memory space.
+ auto srcType = op.source().getType().cast<BaseMemRefType>();
+ auto resultType = op.getType().cast<MemRefType>();
+ if (srcType.getMemorySpace() != resultType.getMemorySpace())
+ return op.emitError("
diff erent memory spaces specified for source type ")
+ << srcType << " and result memref type " << resultType;
+ if (srcType.getElementType() != resultType.getElementType())
+ return op.emitError("
diff erent element types specified for source type ")
+ << srcType << " and result memref type " << resultType;
+
+ // Verify that dynamic and static offset/sizes/strides arguments/attributes
+ // are consistent.
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset, op.offsets())))
+ return failure();
+ unsigned resultRank = op.getResultRank();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "size", resultRank, op.getStaticSizesAttrName(),
+ op.static_sizes(), ShapedType::isDynamic, op.sizes())))
+ return failure();
+ if (failed(verifyOpWithOffsetSizesAndStridesPart(
+ op, "stride", resultRank, op.getStaticStridesAttrName(),
+ op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+ op.strides())))
+ return failure();
+
+ // Extract source offset and strides.
+ int64_t resultOffset;
+ SmallVector<int64_t, 4> resultStrides;
+ if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ return failure();
+
+ // Match offset in result memref type and in static_offsets attribute.
+ int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+ if (resultOffset != expectedOffset)
+ return op.emitError("expected result type with offset = ")
+ << resultOffset << " instead of " << expectedOffset;
+
+ // Match sizes in result memref type and in static_sizes attribute.
+ for (auto &en :
+ llvm::enumerate(llvm::zip(resultType.getShape(),
+ extractFromI64ArrayAttr(op.static_sizes())))) {
+ int64_t resultSize = std::get<0>(en.value());
+ int64_t expectedSize = std::get<1>(en.value());
+ if (resultSize != expectedSize)
+ return op.emitError("expected result type with size = ")
+ << expectedSize << " instead of " << resultSize
+ << " in dim = " << en.index();
+ }
+
+ // Match strides in result memref type and in static_strides attribute.
+ for (auto &en : llvm::enumerate(llvm::zip(
+ resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+ int64_t resultStride = std::get<0>(en.value());
+ int64_t expectedStride = std::get<1>(en.value());
+ if (resultStride != expectedStride)
+ return op.emitError("expected result type with stride = ")
+ << expectedStride << " instead of " << resultStride
+ << " in dim = " << en.index();
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MemRefReshapeOp
//===----------------------------------------------------------------------===//
@@ -2577,75 +2860,6 @@ bool UIToFPOp::areCastCompatible(Type a, Type b) {
// SubViewOp
//===----------------------------------------------------------------------===//
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void printSubViewListOfOperandsOrIntegers(
- OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- p << "[";
- unsigned idx = 0;
- llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
- int64_t val = a.cast<IntegerAttr>().getInt();
- if (isDynamic(val))
- p << values[idx++];
- else
- p << val;
- });
- p << "] ";
-}
-
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-/// 2. `ssa` is filled with "[%arg0, %arg1]".
-static ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
- StringRef attrName, int64_t dynVal,
- SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
- if (failed(parser.parseLSquare()))
- return failure();
- // 0-D.
- if (succeeded(parser.parseOptionalRSquare())) {
- result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
- return success();
- }
-
- SmallVector<int64_t, 4> attrVals;
- while (true) {
- OpAsmParser::OperandType operand;
- auto res = parser.parseOptionalOperand(operand);
- if (res.hasValue() && succeeded(res.getValue())) {
- ssa.push_back(operand);
- attrVals.push_back(dynVal);
- } else {
- Attribute attr;
- NamedAttrList placeholder;
- if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
- !attr.isa<IntegerAttr>())
- return parser.emitError(parser.getNameLoc())
- << "expected SSA value or integer";
- attrVals.push_back(attr.cast<IntegerAttr>().getInt());
- }
-
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- else
- break;
- }
-
- auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
- result.addAttribute(attrName, arrayAttr);
- return success();
-}
-
namespace {
/// Helpers to write more idiomatic operations.
namespace saturated_arith {
@@ -2733,12 +2947,15 @@ static void printOpWithOffsetsSizesAndStrides(
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source();
printExtraOperands(p, op);
- printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
+ printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ' ';
+ printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ p << ' ';
+ printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << ' ';
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{OpType::getSpecialAttrNames()});
p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
@@ -2878,33 +3095,6 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
/// For ViewLikeOpInterface.
Value SubViewOp::getViewSource() { return source(); }
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
- OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
- llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
- /// Check static and dynamic offsets/sizes/strides breakdown.
- if (attr.size() != op.getSourceRank())
- return op.emitError("expected ")
- << op.getSourceRank() << " " << name << " values";
- unsigned expectedNumDynamicEntries =
- llvm::count_if(attr.getValue(), [&](Attribute attr) {
- return isDynamic(attr.cast<IntegerAttr>().getInt());
- });
- if (values.size() != expectedNumDynamicEntries)
- return op.emitError("expected ")
- << expectedNumDynamicEntries << " dynamic " << name << " values";
- return success();
-}
-
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
-}
-
llvm::Optional<SmallVector<bool, 4>>
mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape) {
@@ -3041,24 +3231,6 @@ static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
llvm_unreachable("unexpected subview verification result");
}
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
- // Verify static attributes offsets/sizes/strides.
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset, op.offsets())))
- return failure();
-
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
- ShapedType::isDynamic, op.sizes())))
- return failure();
- if (failed(verifyOpWithOffsetSizesAndStridesPart(
- op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset, op.strides())))
- return failure();
- return success();
-}
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 8047ad94f588..5bd94bd39d91 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -105,6 +105,85 @@ func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off
// -----
+// CHECK-LABEL: func @memref_reinterpret_cast_too_many_offsets
+func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
+ // expected-error @+1 {{expected 1 offset values}}
+ %out = memref_reinterpret_cast %in to
+ offset: [0, 0], sizes: [10, 10], strides: [10, 1]
+ : memref<?xf32> to memref<10x10xf32, offset: 0, strides: [10, 1]>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_element_types
+func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) {
+ // expected-error @+1 {{
diff erent element types specified}}
+ %out = memref_reinterpret_cast %in to
+ offset: [0], sizes: [10], strides: [1]
+ : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_memory_space
+func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) {
+ // expected-error @+1 {{
diff erent memory spaces specified}}
+ %out = memref_reinterpret_cast %in to
+ offset: [0], sizes: [10], strides: [1]
+ : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+ // expected-error @+1 {{expected result type with offset = 0 instead of 1}}
+ %out = memref_reinterpret_cast %in to
+ offset: [1], sizes: [10], strides: [1]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch
+func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) {
+ // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}}
+ %out = memref_reinterpret_cast %in to
+ offset: [0], sizes: [10], strides: [1]
+ : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+ // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}}
+ %out = memref_reinterpret_cast %in to
+ offset: [0], sizes: [10], strides: [2]
+ : memref<?xf32> to memref<10xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c10 = constant 10 : index
+ // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}}
+ %out = memref_reinterpret_cast %in to
+ offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
+ : memref<?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ return
+}
+
+// -----
+
// CHECK-LABEL: memref_reshape_element_type_mismatch
func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 501ff07fd5a7..fcf2325fd2e8 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -55,6 +55,17 @@ func @atan2(%arg0 : f32, %arg1 : f32) -> f32 {
return %result : f32
}
+// CHECK-LABEL: func @memref_reinterpret_cast
+func @memref_reinterpret_cast(%in: memref<?xf32>)
+ -> memref<10x?xf32, offset: ?, strides: [?, 1]> {
+ %c0 = constant 0 : index
+ %c10 = constant 10 : index
+ %out = memref_reinterpret_cast %in to
+ offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
+ : memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
+ return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
+}
+
// CHECK-LABEL: func @memref_reshape(
func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
%shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {
More information about the Mlir-commits
mailing list