[Mlir-commits] [mlir] 691e826 - Revert "[mlir] Revisit std.subview handling of static information."
Sam McCall
llvmlistbot at llvm.org
Tue May 12 06:23:34 PDT 2020
Author: Sam McCall
Date: 2020-05-12T15:18:50+02:00
New Revision: 691e82699591d8f336cd6be52436eeff2417fab9
URL: https://github.com/llvm/llvm-project/commit/691e82699591d8f336cd6be52436eeff2417fab9
DIFF: https://github.com/llvm/llvm-project/commit/691e82699591d8f336cd6be52436eeff2417fab9.diff
LOG: Revert "[mlir] Revisit std.subview handling of static information."
This reverts commit 80d133b24f77d1b9d351251315606441c971ef9b.
Per Stephan Herhut: The canonicalizer pattern that was added creates
forms of the subview op that cannot be lowered.
This is shown by failing Tensorflow XLA tests such as:
tensorflow/compiler/xla/service/mlir_gpu/tests:abs.hlo.test
Will provide more details offline, they rely on logs from private CI.
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/Conversion/StandardToLLVM/invalid.mlir
mlir/test/Conversion/StandardToSPIRV/legalization.mlir
mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
mlir/test/Dialect/Affine/ops.mlir
mlir/test/Dialect/Linalg/promote.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 813dd7db5e9e..f79e955437dd 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2504,7 +2504,7 @@ def SubIOp : IntArithmeticOp<"subi"> {
//===----------------------------------------------------------------------===//
def SubViewOp : Std_Op<"subview", [
- AttrSizedOperandSegments,
+ AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ViewLikeOpInterface>,
NoSideEffect,
]> {
@@ -2516,14 +2516,17 @@ def SubViewOp : Std_Op<"subview", [
The SubView operation supports the following arguments:
*) Memref: the "base" memref on which to create a "view" memref.
- *) Offsets: memref-rank number of dynamic offsets or static integer
- attributes into the "base" memref at which to create the "view"
- memref.
- *) Sizes: memref-rank number of dynamic sizes or static integer attributes
- which specify the sizes of the result "view" memref type.
- *) Strides: memref-rank number of dynamic strides or static integer
- attributes multiplicatively to the base memref strides in each
- dimension.
+ *) Offsets: zero or memref-rank number of dynamic offsets into the "base"
+ memref at which to create the "view" memref.
+ *) Sizes: zero or memref-rank dynamic size operands which specify the
+ dynamic sizes of the result "view" memref type.
+ *) Strides: zero or memref-rank number of dynamic strides which are applied
+ multiplicatively to the base memref strides in each dimension.
+
+ Note on the number of operands for offsets, sizes and strides: For
+ each of these, the number of operands must either be same as the
+ memref-rank number or empty. For the latter, those values will be
+ treated as constants.
Example 1:
@@ -2534,7 +2537,7 @@ def SubViewOp : Std_Op<"subview", [
// dynamic sizes for each dimension, and stride arguments '%c1'.
%1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1]
: memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to
- memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+ memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s1 + d1 + s0)>
```
Example 2:
@@ -2561,9 +2564,9 @@ def SubViewOp : Std_Op<"subview", [
%0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)>
// Subview with constant offsets, sizes and strides.
- %1 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
+ %1 = subview %0[][][]
: memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
- memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>
+ memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
```
Example 4:
@@ -2605,7 +2608,7 @@ def SubViewOp : Std_Op<"subview", [
// #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)
//
// where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1.
- %1 = subview %0[%i, %j][4, 4][%x, %y] :
+ %1 = subview %0[%i, %j][][%x, %y] :
: memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> to
memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)>
@@ -2621,25 +2624,24 @@ def SubViewOp : Std_Op<"subview", [
AnyMemRef:$source,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
- Variadic<Index>:$strides,
- I64ArrayAttr:$static_offsets,
- I64ArrayAttr:$static_sizes,
- I64ArrayAttr:$static_strides
+ Variadic<Index>:$strides
);
let results = (outs AnyMemRef:$result);
+ let assemblyFormat = [{
+ $source `[` $offsets `]` `[` $sizes `]` `[` $strides `]` attr-dict `:`
+ type($source) `to` type($result)
+ }];
+
let builders = [
- // Build a SubViewOp with mized static and dynamic entries.
OpBuilder<
"OpBuilder &b, OperationState &result, Value source, "
- "ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,"
- "ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes, "
- "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">,
- // Build a SubViewOp with all dynamic entries.
+ "ValueRange offsets, ValueRange sizes, "
+ "ValueRange strides, Type resultType = Type(), "
+ "ArrayRef<NamedAttribute> attrs = {}">,
OpBuilder<
- "OpBuilder &b, OperationState &result, Value source, "
- "ValueRange offsets, ValueRange sizes, ValueRange strides, "
- "ArrayRef<NamedAttribute> attrs = {}">
+ "OpBuilder &builder, OperationState &result, "
+ "Type resultType, Value source">
];
let extraClassDeclaration = [{
@@ -2668,34 +2670,13 @@ def SubViewOp : Std_Op<"subview", [
/// operands could not be retrieved.
LogicalResult getStaticStrides(SmallVectorImpl<int64_t> &staticStrides);
- /// Auxiliary range data structure and helper function that unpacks the
- /// offset, size and stride operands of the SubViewOp into a list of triples.
- /// Such a list of triple is sometimes more convenient to manipulate.
+ // Auxiliary range data structure and helper function that unpacks the
+ // offset, size and stride operands of the SubViewOp into a list of triples.
+ // Such a list of triple is sometimes more convenient to manipulate.
struct Range {
Value offset, size, stride;
};
SmallVector<Range, 8> getRanges();
-
- /// Return the rank of the result MemRefType.
- unsigned getRank() { return getType().getRank(); }
-
- static StringRef getStaticOffsetsAttrName() {
- return "static_offsets";
- }
- static StringRef getStaticSizesAttrName() {
- return "static_sizes";
- }
- static StringRef getStaticStridesAttrName() {
- return "static_strides";
- }
- static ArrayRef<StringRef> getSpecialAttrNames() {
- static SmallVector<StringRef, 4> names{
- getStaticOffsetsAttrName(),
- getStaticSizesAttrName(),
- getStaticStridesAttrName(),
- getOperandSegmentSizeAttr()};
- return names;
- }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 39dc5d203a61..553be944ab30 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1293,7 +1293,7 @@ static LogicalResult verify(DimOp op) {
auto indexAttr = op.getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return op.emitOpError("requires an integer attribute named 'index'");
- int64_t index = indexAttr.getInt();
+ int64_t index = indexAttr.getValue().getSExtValue();
auto type = op.getOperand().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
@@ -2183,272 +2183,59 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
// SubViewOp
//===----------------------------------------------------------------------===//
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void printSubViewListOfOperandsOrIntegers(
- OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- p << "[";
- unsigned idx = 0;
- llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
- int64_t val = a.cast<IntegerAttr>().getInt();
- if (isDynamic(val))
- p << values[idx++];
- else
- p << val;
- });
- p << "] ";
-}
-
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-/// 2. `ssa` is filled with "[%arg0, %arg1]".
-static ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
- StringRef attrName, int64_t dynVal,
- SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
- if (failed(parser.parseLSquare()))
- return failure();
- // 0-D.
- if (succeeded(parser.parseOptionalRSquare()))
- return success();
-
- SmallVector<int64_t, 4> attrVals;
- while (true) {
- OpAsmParser::OperandType operand;
- auto res = parser.parseOptionalOperand(operand);
- if (res.hasValue() && succeeded(res.getValue())) {
- ssa.push_back(operand);
- attrVals.push_back(dynVal);
- } else {
- Attribute attr;
- NamedAttrList placeholder;
- if (failed(parser.parseAttribute(attr, "_", placeholder)) ||
- !attr.isa<IntegerAttr>())
- return parser.emitError(parser.getNameLoc())
- << "expected SSA value or integer";
- attrVals.push_back(attr.cast<IntegerAttr>().getInt());
- }
-
- if (succeeded(parser.parseOptionalComma()))
- continue;
- if (failed(parser.parseRSquare()))
- return failure();
- else
- break;
- }
-
- auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
- result.addAttribute(attrName, arrayAttr);
- return success();
-}
-
-namespace {
-/// Helpers to write more idiomatic operations.
-namespace saturated_arith {
-struct Wrapper {
- explicit Wrapper(int64_t v) : v(v) {}
- operator int64_t() { return v; }
- int64_t v;
-};
-Wrapper operator+(Wrapper a, int64_t b) {
- if (ShapedType::isDynamicStrideOrOffset(a) ||
- ShapedType::isDynamicStrideOrOffset(b))
- return Wrapper(ShapedType::kDynamicStrideOrOffset);
- return Wrapper(a.v + b);
-}
-Wrapper operator*(Wrapper a, int64_t b) {
- if (ShapedType::isDynamicStrideOrOffset(a) ||
- ShapedType::isDynamicStrideOrOffset(b))
- return Wrapper(ShapedType::kDynamicStrideOrOffset);
- return Wrapper(a.v * b);
-}
-} // end namespace saturated_arith
-} // end namespace
-
-/// A subview result type can be fully inferred from the source type and the
-/// static representation of offsets, sizes and strides. Special sentinels
-/// encode the dynamic case.
-static Type inferSubViewResultType(MemRefType sourceMemRefType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
- unsigned rank = sourceMemRefType.getRank();
- (void)rank;
- assert(staticOffsets.size() == rank &&
- "unexpected staticOffsets size mismatch");
- assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch");
- assert(staticStrides.size() == rank &&
- "unexpected staticStrides size mismatch");
-
- // Extract source offset and strides.
- int64_t sourceOffset;
- SmallVector<int64_t, 4> sourceStrides;
- auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
+// Returns a MemRefType with dynamic sizes and offset and the same stride as the
+// `memRefType` passed as argument.
+// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep
+// sizes and offset static.
+static Type inferSubViewResultType(MemRefType memRefType) {
+ auto rank = memRefType.getRank();
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && "SubViewOp expected strided memref type");
(void)res;
- // Compute target offset whose value is:
- // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
- int64_t targetOffset = sourceOffset;
- for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
- auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
- using namespace saturated_arith;
- targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
- }
-
- // Compute target stride whose value is:
- // `sourceStrides_i * staticStrides_i`.
- SmallVector<int64_t, 4> targetStrides;
- targetStrides.reserve(staticOffsets.size());
- for (auto it : llvm::zip(sourceStrides, staticStrides)) {
- auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
- using namespace saturated_arith;
- targetStrides.push_back(Wrapper(sourceStride) * staticStride);
- }
-
- // The type is now known.
- return MemRefType::get(
- staticSizes, sourceMemRefType.getElementType(),
- makeStridedLinearLayoutMap(targetStrides, targetOffset,
- sourceMemRefType.getContext()),
- sourceMemRefType.getMemorySpace());
-}
-
-/// Print SubViewOp in the form:
-/// ```
-/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` strided-memref-type `to` strided-memref-type
-/// ```
-static void print(OpAsmPrinter &p, SubViewOp op) {
- int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.getOperand(0);
- printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
- p.printOptionalAttrDict(op.getAttrs(),
- /*elided=*/{SubViewOp::getSpecialAttrNames()});
- p << " : " << op.getOperand(0).getType() << " to " << op.getType();
-}
-
-/// Parse SubViewOp of the form:
-/// ```
-/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` strided-memref-type `to` strided-memref-type
-/// ```
-static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
- auto indexType = parser.getBuilder().getIndexType();
- Type srcType, dstType;
- if (parser.parseOperand(srcInfo))
- return failure();
- if (parseListOfOperandsOrIntegers(
- parser, result, SubViewOp::getStaticOffsetsAttrName(),
- ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
- parseListOfOperandsOrIntegers(parser, result,
- SubViewOp::getStaticSizesAttrName(),
- ShapedType::kDynamicSize, sizesInfo) ||
- parseListOfOperandsOrIntegers(
- parser, result, SubViewOp::getStaticStridesAttrName(),
- ShapedType::kDynamicStrideOrOffset, stridesInfo))
- return failure();
-
- auto b = parser.getBuilder();
- SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
- static_cast<int>(sizesInfo.size()),
- static_cast<int>(stridesInfo.size())};
- result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(),
- b.getI32VectorAttr(segmentSizes));
-
- return failure(
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands) ||
- parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
- parser.resolveOperands(sizesInfo, indexType, result.operands) ||
- parser.resolveOperands(stridesInfo, indexType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
- parser.addTypeToList(dstType, result.types));
+ // Assume sizes and offset are fully dynamic for now until canonicalization
+ // occurs on the ranges. Typed strides don't change though.
+ offset = MemRefType::getDynamicStrideOrOffset();
+ // Overwrite strides because verifier will not pass.
+ // TODO(b/144419106): don't force degrade the strides to fully dynamic.
+ for (auto &stride : strides)
+ stride = MemRefType::getDynamicStrideOrOffset();
+ auto stridedLayout =
+ makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
+ SmallVector<int64_t, 4> sizes(rank, ShapedType::kDynamicSize);
+ return MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setAffineMaps(stridedLayout);
}
-void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides, ValueRange offsets,
- ValueRange sizes, ValueRange strides,
- ArrayRef<NamedAttribute> attrs) {
- auto sourceMemRefType = source.getType().cast<MemRefType>();
- auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets,
- staticSizes, staticStrides);
- build(b, result, resultType, source, offsets, sizes, strides,
- b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
- b.getI64ArrayAttr(staticStrides));
- result.addAttributes(attrs);
-}
-
-/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
-/// and `staticStrides` are automatically filled with source-memref-rank
-/// sentinel values that encode dynamic entries.
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange offsets, ValueRange sizes,
- ValueRange strides,
+ ValueRange strides, Type resultType,
ArrayRef<NamedAttribute> attrs) {
- auto sourceMemRefType = source.getType().cast<MemRefType>();
- unsigned rank = sourceMemRefType.getRank();
- SmallVector<int64_t, 4> staticOffsetsVector;
- staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
- SmallVector<int64_t, 4> staticSizesVector;
- staticSizesVector.assign(rank, ShapedType::kDynamicSize);
- SmallVector<int64_t, 4> staticStridesVector;
- staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset);
- build(b, result, source, staticOffsetsVector, staticSizesVector,
- staticStridesVector, offsets, sizes, strides, attrs);
-}
-
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-static LogicalResult
-verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName,
- ArrayAttr attr, llvm::function_ref<bool(int64_t)> isDynamic,
- ValueRange values) {
- /// Check static and dynamic offsets/sizes/strides breakdown.
- if (attr.size() != op.getRank())
- return op.emitError("expected ")
- << op.getRank() << " " << name << " values";
- unsigned expectedNumDynamicEntries =
- llvm::count_if(attr.getValue(), [&](Attribute attr) {
- return isDynamic(attr.cast<IntegerAttr>().getInt());
- });
- if (values.size() != expectedNumDynamicEntries)
- return op.emitError("expected ")
- << expectedNumDynamicEntries << " dynamic " << name << " values";
- return success();
+ if (!resultType)
+ resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
+ build(b, result, resultType, source, offsets, sizes, strides);
+ result.addAttributes(attrs);
}
-/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
-static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
- return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
- }));
+void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
+ Type resultType, Value source) {
+ build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{},
+ resultType);
}
-/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
auto baseType = op.getBaseMemRefType().cast<MemRefType>();
auto subViewType = op.getType();
+ // The rank of the base and result subview must match.
+ if (baseType.getRank() != subViewType.getRank()) {
+ return op.emitError(
+ "expected rank of result type to match rank of base type ");
+ }
+
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
return op.emitError("
diff erent memory spaces specified for base memref "
@@ -2456,32 +2243,96 @@ static LogicalResult verify(SubViewOp op) {
<< baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map.
- if (!isStrided(baseType))
- return op.emitError("base type ") << baseType << " is not strided";
+ int64_t baseOffset;
+ SmallVector<int64_t, 4> baseStrides;
+ if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset)))
+ return op.emitError("base type ") << subViewType << " is not strided";
+
+ // Verify that the result memref type has a strided layout map.
+ int64_t subViewOffset;
+ SmallVector<int64_t, 4> subViewStrides;
+ if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset)))
+ return op.emitError("result type ") << subViewType << " is not strided";
+
+ // Num offsets should either be zero or rank of memref.
+ if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic offsets specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
- // Verify static attributes offsets/sizes/strides.
- if (failed(verifySubViewOpPart(
- op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset, op.offsets())))
- return failure();
+ // Num sizes should either be zero or rank of memref.
+ if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic sizes specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
- if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(),
- op.static_sizes(), ShapedType::isDynamic,
- op.sizes())))
- return failure();
- if (failed(verifySubViewOpPart(
- op, "stride", op.getStaticStridesAttrName(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset, op.strides())))
- return failure();
+ // Num strides should either be zero or rank of memref.
+ if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) {
+ return op.emitError("expected number of dynamic strides specified to match "
+ "the rank of the result type ")
+ << subViewType;
+ }
- // Verify result type against inferred type.
- auto expectedType = inferSubViewResultType(
- op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()),
- extractFromI64ArrayAttr(op.static_sizes()),
- extractFromI64ArrayAttr(op.static_strides()));
- if (op.getType() != expectedType)
- return op.emitError("expected result type to be ") << expectedType;
+ // Verify that if the shape of the subview type is static, then sizes are not
+ // dynamic values, and vice versa.
+ if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) ||
+ (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) {
+ return op.emitError("invalid to specify dynamic sizes when subview result "
+ "type is statically shaped and viceversa");
+ }
+
+ // Verify that if dynamic sizes are specified, then the result memref type
+ // have full dynamic dimensions.
+ if (op.getNumSizes() > 0) {
+ if (llvm::any_of(subViewType.getShape(), [](int64_t dim) {
+ return dim != ShapedType::kDynamicSize;
+ })) {
+ // TODO: This is based on the assumption that number of size arguments are
+ // either 0, or the rank of the result type. It is possible to have more
+ // fine-grained verification where only particular dimensions are
+ // dynamic. That probably needs further changes to the shape op
+ // specification.
+ return op.emitError("expected shape of result type to be fully dynamic "
+ "when sizes are specified");
+ }
+ }
+
+ // Verify that if dynamic offsets are specified or base memref has dynamic
+ // offset or base memref has dynamic strides, then the subview offset is
+ // dynamic.
+ if ((op.getNumOffsets() > 0 ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset())) &&
+ subViewOffset != MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result memref layout map to have dynamic offset");
+ }
+
+ // For now, verify that if dynamic strides are specified, then all the result
+ // memref type have dynamic strides.
+ if (op.getNumStrides() > 0) {
+ if (llvm::any_of(subViewStrides, [](int64_t stride) {
+ return stride != MemRefType::getDynamicStrideOrOffset();
+ })) {
+ return op.emitError("expected result type to have dynamic strides");
+ }
+ }
+ // If any of the base memref has dynamic stride, then the corresponding
+ // stride of the subview must also have dynamic stride.
+ assert(baseStrides.size() == subViewStrides.size());
+ for (auto stride : enumerate(baseStrides)) {
+ if (stride.value() == MemRefType::getDynamicStrideOrOffset() &&
+ subViewStrides[stride.index()] !=
+ MemRefType::getDynamicStrideOrOffset()) {
+ return op.emitError(
+ "expected result type to have dynamic stride along a dimension if "
+ "the base memref type has dynamic stride along that dimension");
+ }
+ }
return success();
}
@@ -2502,9 +2353,37 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
- if (!strides().empty())
+ // If the strides are dynamic return failure.
+ if (getNumStrides())
+ return failure();
+
+ // When static, the stride operands can be retrieved by taking the strides of
+ // the result of the subview op, and dividing the strides of the base memref.
+ int64_t resultOffset, baseOffset;
+ SmallVector<int64_t, 2> resultStrides, baseStrides;
+ if (failed(
+ getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) ||
+ llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) ||
+ failed(getStridesAndOffset(getType(), resultStrides, resultOffset)))
return failure();
- staticStrides = extractFromI64ArrayAttr(static_strides());
+
+ assert(static_cast<int64_t>(resultStrides.size()) == getType().getRank() &&
+ baseStrides.size() == resultStrides.size() &&
+ "base and result memrefs must have the same rank");
+ assert(!llvm::is_contained(resultStrides,
+ MemRefType::getDynamicStrideOrOffset()) &&
+ "strides of subview op must be static, when there are no dynamic "
+ "strides specified");
+ staticStrides.resize(getType().getRank());
+ for (auto resultStride : enumerate(resultStrides)) {
+ auto baseStride = baseStrides[resultStride.index()];
+ // The result stride is expected to be a multiple of the base stride. Abort
+ // if that is not the case.
+ if (resultStride.value() < baseStride ||
+ resultStride.value() % baseStride != 0)
+ return failure();
+ staticStrides[resultStride.index()] = resultStride.value() / baseStride;
+ }
return success();
}
@@ -2512,80 +2391,136 @@ Value SubViewOp::getViewSource() { return source(); }
namespace {
-/// Take a list of `values` with potential new constant to extract and a list
-/// of `constantValues` with`values.size()` sentinel that evaluate to true by
-/// applying `isDynamic`.
-/// Detects the `values` produced by a ConstantIndexOp and places the new
-/// constant in place of the corresponding sentinel value.
-void canonicalizeSubViewPart(SmallVectorImpl<Value> &values,
- SmallVectorImpl<int64_t> &constantValues,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- bool hasNewStaticValue = llvm::any_of(
- values, [](Value val) { return matchPattern(val, m_ConstantIndex()); });
- if (hasNewStaticValue) {
- for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size();
- cstIdx != e; ++cstIdx) {
- // Was already static, skip.
- if (!isDynamic(constantValues[cstIdx]))
- continue;
- // Newly static, move from Value to constant.
- if (matchPattern(values[valIdx], m_ConstantIndex())) {
- constantValues[cstIdx] =
- cast<ConstantIndexOp>(values[valIdx].getDefiningOp()).getValue();
- // Erase for impl. simplicity. Reverse iterator if we really must.
- values.erase(std::next(values.begin(), valIdx));
- continue;
- }
- // Remains dynamic move to next value.
- ++valIdx;
+/// Pattern to rewrite a subview op with constant size arguments.
+class SubViewOpShapeFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ MemRefType subViewType = subViewOp.getType();
+ // Follow all or nothing approach for shapes for now. If all the operands
+ // for sizes are constants then fold it into the type of the result memref.
+ if (subViewType.hasStaticShape() ||
+ llvm::any_of(subViewOp.sizes(), [](Value operand) {
+ return !matchPattern(operand, m_ConstantIndex());
+ })) {
+ return failure();
+ }
+ SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
+ for (auto size : llvm::enumerate(subViewOp.sizes())) {
+ auto defOp = size.value().getDefiningOp();
+ assert(defOp);
+ staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
}
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setShape(staticShape);
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ ArrayRef<Value>(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
+ subViewOp.getType());
+ return success();
}
-}
+};
-/// Pattern to rewrite a subview op with constant arguments.
-class SubViewOpFolder final : public OpRewritePattern<SubViewOp> {
+// Pattern to rewrite a subview op with constant stride arguments.
+class SubViewOpStrideFolder final : public OpRewritePattern<SubViewOp> {
public:
using OpRewritePattern<SubViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SubViewOp subViewOp,
PatternRewriter &rewriter) const override {
- // No constant operand, just return;
- if (llvm::none_of(subViewOp.getOperands(), [](Value operand) {
- return matchPattern(operand, m_ConstantIndex());
- }))
+ if (subViewOp.getNumStrides() == 0) {
+ return failure();
+ }
+ // Follow all or nothing approach for strides for now. If all the operands
+ // for strides are constants then fold it into the strides of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ llvm::any_of(subViewOp.strides(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
return failure();
+ }
- // At least one of offsets/sizes/strides is a new constant.
- // Form the new list of operands and constant attributes from the existing.
- SmallVector<Value, 8> newOffsets(subViewOp.offsets());
- SmallVector<int64_t, 8> newStaticOffsets =
- extractFromI64ArrayAttr(subViewOp.static_offsets());
- assert(newStaticOffsets.size() == subViewOp.getRank());
- canonicalizeSubViewPart(newOffsets, newStaticOffsets,
- ShapedType::isDynamicStrideOrOffset);
-
- SmallVector<Value, 8> newSizes(subViewOp.sizes());
- SmallVector<int64_t, 8> newStaticSizes =
- extractFromI64ArrayAttr(subViewOp.static_sizes());
- assert(newStaticOffsets.size() == subViewOp.getRank());
- canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic);
-
- SmallVector<Value, 8> newStrides(subViewOp.strides());
- SmallVector<int64_t, 8> newStaticStrides =
- extractFromI64ArrayAttr(subViewOp.static_strides());
- assert(newStaticOffsets.size() == subViewOp.getRank());
- canonicalizeSubViewPart(newStrides, newStaticStrides,
- ShapedType::isDynamicStrideOrOffset);
-
- // Create the new op in canonical form.
+ SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
+ for (auto stride : llvm::enumerate(subViewOp.strides())) {
+ auto defOp = stride.value().getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[stride.index()] > 0);
+ staticStrides[stride.index()] =
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[stride.index()];
+ }
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ staticStrides, resultOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
auto newSubViewOp = rewriter.create<SubViewOp>(
- subViewOp.getLoc(), subViewOp.source(), newStaticOffsets,
- newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides);
-
+ subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(),
+ subViewOp.sizes(), ArrayRef<Value>(), newMemRefType);
// Insert a memref_cast for compatibility of the uses of the op.
rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
subViewOp.getType());
+ return success();
+ }
+};
+// Pattern to rewrite a subview op with constant offset arguments.
+class SubViewOpOffsetFolder final : public OpRewritePattern<SubViewOp> {
+public:
+ using OpRewritePattern<SubViewOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubViewOp subViewOp,
+ PatternRewriter &rewriter) const override {
+ if (subViewOp.getNumOffsets() == 0) {
+ return failure();
+ }
+ // Follow all or nothing approach for offsets for now. If all the operands
+ // for offsets are constants then fold it into the offset of the result
+ // memref.
+ int64_t baseOffset, resultOffset;
+ SmallVector<int64_t, 4> baseStrides, resultStrides;
+ MemRefType subViewType = subViewOp.getType();
+ if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides,
+ baseOffset)) ||
+ failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) ||
+ llvm::is_contained(baseStrides,
+ MemRefType::getDynamicStrideOrOffset()) ||
+ baseOffset == MemRefType::getDynamicStrideOrOffset() ||
+ llvm::any_of(subViewOp.offsets(), [](Value stride) {
+ return !matchPattern(stride, m_ConstantIndex());
+ })) {
+ return failure();
+ }
+
+ auto staticOffset = baseOffset;
+ for (auto offset : llvm::enumerate(subViewOp.offsets())) {
+ auto defOp = offset.value().getDefiningOp();
+ assert(defOp);
+ assert(baseStrides[offset.index()] > 0);
+ staticOffset +=
+ cast<ConstantIndexOp>(defOp).getValue() * baseStrides[offset.index()];
+ }
+
+ AffineMap layoutMap = makeStridedLinearLayoutMap(
+ resultStrides, staticOffset, rewriter.getContext());
+ MemRefType newMemRefType =
+ MemRefType::Builder(subViewType).setAffineMaps(layoutMap);
+ auto newSubViewOp = rewriter.create<SubViewOp>(
+ subViewOp.getLoc(), subViewOp.source(), ArrayRef<Value>(),
+ subViewOp.sizes(), subViewOp.strides(), newMemRefType);
+ // Insert a memref_cast for compatibility of the uses of the op.
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(subViewOp, newSubViewOp,
+ subViewOp.getType());
return success();
}
};
@@ -2698,7 +2633,8 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<SubViewOpFolder>(context);
+ results.insert<SubViewOpShapeFolder, SubViewOpStrideFolder,
+ SubViewOpOffsetFolder>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 41ed5315ab1c..4cc8e11294d7 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -839,7 +839,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG2:.*]]: !llvm.i32)
-func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -883,8 +883,7 @@ func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index,
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
%1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
- memref<64x4xf32, offset: 0, strides: [4, 1]>
- to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
return
}
@@ -900,7 +899,7 @@ func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index,
// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32,
// CHECK32: %[[ARG2:.*]]: !llvm.i32)
-func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -944,14 +943,13 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
%1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
- memref<64x4xf32, offset: 0, strides: [4, 1], 3>
- to memref<?x?xf32, offset: ?, strides: [?, ?], 3>
+ memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3> to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>, 3>
return
}
// CHECK-LABEL: func @subview_const_size(
// CHECK32-LABEL: func @subview_const_size(
-func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -998,15 +996,14 @@ func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32
// CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] :
- memref<64x4xf32, offset: 0, strides: [4, 1]>
- to memref<4x2xf32, offset: ?, strides: [?, ?]>
+ %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] :
+ memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<4x2xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
return
}
// CHECK-LABEL: func @subview_const_stride(
// CHECK32-LABEL: func @subview_const_stride(
-func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
+func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -1049,15 +1046,14 @@ func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %a
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
// CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] :
- memref<64x4xf32, offset: 0, strides: [4, 1]>
- to memref<?x?xf32, offset: ?, strides: [4, 2]>
+ %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] :
+ memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4 + d1 * 2 + s0)>>
return
}
// CHECK-LABEL: func @subview_const_stride_and_offset(
// CHECK32-LABEL: func @subview_const_stride_and_offset(
-func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>) {
+func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
// CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
@@ -1096,9 +1092,8 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides:
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
// CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }">
- %1 = subview %0[0, 8][62, 3][1, 1] :
- memref<64x4xf32, offset: 0, strides: [4, 1]>
- to memref<62x3xf32, offset: 8, strides: [4, 1]>
+ %1 = subview %0[][][] :
+ memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>>
return
}
diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
index 1be148707458..bb9c2728dcb8 100644
--- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
@@ -7,7 +7,7 @@ func @invalid_memref_cast(%arg0: memref<?x?xf64>) {
%c0 = constant 0 : index
// expected-error at +1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}}
%5 = memref_cast %arg0 : memref<?x?xf64> to memref<?x?xf64, #map1>
- %25 = std.subview %5[%c0, %c0][%c1, %c1][1, 1] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
+ %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref<?x?xf64, #map1> to memref<?x?xf64, #map1>
return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
index d3b339e82a88..3540a101c55b 100644
--- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir
@@ -11,7 +11,7 @@ func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : in
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return %1 : f32
}
@@ -25,8 +25,7 @@ func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : i
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
- memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return %1 : f32
}
@@ -42,8 +41,7 @@ func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : i
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] :
- memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
return
}
@@ -57,8 +55,7 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 :
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
- %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
- memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
+ %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]>
store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
index 2e28079b4f15..c0d904adb51c 100644
--- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
@@ -28,7 +28,7 @@ func @fold_static_stride_subview
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]]
// CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]]
// CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
- %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+ %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
%2 = sqrt %1 : f32
store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 5ca6de5023eb..52eddfcd69f8 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -103,8 +103,8 @@ func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
affine.for %arg4 = 0 to %13 step 264 {
%18 = dim %0, 0 : memref<?x?xf32>
%20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32>
- to memref<?x?xf32, offset : ?, strides : [?, ?]>
- %24 = dim %20, 0 : memref<?x?xf32, offset : ?, strides : [?, ?]>
+ to memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
+ %24 = dim %20, 0 : memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
affine.for %arg5 = 0 to %24 step 768 {
"foo"() : () -> ()
}
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index bd6a3e7d7033..d5148b3424c4 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -23,9 +23,9 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
}
}
@@ -88,9 +88,9 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xf64> to memref<?x?xf64, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>, memref<?x?xf64, offset: ?, strides: [?, 1]>
}
}
@@ -153,9 +153,9 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
loop.for %arg4 = %c0 to %6 step %c2 {
loop.for %arg5 = %c0 to %8 step %c3 {
loop.for %arg6 = %c0 to %7 step %c4 {
- %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
- %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
- %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
+ %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref<?x?xi32> to memref<?x?xi32, offset: ?, strides: [?, 1]>
linalg.matmul(%11, %14, %17) : memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>, memref<?x?xi32, offset: ?, strides: [?, 1]>
}
}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 41172aa22527..da098a82b1b0 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -10,14 +10,15 @@
// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>
// CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
// CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 22 + d1)>
-// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>
-// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>
+// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
// CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)>
// CHECK-LABEL: func @func_with_ops(%arg0: f32) {
@@ -707,56 +708,41 @@ func @memref_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%c1 = constant 1 : index
%0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
- // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<?x?x?xf32, #[[BASE_MAP3]]>
+ // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP0]]>
%1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1]
- : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
%2 = alloc()[%arg2] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
- // CHECK: subview %2[%c1] [%arg0] [%c1] :
- // CHECK-SAME: memref<64xf32, #[[BASE_MAP1]]>
- // CHECK-SAME: to memref<?xf32, #[[SUBVIEW_MAP1]]>
+ // CHECK: subview %2[%c1] [%arg0] [%c1] : memref<64xf32, #[[BASE_MAP1]]> to memref<?xf32, #[[SUBVIEW_MAP1]]>
%3 = subview %2[%c1][%arg0][%c1]
: memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to
memref<?xf32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
%4 = alloc() : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>>
- // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] :
- // CHECK-SAME: memref<64x22xf32, #[[BASE_MAP2]]>
- // CHECK-SAME: to memref<?x?xf32, #[[SUBVIEW_MAP2]]>
+ // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : memref<64x22xf32, #[[BASE_MAP2]]> to memref<?x?xf32, #[[SUBVIEW_MAP2]]>
%5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0]
- : memref<64x22xf32, offset:0, strides: [22, 1]> to
- memref<?x?xf32, offset:?, strides: [?, ?]>
+ : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> to
+ memref<?x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>>
- // CHECK: subview %0[0, 2, 0] [4, 4, 4] [1, 1, 1] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]>
- %6 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1]
- : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to
- memref<4x4x4xf32, offset:8, strides: [64, 4, 1]>
+ // CHECK: subview %0[] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]>
+ %6 = subview %0[][][]
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>>
%7 = alloc(%arg1, %arg2) : memref<?x?xf32>
- // CHECK: subview {{%.*}}[0, 0] [4, 4] [1, 1] :
- // CHECK-SAME: memref<?x?xf32>
- // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP4]]>
- %8 = subview %7[0, 0][4, 4][1, 1]
- : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides:[?, 1]>
+ // CHECK: subview {{%.*}}[] [] [] : memref<?x?xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]]>
+ %8 = subview %7[][][]
+ : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]>
%9 = alloc() : memref<16x4xf32>
- // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [{{%.*}}, {{%.*}}] :
- // CHECK-SAME: memref<16x4xf32>
- // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP2]]
- %10 = subview %9[%arg1, %arg1][4, 4][%arg2, %arg2]
+ // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [{{%.*}}, {{%.*}}] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]]
+ %10 = subview %9[%arg1, %arg1][][%arg2, %arg2]
: memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]>
-
- // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [2, 2] :
- // CHECK-SAME: memref<16x4xf32>
- // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP5]]
- %11 = subview %9[%arg1, %arg2][4, 4][2, 2]
+ // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP5]]
+ %11 = subview %9[%arg1, %arg2][][]
: memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]>
-
return
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b0535047874f..0f9fb3ccada5 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -976,22 +976,33 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
- %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2>
+ %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2>
// expected-error at +1 {{
diff erent memory spaces}}
- %1 = subview %0[0, 0, 0][%arg2][1, 1, 1]
- : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to
+ %1 = subview %0[][%arg2][]
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> to
memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>>
return
}
// -----
+func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
+ // expected-error at +1 {{is not strided}}
+ %1 = subview %0[][%arg2][]
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0, d1, d2)>>
+ return
+}
+
+// -----
+
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>>
// expected-error at +1 {{is not strided}}
- %1 = subview %0[0, 0, 0][%arg2][1, 1, 1]
+ %1 = subview %0[][%arg2][]
: memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to
- memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]>
+ memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>>
return
}
@@ -999,8 +1010,8 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32>
- // expected-error at +1 {{expected 3 offset values}}
- %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1]
+ // expected-error at +1 {{expected number of dynamic offsets specified to match the rank of the result type}}
+ %1 = subview %0[%arg0, %arg1][%arg2][]
: memref<8x16x4xf32> to
memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]>
return
@@ -1010,7 +1021,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32>
- // expected-error at +1 {{expected result type to be 'memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}}
+ // expected-error at +1 {{expected result type to have dynamic strides}}
%1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2]
: memref<8x16x4xf32> to
memref<?x?x?xf32, offset: ?, strides: [64, 4, 1]>
@@ -1019,6 +1030,106 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
+func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %0 = alloc() : memref<8x16x4xf32>
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
+ %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1]
+ : memref<8x16x4xf32> to
+ memref<?x?x?xf32, offset: 0, strides: [?, ?, ?]>
+ return
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{expected rank of result type to match rank of base type}}
+ %0 = subview %arg1[%arg0, %arg0][][%arg0, %arg0] : memref<?x?xf32> to memref<?xf32>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{expected number of dynamic offsets specified to match the rank of the result type}}
+ %0 = subview %arg1[%arg0][][] : memref<?x?xf32> to memref<4x4xf32, offset: ?, strides: [4, 1]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{expected number of dynamic sizes specified to match the rank of the result type}}
+ %0 = subview %arg1[][%arg0][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{expected number of dynamic strides specified to match the rank of the result type}}
+ %0 = subview %arg1[][][%arg0] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}}
+ %0 = subview %arg1[][%arg0, %arg0][] : memref<?x?xf32> to memref<4x8xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x?xf32>) {
+ // expected-error at +1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}}
+ %0 = subview %arg1[][][] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) {
+ // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
+ %0 = subview %arg1[%arg0, %arg0][][] : memref<16x4xf32> to memref<4x2xf32>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: ?, strides: [4, 1]>) {
+ // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
+ %0 = subview %arg1[][][] : memref<16x4xf32, offset: ?, strides: [4, 1]> to memref<4x2xf32>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 8, strides:[?, 1]>) {
+ // expected-error at +1 {{expected result memref layout map to have dynamic offset}}
+ %0 = subview %arg1[][][] : memref<16x4xf32, offset: 8, strides:[?, 1]> to memref<4x2xf32>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) {
+ // expected-error at +1 {{expected result type to have dynamic strides}}
+ %0 = subview %arg1[][][%arg0, %arg0] : memref<16x4xf32> to memref<4x2xf32>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 0, strides:[?, ?]>) {
+ // expected-error at +1 {{expected result type to have dynamic stride along a dimension if the base memref type has dynamic stride along that dimension}}
+ %0 = subview %arg1[][][] : memref<16x4xf32, offset: 0, strides:[?, ?]> to memref<4x2xf32, offset:?, strides:[2, 1]>
+}
+
+// -----
+
+func @invalid_subview(%arg0 : index, %arg1 : memref<?x8x?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ // expected-error at +1 {{expected shape of result type to be fully dynamic when sizes are specified}}
+ %0 = subview %arg1[%c0, %c0, %c0][%c1, %arg0, %c1][%c1, %c1, %c1] : memref<?x8x?xf32> to memref<?x8x?xf32, offset:?, strides:[?, ?, ?]>
+ return
+}
+
+// -----
+
func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error at +1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}}
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index dfcf086c73de..e4090ccd6073 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -427,7 +427,7 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
}
-#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
#map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)>
// CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index,
@@ -684,138 +684,106 @@ func @view(%arg0 : index) -> (f32, f32, f32, f32) {
// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 128 + s0 + d1 * 28 + d2 * 11)>
// CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2 + 79)>
-// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2 * 2)>
-// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
-// CHECK-DAG: #[[SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)>
-
+// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)>
+// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)>
// CHECK-LABEL: func @subview
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
// CHECK: %[[C0:.*]] = constant 0 : index
%c0 = constant 0 : index
- // CHECK-NOT: constant 1 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
%c1 = constant 1 : index
- // CHECK-NOT: constant 2 : index
+ // CHECK: %[[C2:.*]] = constant 2 : index
%c2 = constant 2 : index
- // Folded but reappears after subview folding into dim.
// CHECK: %[[C7:.*]] = constant 7 : index
%c7 = constant 7 : index
- // Folded but reappears after subview folding into dim.
// CHECK: %[[C11:.*]] = constant 11 : index
%c11 = constant 11 : index
- // CHECK-NOT: constant 15 : index
%c15 = constant 15 : index
// CHECK: %[[ALLOC0:.*]] = alloc()
- %0 = alloc() : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]>
+ %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
// Test: subview with constant base memref and constant operands is folded.
// Note that the subview uses the base memrefs layout map because it used
// zero offset and unit stride arguments.
- // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<7x11x2xf32, #[[BASE_MAP0]]>
+ // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]>
%1 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1]
- : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
- memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- %v0 = load %1[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
-
- // Test: subview with one dynamic operand can also be folded.
- // CHECK: subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ %v0 = load %1[%c0, %c0, %c0] : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+
+ // Test: subview with one dynamic operand should not be folded.
+ // CHECK: subview %[[ALLOC0]][%[[C0]], %[[ARG0]], %[[C0]]] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]>
%2 = subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1]
- : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
- memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- store %v0, %2[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ store %v0, %2[%c0, %c0, %c0] : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
// CHECK: %[[ALLOC1:.*]] = alloc(%[[ARG0]])
- %3 = alloc(%arg0) : memref<?x16x4xf32, offset : 0, strides : [64, 4, 1]>
+ %3 = alloc(%arg0) : memref<?x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>
// Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static.
- // CHECK: subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] :
- // CHECK-SAME: memref<?x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<7x11x15xf32, #[[BASE_MAP0]]>
+ // CHECK: subview %[[ALLOC1]][] [] [] : memref<?x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[BASE_MAP0]]>
%4 = subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1]
- : memref<?x16x4xf32, offset : 0, strides : [64, 4, 1]> to
- memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- store %v0, %4[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ : memref<?x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ store %v0, %4[%c0, %c0, %c0] : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
// Test: subview offset operands are folded correctly w.r.t. base strides.
- // CHECK: subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
- // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
+ // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
%5 = subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1]
- : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
- memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- store %v0, %5[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ store %v0, %5[%c0, %c0, %c0] : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
// Test: subview stride operands are folded correctly w.r.t. base strides.
- // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]>
- // CHECK-SAME: to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]>
+ // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]>
%6 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11]
- : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to
- memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- store %v0, %6[%c0, %c0, %c0] : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to
+ memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ store %v0, %6[%c0, %c0, %c0] : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
// Test: subview shape are folded, but offsets and strides are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
- // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
- %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] :
- memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %10[%arg1, %arg1, %arg1] :
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
+ %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %10[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// Test: subview strides are folded, but offsets and shape are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
- // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP4]]
- %11 = subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] :
- memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %11[%arg0, %arg0, %arg0] :
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP4]]
+ %11 = subview %0[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c7, %c11] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %11[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// Test: subview offsets are folded, but strides and shape are not even if base memref is static
- // CHECK: subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] :
- // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to
- // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP5]]
- %13 = subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] :
- memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- store %v0, %13[%arg1, %arg1, %arg1] :
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC0]][] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<?x?x?xf32, #[[SUBVIEW_MAP5]]
+ %13 = subview %0[%c1, %c2, %c7] [%arg1, %arg1, %arg1] [%arg0, %arg0, %arg0] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ store %v0, %13[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// CHECK: %[[ALLOC2:.*]] = alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]])
%14 = alloc(%arg0, %arg0, %arg1) : memref<?x?x?xf32>
// Test: subview shape are folded, even if base memref is not static
- // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
- // CHECK-SAME: memref<?x?x?xf32> to
- // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
- %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] :
- memref<?x?x?xf32> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<?x?x?xf32> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]>
+ %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %15[%arg1, %arg1, %arg1] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- // TEST: subview strides are folded, in the type only the most minor stride is folded.
- // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 2, 2] :
- // CHECK-SAME: memref<?x?x?xf32> to
- // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP6]]
- %16 = subview %14[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c2, %c2] :
- memref<?x?x?xf32> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // TEST: subview strides are not folded when the base memref is not static
+ // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[C2]], %[[C2]], %[[C2]]] : memref<?x?x?xf32> to memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
+ %16 = subview %14[%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] [%c2, %c2, %c2] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %16[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
- // TEST: subview offsets are folded but the type offset remains dynamic, when the base memref is not static
- // CHECK: subview %[[ALLOC2]][1, 1, 1] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] :
- // CHECK-SAME: memref<?x?x?xf32> to
- // CHECK-SAME: memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
- %17 = subview %14[%c1, %c1, %c1] [%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] :
- memref<?x?x?xf32> to
- memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
+ // TEST: subview offsets are not folded when the base memref is not static
+ // CHECK: subview %[[ALLOC2]][%[[C1]], %[[C1]], %[[C1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<?x?x?xf32> to memref<?x?x?xf32, #[[SUBVIEW_MAP3]]
+ %17 = subview %14[%c1, %c1, %c1] [%arg0, %arg0, %arg0] [%arg1, %arg1, %arg1] : memref<?x?x?xf32> to memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
store %v0, %17[%arg0, %arg0, %arg0] : memref<?x?x?xf32, offset: ?, strides: [?, ?, ?]>
// CHECK: %[[ALLOC3:.*]] = alloc() : memref<12x4xf32>
@@ -823,26 +791,20 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
%c4 = constant 4 : index
// TEST: subview strides are maintained when sizes are folded
- // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] :
- // CHECK-SAME: memref<12x4xf32> to
- // CHECK-SAME: memref<2x4xf32, #[[SUBVIEW_MAP7]]>
- %19 = subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] :
- memref<12x4xf32> to
- memref<?x?xf32, offset: ?, strides:[4, 1]>
+ // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [] [] : memref<12x4xf32> to memref<2x4xf32, #[[SUBVIEW_MAP6]]>
+ %19 = subview %18[%arg1, %arg1] [%c2, %c4] [] : memref<12x4xf32> to memref<?x?xf32, offset: ?, strides:[4, 1]>
store %v0, %19[%arg1, %arg1] : memref<?x?xf32, offset: ?, strides:[4, 1]>
// TEST: subview strides and sizes are maintained when offsets are folded
- // CHECK: subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] :
- // CHECK-SAME: memref<12x4xf32> to
- // CHECK-SAME: memref<12x4xf32, #[[SUBVIEW_MAP8]]>
- %20 = subview %18[%c2, %c4] [12, 4] [1, 1] :
- memref<12x4xf32> to
- memref<12x4xf32, offset: ?, strides:[4, 1]>
+ // CHECK: subview %[[ALLOC3]][] [] [] : memref<12x4xf32> to memref<12x4xf32, #[[SUBVIEW_MAP7]]>
+ %20 = subview %18[%c2, %c4] [] [] : memref<12x4xf32> to memref<12x4xf32, offset: ?, strides:[4, 1]>
store %v0, %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]>
// Test: dim on subview is rewritten to size operand.
- %7 = dim %4, 0 : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
- %8 = dim %4, 1 : memref<?x?x?xf32, offset : ?, strides : [?, ?, ?]>
+ %7 = dim %4, 0 : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
+ %8 = dim %4, 1 : memref<?x?x?xf32,
+ affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>>
// CHECK: return %[[C7]], %[[C11]]
return %7, %8 : index, index
@@ -929,3 +891,15 @@ func @tensor_divi_unsigned_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// CHECK: return %[[ARG]]
return %res : tensor<4x5xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast_folding_subview
+func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
+ %0 = memref_cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
+ // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32>
+ %1 = subview %0[][%i,%i][]: memref<?x?xf32> to memref<?x?xf32, offset:? , strides: [?, ?]>
+ // CHECK-NEXT: return %{{.*}}
+ return %1: memref<?x?xf32, offset:? , strides: [?, ?]>
+}
+
More information about the Mlir-commits
mailing list