[Mlir-commits] [mlir] 461605c - [mlir] Add MemRefReinterpretCastOp definition to Standard.

Alexander Belyaev llvmlistbot at llvm.org
Thu Oct 22 06:18:00 PDT 2020


Author: Alexander Belyaev
Date: 2020-10-22T15:17:22+02:00
New Revision: 461605c418e9059aa50de65c60bbd49e8f270b4a

URL: https://github.com/llvm/llvm-project/commit/461605c418e9059aa50de65c60bbd49e8f270b4a
DIFF: https://github.com/llvm/llvm-project/commit/461605c418e9059aa50de65c60bbd49e8f270b4a.diff

LOG: [mlir] Add MemRefReinterpretCastOp definition to Standard.

Reuse most code for printing/parsing/verification from SubViewOp.

https://llvm.discourse.group/t/rfc-standard-memref-cast-ops/1454/15

Differential Revision: https://https://reviews.llvm.org/D89720

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6de8ace044cc..e3e1930580b4 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2217,6 +2217,52 @@ def MemRefCastOp : CastOp<"memref_cast", [
   }];
 }
 
+
+//===----------------------------------------------------------------------===//
+// MemRefReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+def MemRefReinterpretCastOp:
+    BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [
+      NoSideEffect, ViewLikeOpInterface
+    ]> {
+  let summary = "memref reinterpret cast operation";
+  let description = [{
+    Modify offset, sizes and strides of an unranked/ranked memref.
+
+    Example:
+    ```mlir
+    memref_reinterpret_cast %ranked to
+      offset: [0],
+      sizes: [%size0, 10],
+      strides: [1, %stride1]
+    : memref<?x?xf32> to memref<?x10xf32, offset: 0, strides: [1, ?]>
+
+    memref_reinterpret_cast %unranked to
+      offset: [%offset],
+      sizes: [%size0, %size1],
+      strides: [%stride0, %stride1]
+    : memref<*xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<AnyRankedOrUnrankedMemRef, "", []>:$source,
+    Variadic<Index>:$offsets,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides,
+    I64ArrayAttr:$static_offsets,
+    I64ArrayAttr:$static_sizes,
+    I64ArrayAttr:$static_strides
+  );
+  let results = (outs AnyMemRef:$result);
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    // The result of the op is always a ranked memref.
+    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+    Value getViewSource() { return source(); }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefReshapeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7010ad5f34c9..4638ae4e9268 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -261,6 +261,126 @@ OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
                                         [](APInt a, APInt b) { return a + b; });
 }
 
+//===----------------------------------------------------------------------===//
+// BaseOpWithOffsetSizesAndStridesOp
+//===----------------------------------------------------------------------===//
+
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void
+printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+                              ArrayAttr arrayAttr,
+                              llvm::function_ref<bool(int64_t)> isDynamic) {
+  p << '[';
+  unsigned idx = 0;
+  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+    int64_t val = a.cast<IntegerAttr>().getInt();
+    if (isDynamic(val))
+      p << values[idx++];
+    else
+      p << val;
+  });
+  p << ']';
+}
+
+/// Parse a mixed list with either (1) static integer values or (2) SSA values.
+/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
+/// encode the position of SSA values. Add the parsed SSA values to `ssa`
+/// in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+///   2. `ssa` is filled with "[%arg0, %arg1]".
+static ParseResult
+parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
+                              StringRef attrName, int64_t dynVal,
+                              SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+  if (failed(parser.parseLSquare()))
+    return failure();
+  // 0-D.
+  if (succeeded(parser.parseOptionalRSquare())) {
+    result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
+    return success();
+  }
+
+  SmallVector<int64_t, 4> attrVals;
+  while (true) {
+    OpAsmParser::OperandType operand;
+    auto res = parser.parseOptionalOperand(operand);
+    if (res.hasValue() && succeeded(res.getValue())) {
+      ssa.push_back(operand);
+      attrVals.push_back(dynVal);
+    } else {
+      IntegerAttr attr;
+      if (failed(parser.parseAttribute<IntegerAttr>(attr)))
+        return parser.emitError(parser.getNameLoc())
+               << "expected SSA value or integer";
+      attrVals.push_back(attr.getInt());
+    }
+
+    if (succeeded(parser.parseOptionalComma()))
+      continue;
+    if (failed(parser.parseRSquare()))
+      return failure();
+    break;
+  }
+
+  auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
+  result.addAttribute(attrName, arrayAttr);
+  return success();
+}
+
+/// Verify that a particular offset/size/stride static attribute is well-formed.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
+    OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName,
+    ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
+    ValueRange values) {
+  /// Check static and dynamic offsets/sizes/strides breakdown.
+  if (attr.size() != expectedNumElements)
+    return op.emitError("expected ")
+           << expectedNumElements << " " << name << " values";
+  unsigned expectedNumDynamicEntries =
+      llvm::count_if(attr.getValue(), [&](Attribute attr) {
+        return isDynamic(attr.cast<IntegerAttr>().getInt());
+      });
+  if (values.size() != expectedNumDynamicEntries)
+    return op.emitError("expected ")
+           << expectedNumDynamicEntries << " dynamic " << name << " values";
+  return success();
+}
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+        return a.cast<IntegerAttr>().getInt();
+      }));
+}
+
+/// Verify static attributes offsets/sizes/strides.
+template <typename OpType>
+static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
+  unsigned srcRank = op.getSourceRank();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "offset", srcRank, op.getStaticOffsetsAttrName(),
+          op.static_offsets(), ShapedType::isDynamicStrideOrOffset,
+          op.offsets())))
+    return failure();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(),
+          ShapedType::isDynamic, op.sizes())))
+    return failure();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "stride", srcRank, op.getStaticStridesAttrName(),
+          op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+          op.strides())))
+    return failure();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp / AllocaOp
 //===----------------------------------------------------------------------===//
@@ -2145,6 +2265,169 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
   return impl::foldCastOp(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+/// Print of the form:
+/// ```
+///   `name` ssa-name to
+///       offset: `[` offset `]`
+///       sizes: `[` size-list `]`
+///       strides:`[` stride-list `]`
+///   `:` any-memref-type to strided-memref-type
+/// ```
+static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
+  int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op.getOperationName().drop_front(stdDotLen) << " " << op.source()
+    << " to offset: ";
+  printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p << ", sizes: ";
+  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+                                ShapedType::isDynamic);
+  p << ", strides: ";
+  printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p.printOptionalAttrDict(
+      op.getAttrs(),
+      /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+                       MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+                       MemRefReinterpretCastOp::getStaticSizesAttrName(),
+                       MemRefReinterpretCastOp::getStaticStridesAttrName()});
+  p << ": " << op.source().getType() << " to " << op.getType();
+}
+
+/// Parse of the form:
+/// ```
+///   `name` ssa-name to
+///       offset: `[` offset `]`
+///       sizes: `[` size-list `]`
+///       strides:`[` stride-list `]`
+///   `:` any-memref-type to strided-memref-type
+/// ```
+static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
+                                                OperationState &result) {
+  // Parse `operand` and `offset`.
+  OpAsmParser::OperandType operand;
+  if (parser.parseOperand(operand))
+    return failure();
+
+  // Parse offset.
+  SmallVector<OpAsmParser::OperandType, 1> offset;
+  if (parser.parseKeyword("to") || parser.parseKeyword("offset") ||
+      parser.parseColon() ||
+      parseListOfOperandsOrIntegers(
+          parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(),
+          ShapedType::kDynamicStrideOrOffset, offset) ||
+      parser.parseComma())
+    return failure();
+
+  // Parse `sizes`.
+  SmallVector<OpAsmParser::OperandType, 4> sizes;
+  if (parser.parseKeyword("sizes") || parser.parseColon() ||
+      parseListOfOperandsOrIntegers(
+          parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(),
+          ShapedType::kDynamicSize, sizes) ||
+      parser.parseComma())
+    return failure();
+
+  // Parse `strides`.
+  SmallVector<OpAsmParser::OperandType, 4> strides;
+  if (parser.parseKeyword("strides") || parser.parseColon() ||
+      parseListOfOperandsOrIntegers(
+          parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(),
+          ShapedType::kDynamicStrideOrOffset, strides))
+    return failure();
+
+  // Handle segment sizes.
+  auto b = parser.getBuilder();
+  SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offset.size()),
+                                      static_cast<int>(sizes.size()),
+                                      static_cast<int>(strides.size())};
+  result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(),
+
+                      b.getI32VectorAttr(segmentSizes));
+
+  // Parse types and resolve.
+  Type indexType = b.getIndexType();
+  Type operandType, resultType;
+  return failure(
+      (parser.parseOptionalAttrDict(result.attributes) ||
+       parser.parseColonType(operandType) || parser.parseKeyword("to") ||
+       parser.parseType(resultType) ||
+       parser.resolveOperand(operand, operandType, result.operands) ||
+       parser.resolveOperands(offset, indexType, result.operands) ||
+       parser.resolveOperands(sizes, indexType, result.operands) ||
+       parser.resolveOperands(strides, indexType, result.operands) ||
+       parser.addTypeToList(resultType, result.types)));
+}
+
+static LogicalResult verify(MemRefReinterpretCastOp op) {
+  // The source and result memrefs should be in the same memory space.
+  auto srcType = op.source().getType().cast<BaseMemRefType>();
+  auto resultType = op.getType().cast<MemRefType>();
+  if (srcType.getMemorySpace() != resultType.getMemorySpace())
+    return op.emitError("
diff erent memory spaces specified for source type ")
+           << srcType << " and result memref type " << resultType;
+  if (srcType.getElementType() != resultType.getElementType())
+    return op.emitError("
diff erent element types specified for source type ")
+           << srcType << " and result memref type " << resultType;
+
+  // Verify that dynamic and static offset/sizes/strides arguments/attributes
+  // are consistent.
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(),
+          ShapedType::isDynamicStrideOrOffset, op.offsets())))
+    return failure();
+  unsigned resultRank = op.getResultRank();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "size", resultRank, op.getStaticSizesAttrName(),
+          op.static_sizes(), ShapedType::isDynamic, op.sizes())))
+    return failure();
+  if (failed(verifyOpWithOffsetSizesAndStridesPart(
+          op, "stride", resultRank, op.getStaticStridesAttrName(),
+          op.static_strides(), ShapedType::isDynamicStrideOrOffset,
+          op.strides())))
+    return failure();
+
+  // Extract source offset and strides.
+  int64_t resultOffset;
+  SmallVector<int64_t, 4> resultStrides;
+  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+    return failure();
+
+  // Match offset in result memref type and in static_offsets attribute.
+  int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
+  if (resultOffset != expectedOffset)
+    return op.emitError("expected result type with offset = ")
+           << resultOffset << " instead of " << expectedOffset;
+
+  // Match sizes in result memref type and in static_sizes attribute.
+  for (auto &en :
+       llvm::enumerate(llvm::zip(resultType.getShape(),
+                                 extractFromI64ArrayAttr(op.static_sizes())))) {
+    int64_t resultSize = std::get<0>(en.value());
+    int64_t expectedSize = std::get<1>(en.value());
+    if (resultSize != expectedSize)
+      return op.emitError("expected result type with size = ")
+             << expectedSize << " instead of " << resultSize
+             << " in dim = " << en.index();
+  }
+
+  // Match strides in result memref type and in static_strides attribute.
+  for (auto &en : llvm::enumerate(llvm::zip(
+           resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+    int64_t resultStride = std::get<0>(en.value());
+    int64_t expectedStride = std::get<1>(en.value());
+    if (resultStride != expectedStride)
+      return op.emitError("expected result type with stride = ")
+             << expectedStride << " instead of " << resultStride
+             << " in dim = " << en.index();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefReshapeOp
 //===----------------------------------------------------------------------===//
@@ -2577,75 +2860,6 @@ bool UIToFPOp::areCastCompatible(Type a, Type b) {
 // SubViewOp
 //===----------------------------------------------------------------------===//
 
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void printSubViewListOfOperandsOrIntegers(
-    OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
-    llvm::function_ref<bool(int64_t)> isDynamic) {
-  p << "[";
-  unsigned idx = 0;
-  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
-    int64_t val = a.cast<IntegerAttr>().getInt();
-    if (isDynamic(val))
-      p << values[idx++];
-    else
-      p << val;
-  });
-  p << "] ";
-}
-
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-///   2. `ssa` is filled with "[%arg0, %arg1]".
-static ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
-                              StringRef attrName, int64_t dynVal,
-                              SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
-  if (failed(parser.parseLSquare()))
-    return failure();
-  // 0-D.
-  if (succeeded(parser.parseOptionalRSquare())) {
-    result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
-    return success();
-  }
-
-  SmallVector<int64_t, 4> attrVals;
-  while (true) {
-    OpAsmParser::OperandType operand;
-    auto res = parser.parseOptionalOperand(operand);
-    if (res.hasValue() && succeeded(res.getValue())) {
-      ssa.push_back(operand);
-      attrVals.push_back(dynVal);
-    } else {
-      Attribute attr;
-      NamedAttrList placeholder;
-      if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
-          !attr.isa<IntegerAttr>())
-        return parser.emitError(parser.getNameLoc())
-               << "expected SSA value or integer";
-      attrVals.push_back(attr.cast<IntegerAttr>().getInt());
-    }
-
-    if (succeeded(parser.parseOptionalComma()))
-      continue;
-    if (failed(parser.parseRSquare()))
-      return failure();
-    else
-      break;
-  }
-
-  auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
-  result.addAttribute(attrName, arrayAttr);
-  return success();
-}
-
 namespace {
 /// Helpers to write more idiomatic operations.
 namespace saturated_arith {
@@ -2733,12 +2947,15 @@ static void printOpWithOffsetsSizesAndStrides(
   p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
   p << op.source();
   printExtraOperands(p, op);
-  printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
-                                       ShapedType::isDynamicStrideOrOffset);
-  printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
-                                       ShapedType::isDynamic);
-  printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
-                                       ShapedType::isDynamicStrideOrOffset);
+  printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p << ' ';
+  printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+                                ShapedType::isDynamic);
+  p << ' ';
+  printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+                                ShapedType::isDynamicStrideOrOffset);
+  p << ' ';
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
   p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
@@ -2878,33 +3095,6 @@ void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return source(); }
 
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
-    OpType op, StringRef name, StringRef attrName, ArrayAttr attr,
-    llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
-  /// Check static and dynamic offsets/sizes/strides breakdown.
-  if (attr.size() != op.getSourceRank())
-    return op.emitError("expected ")
-           << op.getSourceRank() << " " << name << " values";
-  unsigned expectedNumDynamicEntries =
-      llvm::count_if(attr.getValue(), [&](Attribute attr) {
-        return isDynamic(attr.cast<IntegerAttr>().getInt());
-      });
-  if (values.size() != expectedNumDynamicEntries)
-    return op.emitError("expected ")
-           << expectedNumDynamicEntries << " dynamic " << name << " values";
-  return success();
-}
-
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
-  return llvm::to_vector<4>(
-      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
-        return a.cast<IntegerAttr>().getInt();
-      }));
-}
-
 llvm::Optional<SmallVector<bool, 4>>
 mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
                                ArrayRef<int64_t> reducedShape) {
@@ -3041,24 +3231,6 @@ static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
   llvm_unreachable("unexpected subview verification result");
 }
 
-template <typename OpType>
-static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
-  // Verify static attributes offsets/sizes/strides.
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
-          ShapedType::isDynamicStrideOrOffset, op.offsets())))
-    return failure();
-
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "size", op.getStaticSizesAttrName(), op.static_sizes(),
-          ShapedType::isDynamic, op.sizes())))
-    return failure();
-  if (failed(verifyOpWithOffsetSizesAndStridesPart(
-          op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
-          ShapedType::isDynamicStrideOrOffset, op.strides())))
-    return failure();
-  return success();
-}
 
 /// Verifier for SubViewOp.
 static LogicalResult verify(SubViewOp op) {

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 8047ad94f588..5bd94bd39d91 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -105,6 +105,85 @@ func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off
 
 // -----
 
+// CHECK-LABEL: func @memref_reinterpret_cast_too_many_offsets
+func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
+  // expected-error @+1 {{expected 1 offset values}}
+  %out = memref_reinterpret_cast %in to
+           offset: [0, 0], sizes: [10, 10], strides: [10, 1]
+           : memref<?xf32> to memref<10x10xf32, offset: 0, strides: [10, 1]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_element_types
+func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) {
+  // expected-error @+1 {{
diff erent element types specified}}
+  %out = memref_reinterpret_cast %in to
+           offset: [0], sizes: [10], strides: [1]
+         : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_memory_space
+func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) {
+  // expected-error @+1 {{
diff erent memory spaces specified}}
+  %out = memref_reinterpret_cast %in to
+           offset: [0], sizes: [10], strides: [1]
+         : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+  // expected-error @+1 {{expected result type with offset = 0 instead of 1}}
+  %out = memref_reinterpret_cast %in to
+           offset: [1], sizes: [10], strides: [1]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch
+func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) {
+  // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}}
+  %out = memref_reinterpret_cast %in to
+           offset: [0], sizes: [10], strides: [1]
+         : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+  // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}}
+  %out = memref_reinterpret_cast %in to
+           offset: [0], sizes: [10], strides: [2]
+         : memref<?xf32> to memref<10xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch
+func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
+  %c0 = constant 0 : index
+  %c10 = constant 10 : index
+  // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}}
+  %out = memref_reinterpret_cast %in to
+           offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
+           : memref<?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: memref_reshape_element_type_mismatch
 func @memref_reshape_element_type_mismatch(
        %buf: memref<*xf32>, %shape: memref<1xi32>) {

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 501ff07fd5a7..fcf2325fd2e8 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -55,6 +55,17 @@ func @atan2(%arg0 : f32, %arg1 : f32) -> f32 {
   return %result : f32
 }
 
+// CHECK-LABEL: func @memref_reinterpret_cast
+func @memref_reinterpret_cast(%in: memref<?xf32>)
+    -> memref<10x?xf32, offset: ?, strides: [?, 1]> {
+  %c0 = constant 0 : index
+  %c10 = constant 10 : index
+  %out = memref_reinterpret_cast %in to
+           offset: [%c0], sizes: [10, %c10], strides: [%c10, 1]
+           : memref<?xf32> to memref<10x?xf32, offset: ?, strides: [?, 1]>
+  return %out : memref<10x?xf32, offset: ?, strides: [?, 1]>
+}
+
 // CHECK-LABEL: func @memref_reshape(
 func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
          %shape2: memref<2xi32>, %shape3: memref<?xi32>) -> memref<*xf32> {


        


More information about the Mlir-commits mailing list