[Mlir-commits] [mlir] 80d133b - [mlir] Revisit std.subview handling of static information.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon May 11 14:48:44 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-11T17:44:24-04:00
New Revision: 80d133b24f77d1b9d351251315606441c971ef9b
URL: https://github.com/llvm/llvm-project/commit/80d133b24f77d1b9d351251315606441c971ef9b
DIFF: https://github.com/llvm/llvm-project/commit/80d133b24f77d1b9d351251315606441c971ef9b.diff
LOG: [mlir] Revisit std.subview handling of static information.
Summary:
The main objective of this revision is to change the way static information is represented, propagated and canonicalized in the SubViewOp.
In the current implementation the issue is that canonicalization may strictly lose information because static offsets are combined in irrecoverable ways into the result type, in order to fit the strided memref representation.
The core semantics of the op do not change but the parser and printer do: the op always requires `rank` offsets, sizes and strides. These quantities can now be either SSA values or static integer attributes.
The result type is automatically deduced from the static information and more powerful canonicalizations (as powerful as the representation with sentinel `?` values allows). Previously static information was inferred on a best-effort basis from looking at the source and destination type.
Relevant tests are rewritten to use the idiomatic `offset: x, strides : [...]`-form. Bugs are corrected along the way that were not trivially visible in flattened strided memref form.
It is an open question, and a longer discussion, whether a better result type representation would be a nicer alternative. For now, the subview op carries the required semantic.
Reviewers: ftynse, mravishankar, antiagainst, rriddle!, andydavis1, timshen, asaadaldien, stellaraccident
Reviewed By: mravishankar
Subscribers: aartbik, bondhugula, mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, bader, grosul1, frgossen, Kayjukh, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79662
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Conversion/StandardToLLVM/invalid.mlir
mlir/test/Conversion/StandardToSPIRV/legalization.mlir
mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
mlir/test/Dialect/Affine/ops.mlir
mlir/test/Dialect/Linalg/promote.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index f79e955437dd..813dd7db5e9e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2504,7 +2504,7 @@ def SubIOp : IntArithmeticOp<"subi"> {
//===----------------------------------------------------------------------===//
def SubViewOp : Std_Op<"subview", [
- AttrSizedOperandSegments,
+ AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect,
]> {
@@ -2516,17 +2516,14 @@ def SubViewOp : Std_Op<"subview", [
The SubView operation supports the following arguments:
*) Memref: the "base" memref on which to create a "view" memref.
- *) Offsets: zero or memref-rank number of dynamic offsets into the "base"
- memref at which to create the "view" memref.
- *) Sizes: zero or memref-rank dynamic size operands which specify the
- dynamic sizes of the result "view" memref type.
- *) Strides: zero or memref-rank number of dynamic strides which are applied
- multiplicatively to the base memref strides in each dimension.
-
- Note on the number of operands for offsets, sizes and strides: For
- each of these, the number of operands must either be same as the
- memref-rank number or empty. For the latter, those values will be
- treated as constants.
+ *) Offsets: memref-rank number of dynamic offsets or static integer
+ attributes into the "base" memref at which to create the "view"
+ memref.
+ *) Sizes: memref-rank number of dynamic sizes or static integer attributes
+ which specify the sizes of the result "view" memref type.
+ *) Strides: memref-rank number of dynamic strides or static integer
+ attributes multiplicatively to the base memref strides in each
+ dimension.
Example 1:
@@ -2537,7 +2534,7 @@ def SubViewOp : Std_Op<"subview", [
// dynamic sizes for each dimension, and stride arguments '%c1'.
%1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1]
: memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to
- memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)>
+ memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
```
Example 2:
@@ -2564,9 +2561,9 @@ def SubViewOp : Std_Op<"subview", [
%0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)>
// Subview with constant offsets, sizes and strides.
- %1 = subview %0[][][]
+ %1 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
: memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
- memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
+ memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>
```
Example 4:
@@ -2608,7 +2605,7 @@ def SubViewOp : Std_Op<"subview", [
// #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)
//
// where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1.
- %1 = subview %0[%i, %j][][%x, %y] :
+ %1 = subview %0[%i, %j][4, 4][%x, %y] :
: memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> to
memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>
@@ -2624,24 +2621,25 @@ def SubViewOp : Std_Op<"subview", [
AnyMemRef:$source,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
- Variadic<Index>:$strides
+ Variadic<Index>:$strides,
+ I64ArrayAttr:$static_offsets,
+ I64ArrayAttr:$static_sizes,
+ I64ArrayAttr:$static_strides
);
let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $source `[` $offsets `]` `[` $sizes `]` `[` $strides `]` attr-dict `:`
- type($source) `to` type($result)
- }];
-
let builders = [
+ // Build a SubViewOp with mized static and dynamic entries.
OpBuilder<
"OpBuilder &b, OperationState &result, Value source, "
- "ValueRange offsets, ValueRange sizes, "
- "ValueRange strides, Type resultType = Type(), "
- "ArrayRef<NamedAttribute> attrs = {}">,
+ "ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,"
+ "ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes, "
+ "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">,
+ // Build a SubViewOp with all dynamic entries.
OpBuilder<
- "OpBuilder &builder, OperationState &result, "
- "Type resultType, Value source">
+ "OpBuilder &b, OperationState &result, Value source, "
+ "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">
];
let extraClassDeclaration = [{
@@ -2670,13 +2668,34 @@ def SubViewOp : Std_Op<"subview", [
/// operands could not be retrieved.
LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
- // Auxiliary range data structure and helper function that unpacks the
- // offset, size and stride operands of the SubViewOp into a list of triples.
- // Such a list of triple is sometimes more convenient to manipulate.
+ /// Auxiliary range data structure and helper function that unpacks the
+ /// offset, size and stride operands of the SubViewOp into a list of triples.
+ /// Such a list of triple is sometimes more convenient to manipulate.
struct Range {
Value offset, size, stride;
};
SmallVector<Range, 8> getRanges();
+
+ /// Return the rank of the result MemRefType.
+ unsigned getRank() { return getType().getRank(); }
+
+ static StringRef getStaticOffsetsAttrName() {
+ return "static_offsets";
+ }
+ static StringRef getStaticSizesAttrName() {
+ return "static_sizes";
+ }
+ static StringRef getStaticStridesAttrName() {
+ return "static_strides";
+ }
+ static ArrayRef<StringRef> getSpecialAttrNames() {
+ static SmallVector<StringRef, 4> names{
+ getStaticOffsetsAttrName(),
+ getStaticSizesAttrName(),
+ getStaticStridesAttrName(),
+ getOperandSegmentSizeAttr()};
+ return names;
+ }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 553be944ab30..39dc5d203a61 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1293,7 +1293,7 @@ static LogicalResult verify(DimOp op) {
auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return op.emitOpError("requires an integer attribute named 'index'");
- int64_t index = indexAttr.getValue().getSExtValue();
+ int64_t index = indexAttr.getInt();
auto type = op.getOperand().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -2183,59 +2183,272 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
// SubViewOp
//===----------------------------------------------------------------------===//
-// Returns a MemRefType with dynamic sizes and offset and the same stride as the
-// `memRefType` passed as argument.
-// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep
-// sizes and offset static.
-static Type inferSubViewResultType(MemRefType memRefType) {
- auto rank = memRefType.getRank();
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto res = getStridesAndOffset(memRefType, strides, offset);
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void printSubViewListOfOperandsOrIntegers(
+ OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ p << "[";
+ unsigned idx = 0;
+ llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+ int64_t val = a.cast<IntegerAttr>().getInt();
+ if (isDynamic(val))
+ p << values[idx++];
+ else
+ p << val;
+ });
+ p << "] ";
+}
+
+/// Parse a mixed list with either (1) static integer values or (2) SSA values.
+/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
+/// encode the position of SSA values. Add the parsed SSA values to `ssa`
+/// in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+/// 2. `ssa` is filled with "[%arg0, %arg1]".
+static ParseResult
+parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
+ StringRef attrName, int64_t dynVal,
+ SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+ if (failed(parser.parseLSquare()))
+ return failure();
+ // 0-D.
+ if (succeeded(parser.parseOptionalRSquare()))
+ return success();
+
+ SmallVector<int64_t, 4> attrVals;
+ while (true) {
+ OpAsmParser::OperandType operand;
+ auto res = parser.parseOptionalOperand(operand);
+ if (res.hasValue() && succeeded(res.getValue())) {
+ ssa.push_back(operand);
+ attrVals.push_back(dynVal);
+ } else {
+ Attribute attr;
+ NamedAttrList placeholder;
+ if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
+ !attr.isa<IntegerAttr>())
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ attrVals.push_back(attr.cast<IntegerAttr>().getInt());
+ }
+
+ if (succeeded(parser.parseOptionalComma()))
+ continue;
+ if (failed(parser.parseRSquare()))
+ return failure();
+ else
+ break;
+ }
+
+ auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
+ result.addAttribute(attrName, arrayAttr);
+ return success();
+}
+
+namespace {
+/// Helpers to write more idiomatic operations.
+namespace saturated_arith {
+struct Wrapper {
+ explicit Wrapper(int64_t v) : v(v) {}
+ operator int64_t() { return v; }
+ int64_t v;
+};
+Wrapper operator+(Wrapper a, int64_t b) {
+ if (ShapedType::isDynamicStrideOrOffset(a) ||
+ ShapedType::isDynamicStrideOrOffset(b))
+ return Wrapper(ShapedType::kDynamicStrideOrOffset);
+ return Wrapper(a.v + b);
+}
+Wrapper operator*(Wrapper a, int64_t b) {
+ if (ShapedType::isDynamicStrideOrOffset(a) ||
+ ShapedType::isDynamicStrideOrOffset(b))
+ return Wrapper(ShapedType::kDynamicStrideOrOffset);
+ return Wrapper(a.v * b);
+}
+} // end namespace saturated_arith
+} // end namespace
+
+/// A subview result type can be fully inferred from the source type and the
+/// static representation of offsets, sizes and strides. Special sentinels
+/// encode the dynamic case.
+static Type inferSubViewResultType(MemRefType sourceMemRefType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides) {
+ unsigned rank = sourceMemRefType.getRank();
+ (void)rank;
+ assert(staticOffsets.size() == rank &&
+ "unexpected staticOffsets size mismatch");
+ assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
+ assert(staticStrides.size() == rank &&
+ "unexpected staticStrides size mismatch");
+
+ // Extract source offset and strides.
+ int64_t sourceOffset;
+ SmallVector<int64_t, 4> sourceStrides;
+ auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
assert(succeeded(res) && "SubViewOp expected strided memref type");
(void)res;
- // Assume sizes and offset are fully dynamic for now until canonicalization
- // occurs on the ranges. Typed strides don't change though.
- offset = MemRefType::getDynamicStrideOrOffset();
- // Overwrite strides because verifier will not pass.
- // TODO(b/144419106): don't force degrade the strides to fully dynamic.
- for (auto &stride : strides)
- stride = MemRefType::getDynamicStrideOrOffset();
- auto stridedLayout =
- makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
- SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
- return MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setAffineMaps(stridedLayout);
+ // Compute target offset whose value is:
+ // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
+ int64_t targetOffset = sourceOffset;
+ for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
+ auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
+ using namespace saturated_arith;
+ targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
+ }
+
+ // Compute target stride whose value is:
+ // `sourceStrides_i * staticStrides_i`.
+ SmallVector<int64_t, 4> targetStrides;
+ targetStrides.reserve(staticOffsets.size());
+ for (auto it : llvm::zip(sourceStrides, staticStrides)) {
+ auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
+ using namespace saturated_arith;
+ targetStrides.push_back(Wrapper(sourceStride) * staticStride);
+ }
+
+ // The type is now known.
+ return MemRefType::get(
+ staticSizes, sourceMemRefType.getElementType(),
+ makeStridedLinearLayoutMap(targetStrides, targetOffset,
+ sourceMemRefType.getContext()),
+ sourceMemRefType.getMemorySpace());
+}
+
+/// Print SubViewOp in the form:
+/// ```
+/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `:` strided-memref-type `to` strided-memref-type
+/// ```
+static void print(OpAsmPrinter &p, SubViewOp op) {
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+ p << op.getOperand(0);
+ printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p.printOptionalAttrDict(op.getAttrs(),
+ /*elided=*/{SubViewOp::getSpecialAttrNames()});
+ p << " : " << op.getOperand(0).getType() << " to " << op.getType();
+}
+
+/// Parse SubViewOp of the form:
+/// ```
+/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `:` strided-memref-type `to` strided-memref-type
+/// ```
+static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType srcInfo;
+ SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
+ auto indexType = parser.getBuilder().getIndexType();
+ Type srcType, dstType;
+ if (parser.parseOperand(srcInfo))
+ return failure();
+ if (parseListOfOperandsOrIntegers(
+ parser, result, SubViewOp::getStaticOffsetsAttrName(),
+ ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
+ parseListOfOperandsOrIntegers(parser, result,
+ SubViewOp::getStaticSizesAttrName(),
+ ShapedType::kDynamicSize, sizesInfo) ||
+ parseListOfOperandsOrIntegers(
+ parser, result, SubViewOp::getStaticStridesAttrName(),
+ ShapedType::kDynamicStrideOrOffset, stridesInfo))
+ return failure();
+
+ auto b = parser.getBuilder();
+ SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
+ static_cast<int>(sizesInfo.size()),
+ static_cast<int>(stridesInfo.size())};
+ result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
+ b.getI32VectorAttr(segmentSizes));
+
+ return failure(
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(srcType) ||
+ parser.resolveOperand(srcInfo, srcType, result.operands) ||
+ parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
+ parser.resolveOperands(sizesInfo, indexType, result.operands) ||
+ parser.resolveOperands(stridesInfo, indexType, result.operands) ||
+ parser.parseKeywordType("to", dstType) ||
+ parser.addTypeToList(dstType, result.types));
}
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
- ValueRange offsets, ValueRange sizes,
- ValueRange strides, Type resultType,
+ ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides, ValueRange offsets,
+ ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
- if (!resultType)
- resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
- build(b, result, resultType, source, offsets, sizes, strides);
+ auto sourceMemRefType = source.getType().cast<MemRefType>();
+ auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
+ staticSizes, staticStrides);
+ build(b, result, resultType, source, offsets, sizes, strides,
+ b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+ b.getI64ArrayAttr(staticStrides));
result.addAttributes(attrs);
}
-void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
- Type resultType, Value source) {
- build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
- resultType);
+/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
+/// and `staticStrides` are automatically filled with source-memref-rank
+/// sentinel values that encode dynamic entries.
+void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
+ ValueRange offsets, ValueRange sizes,
+ ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceMemRefType = source.getType().cast<MemRefType>();
+ unsigned rank = sourceMemRefType.getRank();
+ SmallVector<int64_t, 4> staticOffsetsVector;
+ staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
+ SmallVector<int64_t, 4> staticSizesVector;
+ staticSizesVector.assign(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 4> staticStridesVector;
+ staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
+ build(b, result, source, staticOffsetsVector, staticSizesVector,
+ staticStridesVector, offsets, sizes, strides, attrs);
+}
+
+/// Verify that a particular offset/size/stride static attribute is well-formed.
+static LogicalResult
+verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
+ ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
+ ValueRange values) {
+ /// Check static and dynamic offsets/sizes/strides breakdown.
+ if (attr.size() != op.getRank())
+ return op.emitError("expected ")
+ << op.getRank() << " " << name << " values";
+ unsigned expectedNumDynamicEntries =
+ llvm::count_if(attr.getValue(), [&](Attribute attr) {
+ return isDynamic(attr.cast<IntegerAttr>().getInt());
+ });
+ if (values.size() != expectedNumDynamicEntries)
+ return op.emitError("expected ")
+ << expectedNumDynamicEntries << " dynamic " << name << " values";
+ return success();
+}
+
+/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
}
+/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
auto baseType = op.getBaseMemRefType().cast<MemRefType>();
auto subViewType = op.getType();
- // The rank of the base and result subview must match.
- if (baseType.getRank() != subViewType.getRank()) {
- return op.emitError(
- "expected rank of result type to match rank of base type ");
- }
-
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
return op.emitError("
diff erent memory spaces specified for base memref "
@@ -2243,96 +2456,32 @@ static LogicalResult verify(SubViewOp op) {
<< baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map.
- int64_t baseOffset;
- SmallVector<int64_t, 4> baseStrides;
- if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
- return op.emitError("base type ") << subViewType << " is not strided";
-
- // Verify that the result memref type has a strided layout map.
- int64_t subViewOffset;
- SmallVector<int64_t, 4> subViewStrides;
- if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
- return op.emitError("result type ") << subViewType << " is not strided";
-
- // Num offsets should either be zero or rank of memref.
- if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic offsets specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Num sizes should either be zero or rank of memref.
- if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic sizes specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Num strides should either be zero or rank of memref.
- if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) {
- return op.emitError("expected number of dynamic strides specified to match "
- "the rank of the result type ")
- << subViewType;
- }
-
- // Verify that if the shape of the subview type is static, then sizes are not
- // dynamic values, and vice versa.
- if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
- (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
- return op.emitError("invalid to specify dynamic sizes when subview result "
- "type is statically shaped and viceversa");
- }
+ if (!isStrided(baseType))
+ return op.emitError("base type ") << baseType << " is not strided";
- // Verify that if dynamic sizes are specified, then the result memref type
- // have full dynamic dimensions.
- if (op.getNumSizes() > 0) {
- if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
- return dim != ShapedType::kDynamicSize;
- })) {
- // TODO: This is based on the assumption that number of size arguments are
- // either 0, or the rank of the result type. It is possible to have more
- // fine-grained verification where only particular dimensions are
- // dynamic. That probably needs further changes to the shape op
- // specification.
- return op.emitError("expected shape of result type to be fully dynamic "
- "when sizes are specified");
- }
- }
+ // Verify static attributes offsets/sizes/strides.
+ if (failed(verifySubViewOpPart(
+ op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset, op.offsets())))
+ return failure();
- // Verify that if dynamic offsets are specified or base memref has dynamic
- // offset or base memref has dynamic strides, then the subview offset is
- // dynamic.
- if ((op.getNumOffsets() > 0 ||
- baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset())) &&
- subViewOffset != MemRefType::getDynamicStrideOrOffset()) {
- return op.emitError(
- "expected result memref layout map to have dynamic offset");
- }
+ if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
+ op.static_sizes(), ShapedType::isDynamic,
+ op.sizes())))
+ return failure();
+ if (failed(verifySubViewOpPart(
+ op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset, op.strides())))
+ return failure();
- // For now, verify that if dynamic strides are specified, then all the result
- // memref type have dynamic strides.
- if (op.getNumStrides() > 0) {
- if (llvm::any_of(subViewStrides, [](int64_t stride) {
- return stride != MemRefType::getDynamicStrideOrOffset();
- })) {
- return op.emitError("expected result type to have dynamic strides");
- }
- }
+ // Verify result type against inferred type.
+ auto expectedType = inferSubViewResultType(
+ op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
+ extractFromI64ArrayAttr(op.static_sizes()),
+ extractFromI64ArrayAttr(op.static_strides()));
+ if (op.getType() != expectedType)
+ return op.emitError("expected result type to be ") << expectedType;
- // If any of the base memref has dynamic stride, then the corresponding
- // stride of the subview must also have dynamic stride.
- assert(baseStrides.size() == subViewStrides.size());
- for (auto stride : enumerate(baseStrides)) {
- if (stride.value() == MemRefType::getDynamicStrideOrOffset() &&
- subViewStrides[stride.index()] !=
- MemRefType::getDynamicStrideOrOffset()) {
- return op.emitError(
- "expected result type to have dynamic stride along a dimension if "
- "the base memref type has dynamic stride along that dimension");
- }
- }
return success();
}
@@ -2353,37 +2502,9 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
- // If the strides are dynamic return failure.
- if (getNumStrides())
- return failure();
-
- // When static, the stride operands can be retrieved by taking the strides of
- // the result of the subview op, and dividing the strides of the base memref.
- int64_t resultOffset, baseOffset;
- SmallVector<int64_t, 2> resultStrides, baseStrides;
- if (failed(
- getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
- llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
- failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
+ if (!strides().empty())
return failure();
-
- assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
- baseStrides.size() == resultStrides.size() &&
- "base and result memrefs must have the same rank");
- assert(!llvm::is_contained(resultStrides,
- MemRefType::getDynamicStrideOrOffset()) &&
- "strides of subview op must be static, when there are no dynamic "
- "strides specified");
- staticStrides.resize(getType().getRank());
- for (auto resultStride : enumerate(resultStrides)) {
- auto baseStride = baseStrides[resultStride.index()];
- // The result stride is expected to be a multiple of the base stride. Abort
- // if that is not the case.
- if (resultStride.value() < baseStride ||
- resultStride.value() % baseStride != 0)
- return failure();
- staticStrides[resultStride.index()] = resultStride.value() / baseStride;
- }
+ staticStrides = extractFromI64ArrayAttr(static_strides());
return success();
}
@@ -2391,136 +2512,80 @@ Value SubViewOp::getViewSource() { return source(); }
namespace {
-/// Pattern to rewrite a subview op with constant size arguments.
-class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
-public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
- MemRefType subViewType = subViewOp.getType();
- // Follow all or nothing approach for shapes for now. If all the operands
- // for sizes are constants then fold it into the type of the result memref.
- if (subViewType.hasStaticShape() ||
- llvm::any_of(subViewOp.sizes(), [](Value operand) {
- return !matchPattern(operand, m_ConstantIndex());
- })) {
- return failure();
- }
- SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
- for (auto size : llvm::enumerate(subViewOp.sizes())) {
- auto defOp = size.value().getDefiningOp();
- assert(defOp);
- staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
+/// Take a list of `values` with potential new constant to extract and a list
+/// of `constantValues` with`values.size()` sentinel that evaluate to true by
+/// applying `isDynamic`.
+/// Detects the `values` produced by a ConstantIndexOp and places the new
+/// constant in place of the corresponding sentinel value.
+void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
+ SmallVectorImpl<int64_t> &constantValues,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ bool hasNewStaticValue = llvm::any_of(
+ values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
+ if (hasNewStaticValue) {
+ for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
+ cstIdx != e; ++cstIdx) {
+ // Was already static, skip.
+ if (!isDynamic(constantValues[cstIdx]))
+ continue;
+ // Newly static, move from Value to constant.
+ if (matchPattern(values[valIdx], m_ConstantIndex())) {
+ constantValues[cstIdx] =
+ cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
+ // Erase for impl. simplicity. Reverse iterator if we really must.
+ values.erase(std::next(values.begin(), valIdx));
+ continue;
+ }
+ // Remains dynamic move to next value.
+ ++valIdx;
}
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setShape(staticShape);
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
- return success();
}
-};
+}
-// Pattern to rewrite a subview op with constant stride arguments.
-class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
+/// Pattern to rewrite a subview op with constant arguments.
+class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
- if (subViewOp.getNumStrides() == 0) {
- return failure();
- }
- // Follow all or nothing approach for strides for now. If all the operands
- // for strides are constants then fold it into the strides of the result
- // memref.
- int64_t baseOffset, resultOffset;
- SmallVector<int64_t, 4> baseStrides, resultStrides;
- MemRefType subViewType = subViewOp.getType();
- if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
- baseOffset)) ||
- failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset()) ||
- llvm::any_of(subViewOp.strides(), [](Value stride) {
- return !matchPattern(stride, m_ConstantIndex());
- })) {
+ // No constant operand, just return;
+ if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
return failure();
- }
- SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
- for (auto stride : llvm::enumerate(subViewOp.strides())) {
- auto defOp = stride.value().getDefiningOp();
- assert(defOp);
- assert(baseStrides[stride.index()] > 0);
- staticStrides[stride.index()] =
- cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()];
- }
- AffineMap layoutMap = makeStridedLinearLayoutMap(
- staticStrides, resultOffset, rewriter.getContext());
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
+ // At least one of offsets/sizes/strides is a new constant.
+ // Form the new list of operands and constant attributes from the existing.
+ SmallVector<Value, 8> newOffsets(subViewOp.offsets());
+ SmallVector<int64_t, 8> newStaticOffsets =
+ extractFromI64ArrayAttr(subViewOp.static_offsets());
+ assert(newStaticOffsets.size() == subViewOp.getRank());
+ canonicalizeSubViewPart(newOffsets, newStaticOffsets,
+ ShapedType::isDynamicStrideOrOffset);
+
+ SmallVector<Value, 8> newSizes(subViewOp.sizes());
+ SmallVector<int64_t, 8> newStaticSizes =
+ extractFromI64ArrayAttr(subViewOp.static_sizes());
+ assert(newStaticOffsets.size() == subViewOp.getRank());
+ canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
+
+ SmallVector<Value, 8> newStrides(subViewOp.strides());
+ SmallVector<int64_t, 8> newStaticStrides =
+ extractFromI64ArrayAttr(subViewOp.static_strides());
+ assert(newStaticOffsets.size() == subViewOp.getRank());
+ canonicalizeSubViewPart(newStrides, newStaticStrides,
+ ShapedType::isDynamicStrideOrOffset);
+
+ // Create the new op in canonical form.
auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
- subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
- // Insert a memref_cast for compatibility of the uses of the op.
- rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
- subViewOp.getType());
- return success();
- }
-};
+ subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
+ newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
-// Pattern to rewrite a subview op with constant offset arguments.
-class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
-public:
- using OpRewritePattern<SubViewOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(SubViewOp subViewOp,
- PatternRewriter &rewriter) const override {
- if (subViewOp.getNumOffsets() == 0) {
- return failure();
- }
- // Follow all or nothing approach for offsets for now. If all the operands
- // for offsets are constants then fold it into the offset of the result
- // memref.
- int64_t baseOffset, resultOffset;
- SmallVector<int64_t, 4> baseStrides, resultStrides;
- MemRefType subViewType = subViewOp.getType();
- if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
- baseOffset)) ||
- failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
- llvm::is_contained(baseStrides,
- MemRefType::getDynamicStrideOrOffset()) ||
- baseOffset == MemRefType::getDynamicStrideOrOffset() ||
- llvm::any_of(subViewOp.offsets(), [](Value stride) {
- return !matchPattern(stride, m_ConstantIndex());
- })) {
- return failure();
- }
-
- auto staticOffset = baseOffset;
- for (auto offset : llvm::enumerate(subViewOp.offsets())) {
- auto defOp = offset.value().getDefiningOp();
- assert(defOp);
- assert(baseStrides[offset.index()] > 0);
- staticOffset +=
- cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()];
- }
-
- AffineMap layoutMap = makeStridedLinearLayoutMap(
- resultStrides, staticOffset, rewriter.getContext());
- MemRefType newMemRefType =
- MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
- auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
- subViewOp.sizes(), subViewOp.strides(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
+
return success();
}
};
@@ -2633,8 +2698,7 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
- SubViewOpOffsetFolder>(context);
+ results.insert<SubViewOpFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 4cc8e11294d7..41ed5315ab1c 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -839,7 +839,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG2:.*]]: !llvm.i32)
-func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -883,7 +883,8 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
%1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
- memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
+ memref<64x4xf32, offset: 0, strides: [4, 1]>
+ to memref<?x?xf32, offset: ?, strides: [?, ?]>
return
}
@@ -899,7 +900,7 @@ func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg
// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG2:.*]]: !llvm.i32)
-func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -943,13 +944,14 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
%1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
- memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>, 3>
+ memref<64x4xf32, offset: 0, strides: [4, 1], 3>
+ to memref<?x?xf32, offset: ?, strides: [?, ?], 3>
return
}
// CHECK-LABEL: func @subview_const_size(
// CHECK32-LABEL: func @subview_const_size(
-func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -996,14 +998,15 @@ func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 +
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
// CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] :
- memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<4x2xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
+ %1 = subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] :
+ memref<64x4xf32, offset: 0, strides: [4, 1]>
+ to memref<4x2xf32, offset: ?, strides: [?, ?]>
return
}
// CHECK-LABEL: func @subview_const_stride(
// CHECK32-LABEL: func @subview_const_stride(
-func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -1046,14 +1049,15 @@ func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
// CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] :
- memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + d1 * 2 + s0)>>
+ %1 = subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] :
+ memref<64x4xf32, offset: 0, strides: [4, 1]>
+ to memref<?x?xf32, offset: ?, strides: [4, 2]>
return
}
// CHECK-LABEL: func @subview_const_stride_and_offset(
// CHECK32-LABEL: func @subview_const_stride_and_offset(
-func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) {
+func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -1092,8 +1096,9 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1)
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
// CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[][][] :
- memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>>
+ %1 = subview %0[0, 8][62, 3][1, 1] :
+ memref<64x4xf32, offset: 0, strides: [4, 1]>
+ to memref<62x3xf32, offset: 8, strides: [4, 1]>
return
}
diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
index bb9c2728dcb8..1be148707458 100644
--- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
@@ -7,7 +7,7 @@ func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
%c0 = constant 0 : index
// expected-error at +1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}}
%5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
- %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
+ %25 = std.subview %5[%c0, %c0][%c1, %c1][1, 1] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
index 3540a101c55b..d3b339e82a88 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
@@ -11,7 +11,7 @@ func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : in
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return %1 : f32
}
@@ -25,7 +25,8 @@ func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : i
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
+ memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return %1 : f32
}
@@ -41,7 +42,8 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] :
+ memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return
}
@@ -55,7 +57,8 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 :
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
+ memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
index c0d904adb51c..2e28079b4f15 100644
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
@@ -28,7 +28,7 @@ func @fold_static_stride_subview
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]]
// CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]]
// CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
- %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
%2 = sqrt %1 : f32
store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 52eddfcd69f8..5ca6de5023eb 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -103,8 +103,8 @@ func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
affine.for %arg4 = 0 to %13 step 264 {
%18 = dim %0, 0 : memref<?x?xf32>
%20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32>
- to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
- %24 = dim %20, 0 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
+ to memref<?x?xf32, offset : ?, strides : [?, ?]>
+ %24 = dim %20, 0 : memref<?x?xf32, offset : ?, strides : [?, ?]>
affine.for %arg5 = 0 to %24 step 768 {
"foo"() : () -> ()
}
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index d5148b3424c4..bd6a3e7d7033 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -23,9 +23,9 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
}
}
@@ -88,9 +88,9 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>
}
}
@@ -153,9 +153,9 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>
}
}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index da098a82b1b0..41172aa22527 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -10,15 +10,14 @@
// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
// CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
-// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
// CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 22 + d1)>
-// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
-// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
-// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>
+// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)>
// CHECK-LABEL: func @func_with_ops(%arg0: f32) {
@@ -708,41 +707,56 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%c1 = constant 1 : index
%0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
- // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP0]]>
+ // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<?x?x?xf32, #[[BASE_MAP3]]>
%1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
%2 = alloc()[%arg2] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
- // CHECK: subview %2[%c1] [%arg0] [%c1] : memref<64xf32, #[[BASE_MAP1]]> to memref<?xf32, #[[SUBVIEW_MAP1]]>
+ // CHECK: subview %2[%c1] [%arg0] [%c1] :
+ // CHECK-SAME: memref<64xf32, #[[BASE_MAP1]]>
+ // CHECK-SAME: to memref<?xf32, #[[SUBVIEW_MAP1]]>
%3 = subview %2[%c1][%arg0][%c1]
: memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to
memref<?xf32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
%4 = alloc() : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>>
- // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : memref<64x22xf32, #[[BASE_MAP2]]> to memref<?x?xf32, #[[SUBVIEW_MAP2]]>
+ // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] :
+ // CHECK-SAME: memref<64x22xf32, #[[BASE_MAP2]]>
+ // CHECK-SAME: to memref<?x?xf32, #[[SUBVIEW_MAP2]]>
%5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0]
- : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> to
- memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
+ : memref<64x22xf32, offset:0, strides: [22, 1]> to
+ memref<?x?xf32, offset:?, strides: [?, ?]>
- // CHECK: subview %0[] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]>
- %6 = subview %0[][][]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>>
+ // CHECK: subview %0[0, 2, 0] [4, 4, 4] [1, 1, 1] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]>
+ %6 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
+ : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to
+ memref<4x4x4xf32, offset:8, strides: [64, 4, 1]>
%7 = alloc(%arg1, %arg2) : memref<?x?xf32>
- // CHECK: subview {{%.*}}[] [] [] : memref<?x?xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]]>
- %8 = subview %7[][][]
- : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]>
+ // CHECK: subview {{%.*}}[0, 0] [4, 4] [1, 1] :
+ // CHECK-SAME: memref<?x?xf32>
+ // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP4]]>
+ %8 = subview %7[0, 0][4, 4][1, 1]
+ : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides:[?, 1]>
%9 = alloc() : memref<16x4xf32>
- // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [{{%.*}}, {{%.*}}] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]]
- %10 = subview %9[%arg1, %arg1][][%arg2, %arg2]
+ // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [{{%.*}}, {{%.*}}] :
+ // CHECK-SAME: memref<16x4xf32>
+ // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP2]]
+ %10 = subview %9[%arg1, %arg1][4, 4][%arg2, %arg2]
: memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]>
- // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP5]]
- %11 = subview %9[%arg1, %arg2][][]
+
+ // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [2, 2] :
+ // CHECK-SAME: memref<16x4xf32>
+ // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP5]]
+ %11 = subview %9[%arg1, %arg2][4, 4][2, 2]
: memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]>
+
return
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 0f9fb3ccada5..b0535047874f 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -976,33 +976,22 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
- %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2>
+ %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2>
// expected-error at +1 {{
diff erent memory spaces}}
- %1 = subview %0[][%arg2][]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> to
+ %1 = subview %0[0, 0, 0][%arg2][1, 1, 1]
+ : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to
memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>>
return
}
// -----
-func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
- %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
- // expected-error at +1 {{is not strided}}
- %1 = subview %0[][%arg2][]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0, d1, d2)>>
- return
-}
-
-// -----
-
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>>
// expected-error at +1 {{is not strided}}
- %1 = subview %0[][%arg2][]
+ %1 = subview %0[0, 0, 0][%arg2][1, 1, 1]
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to
- memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>>
+ memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]>
return
}
@@ -1010,8 +999,8 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32>
- // expected-error at +1 {{expected number of dynamic offsets specified to match the rank of the result type}}
- %1 = subview %0[%arg0, %arg1][%arg2][]
+ // expected-error at +1 {{expected 3 offset values}}
+ %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1]
: memref<8x16x4xf32> to
memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]>
return
@@ -1021,7 +1010,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32>
- // expected-error at +1 {{expected result type to have dynamic strides}}
+ // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}}
%1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
: memref<8x16x4xf32> to
memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@@ -1030,106 +1019,6 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
-func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
- %0 = alloc() : memref<8x16x4xf32>
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
- %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1]
- : memref<8x16x4xf32> to
- memref<?x?x?xf32, offset: 0, strides: [?, ?, ?]>
- return
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{expected rank of result type to match rank of base type}}
- %0 = subview %arg1[%arg0, %arg0][][%arg0, %arg0] : memref<?x?xf32> to memref<?xf32>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{expected number of dynamic offsets specified to match the rank of the result type}}
- %0 = subview %arg1[%arg0][][] : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides: [4, 1]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{expected number of dynamic sizes specified to match the rank of the result type}}
- %0 = subview %arg1[][%arg0][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{expected number of dynamic strides specified to match the rank of the result type}}
- %0 = subview %arg1[][][%arg0] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}}
- %0 = subview %arg1[][%arg0, %arg0][] : memref<?x?xf32> to memref<4x8xf32, offset: ?, strides: [?, ?]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
- // expected-error at +1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}}
- %0 = subview %arg1[][][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) {
- // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
- %0 = subview %arg1[%arg0, %arg0][][] : memref<16x4xf32> to memref<4x2xf32>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: ?, strides: [4, 1]>) {
- // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
- %0 = subview %arg1[][][] : memref<16x4xf32, offset: ?, strides: [4, 1]> to memref<4x2xf32>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 8, strides:[?, 1]>) {
- // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
- %0 = subview %arg1[][][] : memref<16x4xf32, offset: 8, strides:[?, 1]> to memref<4x2xf32>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) {
- // expected-error at +1 {{expected result type to have dynamic strides}}
- %0 = subview %arg1[][][%arg0, %arg0] : memref<16x4xf32> to memref<4x2xf32>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 0, strides:[?, ?]>) {
- // expected-error at +1 {{expected result type to have dynamic stride along a dimension if the base memref type has dynamic stride along that dimension}}
- %0 = subview %arg1[][][] : memref<16x4xf32, offset: 0, strides:[?, ?]> to memref<4x2xf32, offset:?, strides:[2, 1]>
-}
-
-// -----
-
-func @invalid_subview(%arg0 : index, %arg1 : memref<?x8x?xf32>) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- // expected-error at +1 {{expected shape of result type to be fully dynamic when sizes are specified}}
- %0 = subview %arg1[%c0, %c0, %c0][%c1, %arg0, %c1][%c1, %c1, %c1] : memref<?x8x?xf32> to memref<?x8x?xf32, offset:?, strides:[?, ?, ?]>
- return
-}
-
-// -----
-
func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index e4090ccd6073..dfcf086c73de 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -427,7 +427,7 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
}
-#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)>
// CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index,
@@ -684,106 +684,138 @@ func @view(%arg0 : index) -> (f32, f32, f32, f32) {
// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 128 + s0 + d1 * 28 + d2 * 11)>
// CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2 + 79)>
-// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
-// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)>
+// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2 * 2)>
+// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
+// CHECK-DAG: #[[SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)>
+
// CHECK-LABEL: func @subview
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
// CHECK: %[[C0:.*]] = constant 0 : index
%c0 = constant 0 : index
- // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK-NOT: constant 1 : index
%c1 = constant 1 : index
- // CHECK: %[[C2:.*]] = constant 2 : index
+ // CHECK-NOT: constant 2 : index
%c2 = constant 2 : index
+ // Folded but reappears after subview folding into dim.
// CHECK: %[[C7:.*]] = constant 7 : index
%c7 = constant 7 : index
+ // Folded but reappears after subview folding into dim.
// CHECK: %[[C11:.*]] = constant 11 : index
%c11 = constant 11 : index
+ // CHECK-NOT: constant 15 : index
%c15 = constant 15 : index
// CHECK: %[[ALLOC0:.*]] = alloc()
- %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
+ %0 = alloc() : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]>
// Test: subview with constant base memref and constant operands is folded.
// Note that the subview uses the base memrefs layout map because it used
// zero offset and unit stride arguments.
- // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]>
+ // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<7x11x2xf32, #[[BASE_MAP0]]>
%1 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- %v0 = load %1[%c0, %c0, %c0] : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
-
- // Test: subview with one dynamic operand should not be folded.
- // CHECK: subview %[[ALLOC0]][%[[C0]], %[[ARG0]], %[[C0]]] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
+ : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ %v0 = load %1[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+
+ // Test: subview with one dynamic operand can also be folded.
+ // CHECK: subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
%2 = subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- store %v0, %2[%c0, %c0, %c0] : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ store %v0, %2[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
// CHECK: %[[ALLOC1:.*]] = alloc(%[[ARG0]])
- %3 = alloc(%arg0) : memref<?x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
+ %3 = alloc(%arg0) : memref<?x16x4xf32, offset : 0, strides : [64, 4, 1]>
// Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static.
- // CHECK: subview %[[ALLOC1]][] [] [] : memref<?x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[BASE_MAP0]]>
+ // CHECK: subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] :
+ // CHECK-SAME: memref<?x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<7x11x15xf32, #[[BASE_MAP0]]>
%4 = subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1]
- : memref<?x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- store %v0, %4[%c0, %c0, %c0] : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ : memref<?x16x4xf32, offset : 0, strides : [64, 4, 1]> to
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ store %v0, %4[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
// Test: subview offset operands are folded correctly w.r.t. base strides.
- // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
+ // CHECK: subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
+ // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
%5 = subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- store %v0, %5[%c0, %c0, %c0] : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ store %v0, %5[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
// Test: subview stride operands are folded correctly w.r.t. base strides.
- // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]>
+ // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
+ // CHECK-SAME: to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]>
%6 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11]
- : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
- memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- store %v0, %6[%c0, %c0, %c0] : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
+ memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ store %v0, %6[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
// Test: subview shape are folded, but offsets and strides are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
- %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %10[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
+ // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
+ %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] :
+ memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %10[%arg1, %arg1, %arg1] :
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// Test: subview strides are folded, but offsets and shape are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP4]]
- %11 = subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %11[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
+ // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP4]]
+ %11 = subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] :
+ memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %11[%arg0, %arg0, %arg0] :
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// Test: subview offsets are folded, but strides and shape are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP5]]
- %13 = subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %13[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] :
+ // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
+ // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP5]]
+ %13 = subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] :
+ memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %13[%arg1, %arg1, %arg1] :
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// CHECK: %[[ALLOC2:.*]] = alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]])
%14 = alloc(%arg0, %arg0, %arg1) : memref<?x?x?xf32>
// Test: subview shape are folded, even if base memref is not static
- // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<?x?x?xf32> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
- %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
+ // CHECK-SAME: memref<?x?x?xf32> to
+ // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
+ %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] :
+ memref<?x?x?xf32> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %15[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- // TEST: subview strides are not folded when the base memref is not static
- // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[C2]], %[[C2]], %[[C2]]] : memref<?x?x?xf32> to memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
- %16 = subview %14[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c2, %c2] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // TEST: subview strides are folded, in the type only the most minor stride is folded.
+ // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 2, 2] :
+ // CHECK-SAME: memref<?x?x?xf32> to
+ // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP6]]
+ %16 = subview %14[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c2, %c2] :
+ memref<?x?x?xf32> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %16[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- // TEST: subview offsets are not folded when the base memref is not static
- // CHECK: subview %[[ALLOC2]][%[[C1]], %[[C1]], %[[C1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<?x?x?xf32> to memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
- %17 = subview %14[%c1, %c1, %c1] [%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // TEST: subview offsets are folded but the type offset remains dynamic, when the base memref is not static
+ // CHECK: subview %[[ALLOC2]][1, 1, 1] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
+ // CHECK-SAME: memref<?x?x?xf32> to
+ // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
+ %17 = subview %14[%c1, %c1, %c1] [%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] :
+ memref<?x?x?xf32> to
+ memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %17[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// CHECK: %[[ALLOC3:.*]] = alloc() : memref<12x4xf32>
@@ -791,20 +823,26 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
%c4 = constant 4 : index
// TEST: subview strides are maintained when sizes are folded
- // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [] [] : memref<12x4xf32> to memref<2x4xf32, #[[SUBVIEW_MAP6]]>
- %19 = subview %18[%arg1, %arg1] [%c2, %c4] [] : memref<12x4xf32> to memref<?x?xf32, offset: ?, strides:[4, 1]>
+ // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] :
+ // CHECK-SAME: memref<12x4xf32> to
+ // CHECK-SAME: memref<2x4xf32, #[[SUBVIEW_MAP7]]>
+ %19 = subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] :
+ memref<12x4xf32> to
+ memref<?x?xf32, offset: ?, strides:[4, 1]>
store %v0, %19[%arg1, %arg1] : memref<?x?xf32, offset: ?, strides:[4, 1]>
// TEST: subview strides and sizes are maintained when offsets are folded
- // CHECK: subview %[[ALLOC3]][] [] [] : memref<12x4xf32> to memref<12x4xf32, #[[SUBVIEW_MAP7]]>
- %20 = subview %18[%c2, %c4] [] [] : memref<12x4xf32> to memref<12x4xf32, offset: ?, strides:[4, 1]>
+ // CHECK: subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] :
+ // CHECK-SAME: memref<12x4xf32> to
+ // CHECK-SAME: memref<12x4xf32, #[[SUBVIEW_MAP8]]>
+ %20 = subview %18[%c2, %c4] [12, 4] [1, 1] :
+ memref<12x4xf32> to
+ memref<12x4xf32, offset: ?, strides:[4, 1]>
store %v0, %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]>
// Test: dim on subview is rewritten to size operand.
- %7 = dim %4, 0 : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
- %8 = dim %4, 1 : memref<?x?x?xf32,
- affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ %7 = dim %4, 0 : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ %8 = dim %4, 1 : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
// CHECK: return %[[C7]], %[[C11]]
return %7, %8 : index, index
@@ -891,15 +929,3 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK: return %[[ARG]]
return %res : tensor<4x5xi32>
}
-
-// -----
-
-// CHECK-LABEL: func @memref_cast_folding_subview
-func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
- %0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
- // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
- %1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
- // CHECK-NEXT: return %{{.*}}
- return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
-}
-
More information about the Mlir-commits
mailing list