[Mlir-commits] [mlir] 2bd7ee4 - [mlir][memref] Fix ExpandShapeOp verifier
Matthias Springer
llvmlistbot at llvm.org
Thu Mar 31 01:08:01 PDT 2022
Author: Matthias Springer
Date: 2022-03-31T17:05:52+09:00
New Revision: 2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2
URL: https://github.com/llvm/llvm-project/commit/2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2
DIFF: https://github.com/llvm/llvm-project/commit/2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2.diff
LOG: [mlir][memref] Fix ExpandShapeOp verifier
* Complete rewrite of the verifier.
* CollapseShapeOp verifier will be updated in a subsequent commit.
* Update and expand op documentation.
* Add a new builder that infers the result type based on the source type, result shape and reassociation indices. In essence, only the result layout map is inferred.
Differential Revision: https://reviews.llvm.org/D122641
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c7519f1125f12..ffc6ed2263bc3 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1196,7 +1196,9 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
SmallVector<AffineMap, 4> getReassociationMaps();
+
SmallVector<ReassociationExprs, 4> getReassociationExprs();
+
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : reassociation())
@@ -1206,8 +1208,11 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
})));
return reassociationIndices;
};
+
MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+
Value getViewSource() { return src(); }
}];
@@ -1224,36 +1229,45 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
The `memref.expand_shape` op produces a new view with a higher rank whose
- sizes are a reassociation of the original `view`. Depending on whether or
- not the reassociated MemRefType is contiguous, the resulting memref may
- require explicit alloc and copies.
-
- A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of I64ArrayAttr attribute.
+ sizes are a reassociation of the original `view`. The operation is limited
+ to such reassociations, where a dimension is expanded into one or multiple
+ contiguous dimensions. Such reassociations never require additional allocs
+ or copies.
- For now, it is assumed that either:
- 1. a reassociation produces and consumes contiguous MemRefType or,
- 2. the reshape op will be folded into its consumers (by changing the shape
- of the computations).
- All other cases are undefined behavior and a reshape op may not lower to
- LLVM if it cannot be proven statically that it does not require alloc+copy.
-
- The operand memref type when dimensions can be zero-ranked if the result
- memref type is statically shaped with all dimensions being unit extent. In
- such case the reassociation map is empty.
-
- The verification rule is that the reassociation maps are applied to the
- result memref with the larger rank to obtain the operand memref with the
- smaller rank.
+ A reassociation is defined as a grouping of dimensions and is represented
+ with an array of I64ArrayAttr attributes.
Example:
```mlir
- // Dimension expansion i -> (i', j') and (k) -> (k')
- %1 = memref.expand_shape %0 [[0, 1], [2]] :
- memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
+ %r = memref.expand_shape %0 [[0, 1], [2]]
+ : memref<?x?xf32> into memref<?x5x?xf32>
```
+
+ At most one dimension of a reassociation group (e.g., [0, 1] above) may be
+ dynamic in the result type. Otherwise, the op would be ambiguous, as it
+ would not be clear how the source dimension is extended.
+
+ If an op can be statically proven to be invalid (e.g, an expansion from
+ `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
+ it cannot statically be proven invalid (e.g., the full example above; it is
+ unclear whether the first source dimension is divisible by 5), the op is
+ accepted by the verifier. However, if the op is in fact invalid at runtime,
+ the behavior is undefined.
+
+ The source memref can be zero-ranked. In that case, the reassociation
+ indices must be empty and the the result shape may only consist of unit
+ dimensions.
+
+ For simplicity, this op may not be used to cast dynamicity of dimension
+ sizes and/or strides. I.e., if and only if a source dimension is dynamic,
+ there must be a dynamic result dimension in the corresponding reassociation
+ group. Same for strides.
+
+ Note: This op currently assumes that the inner strides are of the
+ source/result layout map are the faster-varying ones.
}];
+
let builders = [
// Builders using ReassociationIndices.
OpBuilder<(ins "Type":$resultType, "Value":$src,
@@ -1264,6 +1278,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
$_state.addAttribute("reassociation",
getReassociationIndicesAttribute($_builder, reassociation));
}]>,
+
+ // Builder using ReassociationExprs.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
@@ -1271,8 +1287,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
auto reassociationMaps =
convertReassociationMapsToIndices($_builder, reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
- }]>
+ }]>,
+
+ // Builder that infers the result layout map. The result shape must be
+ // specified. Otherwise, the op may be ambiguous.
+ OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation)>
];
+
let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 67f166da5ec1a..6b81895a4772a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1558,9 +1558,106 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
+/// Helper function that computes a stride based on the size/stride of the
+/// previous dimension.
+///
+/// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]>
+/// ^^
+/// compute this one
+/// prevStride = 5, prevDimSize = 10
+/// nextStride = 5 * 10 = 50
+static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) {
+ if (ShapedType::isDynamicStrideOrOffset(prevStride))
+ return ShapedType::kDynamicStrideOrOffset;
+
+ if (ShapedType::isDynamic(prevDimSize))
+ return ShapedType::kDynamicStrideOrOffset;
+
+ return prevStride * prevDimSize;
+}
+
+/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
+/// result and operand. Layout maps are verified separately.
+///
+/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
+/// allowed in a reassocation group.
+static LogicalResult
+verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
+ ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociation,
+ bool allowMultipleDynamicDimsPerGroup) {
+ // There must be one reassociation group per collapsed dimension.
+ if (collapsedShape.size() != reassociation.size())
+ return op->emitOpError("invalid number of reassociation groups: found ")
+ << reassociation.size() << ", expected " << collapsedShape.size();
+
+ // The next expected expanded dimension index (while iterating over
+ // reassociation indices).
+ int64_t nextDim = 0;
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices group = it.value();
+ int64_t collapsedDim = it.index();
+
+ bool foundDynamic = false;
+ for (int64_t expandedDim : group) {
+ if (expandedDim != nextDim++)
+ return op->emitOpError("reassociation indices must be contiguous");
+
+ if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
+ return op->emitOpError("reassociation index ")
+ << expandedDim << " is out of bounds";
+
+ // Check if there are multiple dynamic dims in a reassociation group.
+ if (ShapedType::isDynamic(expandedShape[expandedDim])) {
+ if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
+ return op->emitOpError(
+ "at most one dimension in a reassociation group may be dynamic");
+ foundDynamic = true;
+ }
+ }
+
+ // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
+ if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
+ return op->emitOpError("collapsed dim (")
+ << collapsedDim
+ << ") must be dynamic if and only if reassociation group is "
+ "dynamic";
+
+ // If all dims in the reassociation group are static, the size of the
+ // collapsed dim can be verified.
+ if (!foundDynamic) {
+ int64_t groupSize = 1;
+ for (int64_t expandedDim : group)
+ groupSize *= expandedShape[expandedDim];
+ if (groupSize != collapsedShape[collapsedDim])
+ return op->emitOpError("collapsed dim size (")
+ << collapsedShape[collapsedDim]
+ << ") must equal reassociation group size (" << groupSize << ")";
+ }
+ }
+
+ if (collapsedShape.empty()) {
+ // Rank 0: All expanded dimensions must be 1.
+ for (int64_t d : expandedShape)
+ if (d != 1)
+ return op->emitOpError(
+ "rank 0 memrefs can only be extended/collapsed with/from ones");
+ } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
+ // Rank >= 1: Number of dimensions among all reassociation groups must match
+ // the result memref rank.
+ return op->emitOpError("expanded rank (")
+ << expandedShape.size()
+ << ") inconsistent with number of reassociation indices (" << nextDim
+ << ")";
+ }
+
+ return success();
+}
+
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
+
SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
@@ -1569,6 +1666,7 @@ SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
+
SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
@@ -1702,8 +1800,123 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
return success();
}
+/// Compute the layout map after expanding a given source MemRef type with the
+/// specified reassociation indices.
+static FailureOr<AffineMap>
+computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
+ ArrayRef<ReassociationIndices> reassociation) {
+ SmallVector<int64_t> srcStrides, resultStrides(resultShape.size(), 0);
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+ assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
+
+ // Ensure that inner strides are the fastest-varying ones. Other source layout
+ // maps are currently not supported.
+ int64_t lastStride = 0;
+ for (int64_t s : llvm::reverse(srcStrides)) {
+ if (!ShapedType::isDynamicStrideOrOffset(s)) {
+ if (s < lastStride)
+ return failure();
+ lastStride = s;
+ }
+ }
+
+ // Iterate over all reassociation groups from the back. Example:
+ // strides = [1000, ?, 2]
+ // source shape = [20, 10, 5]
+ // result shape = [ 2, 10, 2, 5, 5]
+ // reassociation = [[0, 1], [2, 3], [4]]
+ for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
+ ReassociationIndices indices = std::get<0>(it);
+ int64_t srcGroupStride = std::get<1>(it);
+
+ // The first result dimension (least significant one) in each reassociation
+ // group has the same stride as the corresponding source dimension. E.g.:
+ // reassociation = [[0, 1], [2, 3], [4]]
+ // | | |
+ // v v v
+ // 1000 ? 2
+ resultStrides[indices.pop_back_val()] = srcGroupStride;
+
+ // Compute the strides for the remaining dims in the reassociation group.
+ for (int64_t resultDim : llvm::reverse(indices)) {
+ // E.g.:
+ // reassociation = [[0, 1], [2, 3], [4]]
+ // |
+ // v
+ // 1000 * 10 = 10000
+ //
+ // If the previous stride or the previous dimension was dynamic, then this
+ // stride will also be dynamic.
+ resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1],
+ resultShape[resultDim + 1]);
+ }
+ }
+
+ return makeStridedLinearLayoutMap(resultStrides, srcOffset,
+ srcType.getContext());
+}
+
+static FailureOr<MemRefType>
+computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
+ ArrayRef<ReassociationIndices> reassociation) {
+ if (srcType.getLayout().isIdentity()) {
+ // If the source is contiguous (i.e., no layout map specified), so is the
+ // result.
+ MemRefLayoutAttrInterface layout;
+ return MemRefType::get(resultShape, srcType.getElementType(), layout,
+ srcType.getMemorySpace());
+ }
+
+ // Source may not be contiguous. Compute the layout map.
+ FailureOr<AffineMap> computedLayout =
+ computeExpandedLayoutMap(srcType, resultShape, reassociation);
+ if (failed(computedLayout))
+ return failure();
+ auto computedType =
+ MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
+ srcType.getMemorySpaceAsInt());
+ return canonicalizeStridedLayout(computedType);
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+ ArrayRef<int64_t> resultShape, Value src,
+ ArrayRef<ReassociationIndices> reassociation) {
+ // Only ranked memref source values are supported.
+ auto srcType = src.getType().cast<MemRefType>();
+ FailureOr<MemRefType> resultType =
+ computeExpandedType(srcType, resultShape, reassociation);
+ // Failure of this assertion usually indicates a problem with the source
+ // type, e.g., could not get strides/offset.
+ assert(succeeded(resultType) && "could not compute layout");
+ build(builder, result, *resultType, src, reassociation);
+}
+
LogicalResult ExpandShapeOp::verify() {
- return verifyReshapeOp(*this, getResultType(), getSrcType());
+ MemRefType srcType = getSrcType();
+ MemRefType resultType = getResultType();
+
+ // Verify result shape.
+ if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
+ resultType.getShape(),
+ getReassociationIndices(),
+ /*allowMultipleDynamicDimsPerGroup=*/false)))
+ return failure();
+
+ // Compute expected result type (including layout map).
+ FailureOr<MemRefType> expectedResultType = computeExpandedType(
+ srcType, resultType.getShape(), getReassociationIndices());
+ if (failed(expectedResultType))
+ return emitOpError("invalid source layout map");
+
+ // Check actual result type.
+ auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
+ if (*expectedResultType != canonicalizedResultType)
+ return emitOpError("expected expanded type to be ")
+ << *expectedResultType << " but found " << canonicalizedResultType;
+
+ return success();
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index a714fcb074284..08638234c2259 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -393,8 +393,17 @@ func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
// -----
func @expand_shape(%arg0: memref<f32>) {
- // expected-error @+1 {{expected non-zero memref ranks}}
+ // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}}
%0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+ return
+}
+
+// -----
+
+func @expand_shape(%arg0: memref<f32>) {
+ // expected-error @+1 {{rank 0 memrefs can only be extended/collapsed with/from ones}}
+ %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x2xf32>
+ return
}
// -----
@@ -407,12 +416,22 @@ func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
// -----
func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
- // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+ // expected-error @+1 {{op reassociation index 0 is out of bounds}}
%0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
}
// -----
+func @expand_shape_invalid_result_layout(
+ %arg0: memref<30x20xf32, offset : 100, strides : [4000, 2]>) {
+ // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}}
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] :
+ memref<30x20xf32, offset : 100, strides : [4000, 2]>
+ into memref<2x15x20xf32, offset : 100, strides : [5000, 4000, 2]>
+}
+
+// -----
+
func @collapse_shape(%arg0: memref<?xf32>) {
// expected-error @+1 {{expected to collapse or expand dims}}
%0 = memref.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
@@ -446,7 +465,7 @@ func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
func @expand_shape_illegal_dynamic_memref
(%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
- // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+ // expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}}
%0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
return %0 : memref<?x?x?x4x?xf32>
@@ -456,7 +475,7 @@ func @expand_shape_illegal_dynamic_memref
func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
- // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+ // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
%0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
return %0 : memref<2x3x2x4x5xf32>
@@ -476,7 +495,7 @@ func @collapse_shape_illegal_static_memref
func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
-> memref<?x4x5xf32> {
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+ // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
@@ -486,7 +505,7 @@ func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
-> memref<?x4x5xf32> {
- // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+ // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
%0 = memref.expand_shape %arg0 [[0], [1, 2]]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
@@ -494,6 +513,28 @@ func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
// -----
+func @expand_shape_unsupported_src_layout(
+ %arg0 : memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>)
+ -> memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> {
+ // expected-error @+1 {{invalid source layout map}}
+ %0 = memref.expand_shape %arg0 [[0], [1], [2, 3], [4]] :
+ memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>
+ into memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
+ return %0 : memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
+}
+
+// -----
+
+func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>)
+ -> memref<?x4x5xf32> {
+ // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}}
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+ : memref<?x21xf32> into memref<?x4x5xf32>
+ return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
func @collapse_shape_illegal_mixed_memref(%arg0 : memref<?x4x5xf32>)
-> memref<?x?xf32> {
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6191cdab02e2e..d001f35580ab0 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -104,107 +104,147 @@ func @memref_alloca_scope() {
return
}
-func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>,
- %arg1: tensor<3x4x5xf32>,
- %arg2: tensor<3x?x5xf32>) {
+// CHECK-LABEL: func @expand_collapse_shape_static
+func @expand_collapse_shape_static(
+ %arg0: memref<3x4x5xf32>,
+ %arg1: tensor<3x4x5xf32>,
+ %arg2: tensor<3x?x5xf32>,
+ %arg3: memref<30x20xf32, offset : 100, strides : [4000, 2]>,
+ %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>,
+ %arg5: memref<f32>) {
// Reshapes that collapse and expand back a contiguous buffer.
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<3x4x5xf32> into memref<12x5xf32>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
%r0 = memref.expand_shape %0 [[0, 1], [2]] :
memref<12x5xf32> into memref<3x4x5xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
%1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
memref<3x4x5xf32> into memref<3x20xf32>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
%r1 = memref.expand_shape %1 [[0], [1, 2]] :
memref<3x20xf32> into memref<3x4x5xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
%2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
memref<3x4x5xf32> into memref<60xf32>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
%r2 = memref.expand_shape %2 [[0, 1, 2]] :
- memref<60xf32> into memref<3x4x5xf32>
+ memref<60xf32> into memref<3x4x5xf32>
+
+// CHECK: memref.expand_shape {{.*}} []
+// CHECK-SAME: memref<f32> into memref<1x1xf32>
+ %r5 = memref.expand_shape %arg5 [] :
+ memref<f32> into memref<1x1xf32>
+
+// Reshapes with a custom layout map.
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+ %l0 = memref.expand_shape %arg3 [[0], [1, 2]] :
+ memref<30x20xf32, offset : 100, strides : [4000, 2]>
+ into memref<30x4x5xf32, offset : 100, strides : [4000, 10, 2]>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+ %l1 = memref.expand_shape %arg3 [[0, 1], [2]] :
+ memref<30x20xf32, offset : 100, strides : [4000, 2]>
+ into memref<2x15x20xf32, offset : 100, strides : [60000, 4000, 2]>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+ %r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
+ memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into
+ memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>>
+
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
%3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
%r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+
// Reshapes on tensors.
+// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
%t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] :
tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+
+// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
%rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+
+// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
%t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] :
tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+
+// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
%rt1 = tensor.collapse_shape %t1 [[0], [1, 2], [3, 4]] :
tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
return
}
-// CHECK-LABEL: func @expand_collapse_shape_static
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
-// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-//
-// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
-
+// CHECK-LABEL: func @expand_collapse_shape_dynamic
func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
%arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>,
%arg3: memref<?x42xf32, offset : 0, strides : [42, 1]>) {
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<?x?x?xf32> into memref<?x?xf32>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
%r0 = memref.expand_shape %0 [[0, 1], [2]] :
memref<?x?xf32> into memref<?x4x?xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
%1 = memref.collapse_shape %arg1 [[0, 1], [2]] :
memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
memref<?x?xf32, offset : 0, strides : [?, 1]>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
%r1 = memref.expand_shape %1 [[0, 1], [2]] :
memref<?x?xf32, offset : 0, strides : [?, 1]> into
memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
%2 = memref.collapse_shape %arg2 [[0, 1], [2]] :
memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
memref<?x?xf32, offset : ?, strides : [?, 1]>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
%r2 = memref.expand_shape %2 [[0, 1], [2]] :
memref<?x?xf32, offset : ?, strides : [?, 1]> into
memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]]
+// CHECK-SAME: memref<?x42xf32, #[[$strided2D42]]> into memref<?xf32>
%3 = memref.collapse_shape %arg3 [[0, 1]] :
memref<?x42xf32, offset : 0, strides : [42, 1]> into
memref<?xf32, offset : 0, strides : [1]>
+
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]]
+// CHECK-SAME: memref<?xf32> into memref<?x42xf32>
%r3 = memref.expand_shape %3 [[0, 1]] :
- memref<?xf32, offset : 0, strides : [1]> into
- memref<?x42xf32, offset : 0, strides : [42, 1]>
+ memref<?xf32, offset : 0, strides : [1]> into memref<?x42xf32>
return
}
-// CHECK-LABEL: func @expand_collapse_shape_dynamic
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-// CHECK-SAME: memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]]
-// CHECK-SAME: memref<?x42xf32, #[[$strided2D42]]> into memref<?xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]]
-// CHECK-SAME: memref<?xf32> into memref<?x42xf32, #[[$strided2D42]]>
func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
-> (memref<f32>, memref<1x1xf32>) {
More information about the Mlir-commits
mailing list