[Mlir-commits] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)
Gaurav Shukla
llvmlistbot at llvm.org
Fri Apr 12 09:12:32 PDT 2024
https://github.com/Shukla-Gaurav updated https://github.com/llvm/llvm-project/pull/69267
>From 7bbd5fad68cbb0d55244911d8c45d844eaf215ae Mon Sep 17 00:00:00 2001
From: Ramiro Leal-Cavazos <ramiroleal050 at gmail.com>
Date: Mon, 16 Oct 2023 17:02:23 -0700
Subject: [PATCH] [MLIR] Generalize expand_shape to take shape as explicit
input
*DO NOT SUBMIT*
(This patch needs to be tested in IREE backend)
This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values. This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.
The output_shape input to expand_shape follows the static/dynamic representation
that's also used in `tensor.extract_slice`.
Differential Revision: https://reviews.llvm.org/D140821
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 88 +++++++++++---
.../mlir/Dialect/Tensor/IR/TensorOps.td | 99 +++++++++++----
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 50 +++++++-
.../mlir/Dialect/Utils/StaticValueUtils.h | 5 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 1 -
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 ++-
.../Transforms/ConvertConv2DToImg2Col.cpp | 2 +-
.../Transforms/DataLayoutPropagation.cpp | 10 +-
.../Linalg/Transforms/DropUnitDims.cpp | 25 ++--
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 15 +--
.../Linalg/Transforms/SplitReduction.cpp | 1 +
.../Dialect/Linalg/Transforms/Transforms.cpp | 8 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 41 ++++++-
.../Transforms/SparseTensorRewriting.cpp | 11 +-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 56 ++++++---
.../BufferizableOpInterfaceImpl.cpp | 5 +-
.../Transforms/PackAndUnpackPatterns.cpp | 24 ++--
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 86 ++++++++++---
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 7 +-
.../expand-then-convert-to-llvm.mlir | 16 +--
.../MemRefToLLVM/memref-to-llvm.mlir | 4 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 29 +++--
.../TosaToTensor/tosa-to-tensor.mlir | 114 +++++++++++++-----
.../Linalg/bubble-up-extract-slice-op.mlir | 4 +-
mlir/test/Dialect/Linalg/collapse-dim.mlir | 6 +-
.../Linalg/convert-conv2d-to-img2col.mlir | 20 +--
.../Linalg/data-layout-propagation.mlir | 30 +++--
.../Dialect/Linalg/flatten-elementwise.mlir | 2 +-
.../Dialect/Linalg/fusion-push-reshape.mlir | 24 ++--
mlir/test/Dialect/MemRef/ops.mlir | 72 ++++++-----
.../Dialect/SparseTensor/sparse_reshape.mlir | 12 +-
mlir/test/Dialect/Tensor/bufferize.mlir | 16 +--
mlir/test/Dialect/Tensor/canonicalize.mlir | 112 ++++++++---------
mlir/test/Dialect/Tensor/fold-empty-op.mlir | 5 +-
.../Tensor/fold-reassociative-reshapes.mlir | 6 +-
mlir/test/Dialect/Tensor/ops.mlir | 18 ++-
.../Dialect/Tensor/simplify-pack-unpack.mlir | 14 +--
37 files changed, 715 insertions(+), 338 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8ea46bf3b275e5 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
MemRef_Op<mnemonic, !listconcat(traits,
[Pure, ViewLikeOpInterface])>,
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Value getViewSource() { return getSrc(); }
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
Example:
```mlir
- %r = memref.expand_shape %0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x5x?xf32>
+ %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
+ : memref<?x32xf32> into memref<?x?x32xf32>
```
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
- would not be clear how the source dimension is extended.
-
If an op can be statically proven to be invalid (e.g, an expansion from
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,29 +1613,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
there must be a dynamic result dimension in the corresponding reassociation
group. Same for strides.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
+
Note: This op currently assumes that the inner strides are of the
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation),
+ [{
+ SmallVector<OpFoldResult> inputShape =
+ getMixedSizes($_builder, $_state.location, src);
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+ auto status =
+ inferOutputShape($_builder, $_state.location,
+ resultType.cast<MemRefType>(),
+ reassociation, inputShape, outputShape);
+ (void) status;
+ assert(succeeded(status) && "unable to infer output shape");
+ build($_builder, $_state, resultType.cast<MemRefType>(), src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ outputShape.second, outputShape.first);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto [staticOutputShape, dynamicOutputShape] =
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+ build($_builder, $_state, resultType, src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ dynamicOutputShape, staticOutputShape);
}]>,
// Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
+ [{
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationMaps,
+ outputShape);
}]>,
// Builder that infers the result layout map. The result shape must be
@@ -1657,6 +1691,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
static FailureOr<MemRefType> computeExpandedType(
MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation);
+
+ // Infer the output shape for a memref.expand_shape when it is possible
+ // to do so.
+ static LogicalResult inferOutputShape(
+ OpBuilder &b, Location loc, MemRefType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
}];
let hasVerifier = 1;
@@ -1707,6 +1749,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1718,7 +1766,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1736,7 +1784,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..d3a1250f7e8985 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Tensor_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure])>,
- Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyRankedTensor:$result)> {
+ Results<(outs AnyTensor:$result)> {
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1102,43 +1097,95 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank than the operand `src` whose dimension sizes are a reassociation of
`src`.
- A reassociation is defined as a continuous grouping of dimensions. It is
- represented with an array of DenseI64ArrayAttr attribute. Entries in the
- array are referred to as reassociation maps.
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an array of DenseI64ArrayAttr attribute. The reassociation
+ maps applied to the result tensor with the higher rank must result in the
+ operand tensor with the smaller rank.
- The reassociation maps are applied to the result shape to obtain the operand
- shape.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
- %b = tensor.expand_shape %a [[0, 1], [2]]
- : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
+ : tensor<?x32xf32> into tensor<?x?x32xf32>
```
}];
+
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation),
+ [{
+ SmallVector<OpFoldResult> inputShape =
+ getMixedSizes($_builder, $_state.location, src);
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+ auto status =
+ inferOutputShape($_builder, $_state.location,
+ resultType.cast<RankedTensorType>(),
+ reassociation, inputShape, outputShape);
+ (void) status;
+ assert(succeeded(status) && "unable to infer output shape");
+ build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ outputShape.second, outputShape.first);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto [staticOutputShape, dynamicOutputShape] =
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+ build($_builder, $_state, resultType, src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ dynamicOutputShape, staticOutputShape);
+ }]>,
+
+ // Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
+ [{
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
}]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices,
+ outputShape);
}]>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+ // Infer the output shape for a tensor.expand_shape when it is possible
+ // to do so.
+ static LogicalResult inferOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
}];
let hasVerifier = 1;
@@ -1146,6 +1193,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
let description = [{
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1211,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
}];
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1174,7 +1227,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1192,7 +1245,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..e46d4ff558916f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,28 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+LogicalResult inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
+
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +84,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
- OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+ ArrayRef<ReassociationExprs> reassociationExprs);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return std::nullopt when this computation
@@ -156,9 +178,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
+enum class ReshapeOpKind { kExpand, kCollapse };
+
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +205,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
rewriter.getContext());
if (!reassociationIndices)
return failure();
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+ if constexpr (opKind == ReshapeOpKind::kExpand) {
+ SmallVector<OpFoldResult> outputShape(
+ getMixedValues(reshapeOp.getStaticOutputShape(),
+ reshapeOp.getOutputShape(), rewriter));
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+ outputShape);
+ } else {
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+ }
return success();
}
};
@@ -215,7 +249,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+ typename DimOpTy, typename TensorTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +357,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (!composedReassociation)
return failure();
+ SmallVector<OpFoldResult> outputShape(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+ expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+ outputShape);
return success();
}
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the
/// corresponding pair of arrays. This is the inverse function of
/// `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues);
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..22ff502a7be7a4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..fb43dae305b8de 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -586,9 +586,18 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
return failure();
Location loc = oldFill.getLoc();
- auto newInit = rewriter.create<TensorReshapeOp>(
- loc, reshapeOp.getResultType(), oldFill.output(),
- reshapeOp.getReassociation());
+ TensorReshapeOp newInit;
+ if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
+
+ newInit = rewriter.create<TensorReshapeOp>(
+ loc, reshapeOp.getResultType(), oldFill.output(),
+ reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
+ reshapeOp.getStaticOutputShape());
+ } else {
+ newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
+ oldFill.output(),
+ reshapeOp.getReassociation());
+ }
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
ValueRange{newInit});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 420b04b3ee28cf..81d44ba04fa1d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -349,7 +349,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
{2, 3}};
- Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+ auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
batchMatVecReassociationIndice);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7fd88dec71d491..9a2493a59e019e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -757,7 +757,10 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
- ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+ auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
+ if (!expandTy)
+ return failure();
+ ArrayRef<int64_t> dstShape = expandTy.getShape();
SmallVector<ReassociationIndices> reassocIndices =
expandOp.getReassociationIndices();
// Project inner tile pos to the dim pos after expanding. For example, if dims
@@ -796,9 +799,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
nextPos += 1;
}
- RankedTensorType newExpandType =
- tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
- projectedInnerDimsPos, newOuterDimsPerm);
+ RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
+ expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
newReassocIndices);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 023ea277bcf499..cef61b20e40a00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -252,7 +253,7 @@ replaceUnitDimIndexOps(GenericOp genericOp,
/// Expand the given `value` so that the type matches the type of `origDest`.
/// The `reassociation` is used when `rankReductionStrategy` is set to
/// `RankReductionStrategy::ReassociativeReshape`.
-static Value
+static FailureOr<Value>
expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
ArrayRef<ReassociationIndices> reassociation,
ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
@@ -272,8 +273,9 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
assert(rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
- return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
- reassociation);
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+ .getResult();
}
/// Collapse the given `value` so that the type matches the type of
@@ -536,9 +538,13 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
resultReplacements.push_back(result);
continue;
}
- resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
- reassociations[opOperandIndex],
- options.rankReductionStrategy));
+ FailureOr<Value> expandedValue = expandValue(
+ rewriter, loc, result, origDest, reassociations[opOperandIndex],
+ options.rankReductionStrategy);
+ if (failed(expandedValue)) {
+ return rewriter.notifyMatchFailure(genericOp, "unable to expand result");
+ }
+ resultReplacements.push_back(*expandedValue);
}
rewriter.replaceOp(genericOp, resultReplacements);
@@ -669,10 +675,13 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
padOp.getResultType().getElementType());
}
- Value expandedValue =
+ FailureOr<Value> expandedValue =
expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
reassociationMap, options.rankReductionStrategy);
- rewriter.replaceOp(padOp, expandedValue);
+ if (failed(expandedValue)) {
+ return rewriter.notifyMatchFailure(padOp, "unable to expand result");
+ }
+ rewriter.replaceOp(padOp, *expandedValue);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 373e9cfc3ce719..5a34ecbaf50b15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -843,8 +843,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- linalgOp.getLoc(), expandedOutputType, opOperand.get(),
- reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation));
} else {
outputs.push_back(opOperand.get());
}
@@ -1615,15 +1614,17 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
+ Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
- Value result = rewriter.create<memref::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ MemRefType expandShapeResultType = MemRefType::get(
+ originalResultType.getShape(), originalResultType.getElementType());
+ result = rewriter.create<memref::ExpandShapeOp>(
+ loc, expandShapeResultType, collapsedOpResult, reassociation);
} else {
- Value result = rewriter.create<tensor::ExpandShapeOp>(
+ result = rewriter.create<tensor::ExpandShapeOp>(
loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
}
+ results.push_back(result);
} else {
results.push_back(collapsedOpResult);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 6559c86c9e0ff5..5bfdbc6d0bb59c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -114,6 +114,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
Type newType = RankedTensorType::get(
newShape,
cast<RankedTensorType>(operand->get().getType()).getElementType());
+
Value newInput = b.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..c41a899b2e6f5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -329,11 +329,13 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
/*transposeOp=*/nullptr};
}
}
+
// 5. Expand from the padded result to the stripMinedShape.
+ auto expandShapeResultType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
+ loc, expandShapeResultType, padOp.getResult(),
+ packingMetadata.reassociations);
// 6. Transpose stripMinedShape to packedShape.
SmallVector<int64_t> transpPerm =
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..32e5f1578d0a39 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2237,6 +2237,17 @@ FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
srcType.getMemorySpace());
}
+LogicalResult ExpandShapeOp::inferOutputShape(
+ OpBuilder &b, Location loc, MemRefType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ auto expandedTensorType =
+ getTensorTypeFromMemRefType(expandedType).cast<RankedTensorType>();
+ return inferExpandShapeOutputShape(b, loc, expandedTensorType, reassociation,
+ inputShape, outputShape);
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> resultShape, Value src,
ArrayRef<ReassociationIndices> reassociation) {
@@ -2247,6 +2258,8 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
// Failure of this assertion usually indicates a problem with the source
// type, e.g., could not get strides/offset.
assert(succeeded(resultType) && "could not compute layout");
+ // SmallVector<OpFoldResult> outputShape(
+ // getMixedValues(resultShape, ValueRange{}, builder));
build(builder, result, *resultType, src, reassociation);
}
@@ -2280,14 +2293,28 @@ LogicalResult ExpandShapeOp::verify() {
return emitOpError("expected expanded type to be ")
<< *expectedResultType << " but found " << resultType;
+ if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+ return emitOpError("expected number of static shape bounds to be equal to "
+ "the output rank (")
+ << resultType.getRank() << ") but found "
+ << getStaticOutputShape().size() << " inputs instead";
+
+ if ((int64_t)getOutputShape().size() !=
+ llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+ return emitOpError("mismatch in dynamic dims in output_shape and "
+ "static_output_shape: static_output_shape has ")
+ << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+ << " dynamic dims while output_shape has " << getOutputShape().size()
+ << " values";
+
return success();
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
- context);
+ results.add<
+ ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
/// Compute the layout map after collapsing a given source MemRef type with the
@@ -2488,9 +2515,11 @@ struct CollapseShapeOpMemRefCastFolder
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
- CollapseShapeOpMemRefCastFolder>(context);
+ results.add<
+ ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+ memref::DimOp, MemRefType>,
+ CollapseShapeOpMemRefCastFolder>(context);
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..149597df332907 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -952,8 +952,15 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto rtp = getRankedTensorType(op.getResult());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
- op.getReassociation());
+ ReshapeOp reshape;
+ if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
+ reshape = rewriter.create<ReshapeOp>(
+ loc, denseTp, op.getSrc(), op.getReassociation(),
+ op.getOutputShape(), op.getStaticOutputShape());
+ } else {
+ reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+ op.getReassociation());
+ }
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..e86d318167e421 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1606,6 +1606,15 @@ int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
llvm_unreachable("could not find reassociation group");
}
+LogicalResult ExpandShapeOp::inferOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ return inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
+ inputShape, outputShape);
+}
+
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
@@ -1689,7 +1698,24 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
}
LogicalResult ExpandShapeOp::verify() {
- return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
+ auto srcType = getSrcType();
+ auto resultType = getResultType();
+
+ if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+ return emitOpError("expected number of static shape dims to be equal to "
+ "the output rank (")
+ << resultType.getRank() << ") but found "
+ << getStaticOutputShape().size() << " inputs instead";
+
+ if ((int64_t)getOutputShape().size() !=
+ llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+ return emitOpError("mismatch in dynamic dims in output_shape and "
+ "static_output_shape: static_output_shape has ")
+ << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+ << " dynamic dims while output_shape has " << getOutputShape().size()
+ << " values";
+
+ return verifyTensorReshapeOp(*this, resultType, srcType);
}
LogicalResult CollapseShapeOp::verify() {
@@ -1873,23 +1899,25 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
- FoldReshapeWithConstant<ExpandShapeOp>,
- FoldReshapeWithSplat<ExpandShapeOp>,
- FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
- FoldDimOfCollapseShape>(context);
+ results.add<
+ ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ FoldReshapeWithConstant<ExpandShapeOp>,
+ FoldReshapeWithSplat<ExpandShapeOp>,
+ FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+ FoldDimOfCollapseShape>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
- FoldReshapeWithConstant<CollapseShapeOp>,
- FoldReshapeWithSplat<CollapseShapeOp>,
- FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
- context);
+ results.add<
+ ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+ tensor::DimOp, RankedTensorType>,
+ FoldReshapeWithConstant<CollapseShapeOp>,
+ FoldReshapeWithSplat<CollapseShapeOp>,
+ FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
+ context);
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 58ea4cc4da3c36..8958bb20098b7d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -338,9 +338,10 @@ struct ExpandShapeOpInterface
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
- replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
- rewriter, op, tensorResultType.getShape(), *buffer,
+ Value memrefExpandOp = rewriter.create<memref::ExpandShapeOp>(
+ op->getLoc(), tensorResultType.getShape(), *buffer,
expandShapeOp.getReassociationIndices());
+ replaceOpWithBufferizedValues(rewriter, op, memrefExpandOp);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 666ac56c6cd5cd..7011ce23b55a6b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -52,12 +52,16 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
- Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
- Type newOperandType, ArrayAttr reassociation) const {
+ FailureOr<Value>
+ insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType,
+ ArrayRef<ReassociationIndices> reassociation) const {
if (operand.getType() == newOperandType)
return operand;
- return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
- reassociation);
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+ reassociation)
+ .getResult();
}
/// Returns success() if it is only packing on the innermost dimension.
@@ -96,10 +100,14 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
return failure();
- Value expanded = insertExpand(
- rewriter, packOp.getLoc(), packOp.getSource(), destType,
- getReassociationIndicesAttribute(rewriter, *reassociation));
- rewriter.replaceOp(packOp, expanded);
+ FailureOr<Value> expanded =
+ insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
+ *reassociation);
+ if (failed(expanded)) {
+ return rewriter.notifyMatchFailure(
+ packOp, "unable to expand source of tensor.pack");
+ }
+ rewriter.replaceOp(packOp, *expanded);
return success();
}
};
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 41c7af4593c77c..bf517e15a37dc6 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -16,6 +18,69 @@
using namespace mlir;
+LogicalResult mlir::inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ SmallVector<Value> outputShapeValues;
+ SmallVector<int64_t> outputShapeInts;
+ // For zero-rank inputs, all dims in result shape are unit extent.
+ if (inputShape.empty()) {
+ outputShapeInts.resize(expandedType.getRank(), 1);
+ outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+ return success();
+ }
+
+ // Check for all static shapes.
+ if (expandedType.hasStaticShape()) {
+ ArrayRef<int64_t> shapeInts = expandedType.getShape();
+ outputShapeInts.assign(shapeInts.begin(), shapeInts.end());
+ outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+ return success();
+ }
+
+ outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices indexGroup = it.value();
+
+ int64_t indexGroupStaticSizesProductInt = 1;
+ bool foundDynamic = false;
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ // Cannot infer expanded shape with multiple dynamic dims in the
+ // same reassociation group!
+ if (ShapedType::isDynamic(outputDimSize)) {
+ if (foundDynamic)
+ return failure();
+ foundDynamic = true;
+ } else {
+ outputShapeInts[index] = outputDimSize;
+ indexGroupStaticSizesProductInt *= outputDimSize;
+ }
+ }
+ if (!foundDynamic)
+ continue;
+
+ int64_t inputIndex = it.index();
+ // Call get<Value>() under the assumption that we're not casting
+ // dynamism.
+ Value indexGroupSize = inputShape[inputIndex].get<Value>();
+ Value indexGroupStaticSizesProduct =
+ b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+ Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
+ loc, indexGroupSize, indexGroupStaticSizesProduct);
+ outputShapeValues.push_back(dynamicDimSize);
+ }
+
+ if (llvm::count(outputShapeInts, ShapedType::kDynamic) !=
+ outputShapeValues.size())
+ return failure();
+
+ outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+ return success();
+}
+
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
@@ -168,7 +233,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
}
SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
- OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+ ArrayRef<ReassociationExprs> reassociationExprs) {
SmallVector<ReassociationIndices, 2> reassociationIndices;
for (const auto &exprs : reassociationExprs) {
ReassociationIndices indices;
@@ -230,24 +295,17 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
unsigned expandedDimStart = 0;
for (const auto &map : llvm::enumerate(reassociationMaps)) {
- std::optional<int64_t> dynamicShape;
+ bool foundDynamicShape = false;
int64_t linearizedStaticShape = 1;
+
for (const auto &dim : llvm::enumerate(
expandedShape.slice(expandedDimStart, map.value().size()))) {
- if (ShapedType::isDynamic(dim.value())) {
- if (isExpandingReshape && dynamicShape) {
- return emitError("invalid to have a single dimension (" +
- Twine(map.index()) +
- ") expanded into multiple dynamic dims (" +
- Twine(expandedDimStart + dynamicShape.value()) +
- "," + Twine(expandedDimStart + dim.index()) + ")");
- }
- dynamicShape = dim.index();
- } else {
+ if (ShapedType::isDynamic(dim.value()))
+ foundDynamicShape = true;
+ else
linearizedStaticShape *= dim.value();
- }
}
- if (dynamicShape) {
+ if (foundDynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return emitError(
"expected dimension " + Twine(map.index()) +
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 1e8197e1094424..74a53709592dd2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -180,9 +180,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues) {
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
@@ -193,7 +192,7 @@ decomposeMixedValues(Builder &b,
dynamicValues.push_back(it.get<Value>());
}
}
- return {b.getI64ArrayAttr(staticValues), dynamicValues};
+ return {staticValues, dynamicValues};
}
/// Helper to sort `values` according to matching `keys`.
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 87d613986c7c3f..b86103422b0745 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -453,7 +453,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
- %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+ %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
return %0 : memref<1x3x4x1x5xf32>
}
@@ -510,7 +510,7 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
// -----
func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
- %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+ %0 = memref.expand_shape %arg0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
return %0 : memref<1x1xf32>
}
@@ -571,13 +571,13 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// -----
-func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
- %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
+func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<1x2x?xf32> {
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0]: memref<1x?xf32> into memref<1x2x?xf32>
return %0 : memref<1x2x?xf32>
}
// CHECK-LABEL: func.func @expand_shape_dynamic(
-// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> {
+// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
@@ -614,15 +614,15 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
// -----
func.func @expand_shape_dynamic_with_non_identity_layout(
- %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
+ %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>, %sz0: index) ->
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
- %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
+ %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0] :
memref<1x?xf32, strided<[?, ?], offset: ?>> into
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
}
// CHECK-LABEL: func.func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 37999d6fc14ad1..baf9cfe610a5a0 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -334,9 +334,9 @@ memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
// CHECK-LABEL: func @expand_shape_static(
// CHECK-SAME: %[[ARG:.*]]: memref<{{.*}}>)
func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
- // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]]
+ // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
// Reshapes that expand a contiguous tensor with some 1's.
- %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+ %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
return %0 : memref<1x3x4x1x5xf32>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..7114555f5af9b5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -344,7 +344,7 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
- // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
+ // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
// CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
@@ -867,7 +867,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
+ // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -878,7 +878,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
// CHECK: [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield [[RES]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+ // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32>
%1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xf32>) -> tensor<5x1xf32>
// CHECK: arith.constant 1.0
@@ -916,7 +916,10 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
+ // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?x4xf32>
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [%[[DIM_1]], 1, 4] : tensor<?x4xf32> into tensor<?x1x4xf32>
%0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
return
}
@@ -934,7 +937,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
// CHECK: %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] : tensor<f32> into tensor<1xf32>
+ // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] output_shape [1] : tensor<f32> into tensor<1xf32>
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<?xf32>) -> tensor<1xf32>
return
}
@@ -954,7 +957,10 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
// CHECK: %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[RES]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C1_0]] : tensor<5x?xf32>
+ // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
+ // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [5, %[[DIM_1]], 1] : tensor<5x?xf32> into tensor<5x?x1xf32>
%0 = tosa.reduce_prod %arg0 {axis = 2 : i32} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
return
}
@@ -974,7 +980,10 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
// CHECK: %[[MAX:.+]] = arith.maximumf %[[ARG1]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[MAX]] : f32
// CHECK: }
- // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
+ // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?xf32>
+ // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
+ // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] output_shape [%[[DIM_1]], 1] : tensor<?xf32> into tensor<?x1xf32>
%0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<?x?xf32>) -> tensor<?x1xf32>
return
}
@@ -992,7 +1001,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
// CHECK: }
- // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
+ // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32>
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
// CHECK: [[INIT:%.+]] = tensor.empty()
@@ -1003,7 +1012,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
// CHECK: [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
// CHECK: linalg.yield [[RES]] : i32
// CHECK: }
- // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
+ // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xi32> into tensor<5x1xi32>
%1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x1xi32>
// CHECK: arith.constant 1
@@ -1039,7 +1048,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
// CHECK: [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
// CHECK: linalg.yield [[RES]] : i1
// CHECK: }
- // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
+ // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi1> into tensor<1x4xi1>
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
// CHECK: arith.constant false
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index a8a3c42e168422..b8c3d56f21f10c 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -14,7 +14,7 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32>
// CHECK-LABEL: test_reshape_0d_up_s2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
// CHECK: return %[[VAL_1]] : tensor<?xf32>
func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
@@ -26,7 +26,7 @@ func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
// CHECK-LABEL: test_reshape_0d_up_s2d_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
// CHECK: return %[[VAL_1]] : tensor<?xf32>
func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32> {
@@ -38,7 +38,7 @@ func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32>
// CHECK-LABEL: test_reshape_0d_up_s2s_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[VAL_0]] : tensor<1xf32>
func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<1xf32>
@@ -49,7 +49,7 @@ func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
// CHECK-LABEL: test_reshape_0d_up_s2s_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[VAL_0]] : tensor<1xf32>
func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
@@ -83,8 +83,12 @@ func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32
// CHECK-LABEL: test_reshape_1d_up_d2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
-// CHECK: return %[[VAL_0]] : tensor<2x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C2]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, %[[VAL_0]]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
@@ -94,7 +98,7 @@ func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32>
// CHECK-LABEL: test_reshape_1d_up_s2s_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<6xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: return %[[VAL_0]] : tensor<2x3xf32>
func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
@@ -128,8 +132,12 @@ func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6
// CHECK-LABEL: test_reshape_2d_same_d2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x2xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
-// CHECK: return %[[VAL_1]] : tensor<2x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C2]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, %[[DIV]]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
@@ -140,7 +148,7 @@ func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf
// CHECK-LABEL: test_reshape_2d_same_s2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32>
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
@@ -153,7 +161,7 @@ func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf
// CHECK-LABEL: test_reshape_2d_same_s2d_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32>
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
// CHECK: return %[[VAL_2]] : tensor<?x2xf32>
func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
@@ -166,7 +174,7 @@ func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?
// CHECK-LABEL: test_reshape_2d_same_s2s_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
// CHECK: return %[[VAL_1]] : tensor<2x3xf32>
func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
@@ -178,7 +186,11 @@ func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2
// CHECK-LABEL: test_reshape_3d_same_d2d_auto_empty
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<0x3x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C0_0]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [0, 3, %[[DIV]]] : tensor<?xf32> into tensor<0x3x?xf32>
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
@@ -191,7 +203,11 @@ func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tens
// CHECK-LABEL: test_reshape_3d_same_d2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<2x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor<?xf32> into tensor<2x?x4xf32>
// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?x?xf32> {
@@ -204,7 +220,11 @@ func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?
// CHECK-LABEL: test_reshape_3d_same_d2d_auto_identity
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x3x4xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x3x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, 3, %[[DIV]]] : tensor<?xf32> into tensor<2x3x?xf32>
// CHECK: return %[[VAL_1]] : tensor<2x3x?xf32>
func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, -1>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
@@ -216,8 +236,12 @@ func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> t
// CHECK-LABEL: test_reshape_3d_same_d2d_explicit_empty
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x2xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 2] : tensor<?xf32> into tensor<?x3x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, 2>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
@@ -229,8 +253,12 @@ func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) ->
// CHECK-LABEL: test_reshape_3d_same_d2d_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C12:.*]] = arith.constant 12 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
// CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
@@ -253,8 +281,12 @@ func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>)
// CHECK-LABEL: test_reshape_3d_same_d2s_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
@@ -266,8 +298,12 @@ func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3
// CHECK-LABEL: test_reshape_3d_same_d2s_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C12:.*]] = arith.constant 12 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
// CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
@@ -288,10 +324,14 @@ func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>)
// CHECK-LABEL: test_reshape_3d_up_d2s_explicit
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : tensor<?xf32> into tensor<?x3x2x1xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
-// CHECK: return %[[VAL_2]] : tensor<1x3x2x1xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [%[[VAL_0]], 3, 2, 1] : tensor<?xf32> into tensor<?x3x2x1xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
+// CHECK: return %[[CAST]] : tensor<1x3x2x1xf32>
func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
%0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
return %0 : tensor<1x3x2x1xf32>
@@ -313,9 +353,13 @@ func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tens
// CHECK-LABEL: test_reshape_5d_down_d2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x2x3xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x2x3xf32>
-// CHECK: return %[[VAL_1]] : tensor<?x2x3xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 2, 3] : tensor<?xf32> into tensor<?x2x3xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x2x3xf32>
func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2, 3>} : (tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32>
return %0 : tensor<?x2x3xf32>
@@ -325,9 +369,13 @@ func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor
// CHECK-LABEL: test_reshape_6d_down_d2d_auto
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x?x5x7x11xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x5x77xf32>
-// CHECK: return %[[VAL_1]] : tensor<?x5x77xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C385:.*]] = arith.constant 385 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C385]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 5, 77] : tensor<?xf32> into tensor<?x5x77xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x5x77xf32>
func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
return %0 : tensor<?x5x77xf32>
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
index 0e353a1fa43fcb..4bf81820f0e805 100644
--- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
@@ -165,7 +165,9 @@ func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
%init = tensor.empty(%width) : tensor<1x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
%slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
- %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
+ %c0 = arith.constant 0 : index
+ %sz0 = tensor.dim %slice, %c0 : tensor<?xf32>
+ %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor<?xf32> into tensor<1x1x1x?xf32>
return %expand : tensor<1x1x1x?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f5338747..61bedecbdca5a4 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -52,7 +52,7 @@ func.func @collapse_parallel(
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
// CHECK: } -> tensor<2x32x40960xf32>
-// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
+// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] output_shape [2, 32, 10, 4096] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
// -----
@@ -127,8 +127,8 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
// CHECK: }
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index a6431996353121..c7c846d7ecc9c5 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -50,7 +50,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
// CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -78,7 +78,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<1x196x16xf32>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: return %[[RESULT]]
func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -204,7 +204,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<8x196x16xf32>
-// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
+// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [8, 14, 14, 16] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
// CHECK: return %[[CS_FINAL]]
func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_hwcf
@@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<8x16x196xf32>
-// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
+// CHECK: %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
// CHECK: return %[[CS_FINAL]]
func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
%0 = linalg.conv_2d_nchw_fchw
@@ -310,7 +310,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
// CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -338,7 +338,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: } -> tensor<1x196x16xf32>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
// CHECK: return %[[RESULT]]
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -378,7 +378,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32
// CHECK: linalg.yield %[[ADD]] : i32
// CHECK: } -> tensor<1x196x16xi32>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
// CHECK: return %[[RESULT]]
func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> {
@@ -416,7 +416,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
// CHECK: linalg.yield %[[ADD]] : complex<f32>
// CHECK: } -> tensor<1x196x16xcomplex<f32>>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
// CHECK: return %[[RESULT]]
func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f32>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
@@ -459,7 +459,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
// CHECK: linalg.yield %[[ADD]] : complex<f32>
// CHECK: } -> tensor<1x196x16xcomplex<f32>>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
// CHECK: return %[[RESULT]]
func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f16>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
@@ -500,7 +500,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
// CHECK: linalg.yield %[[ADD]] : complex<f32>
// CHECK: } -> tensor<1x196x16xcomplex<f32>>
-// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
// CHECK: return %[[RESULT]]
func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 79d61ab757e327..bee08503298fd4 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,17 +988,20 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
// -----
-func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
- %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
func.return %expanded : tensor<?x256x256xf32>
}
// CHECK-LABEL: func.func @push_down_unpack_through_expand
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C32:.+]] = arith.constant 32 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
+// CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
@@ -1009,12 +1012,12 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
%6 = tensor.empty() : tensor<4x3072x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
- %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
func.return %expanded : tensor<4x12x256x256xf32>
}
// CHECK-LABEL: @push_down_permuted_unpack_through_expand
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] output_shape [4, 32, 12, 32, 8, 8] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
@@ -1024,29 +1027,32 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
%6 = tensor.empty() : tensor<48x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
- %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] output_shape [3, 16, 1, 256] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
func.return %expanded : tensor<3x16x1x256xf32>
}
// CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] output_shape [3, 2, 1, 32, 8, 8] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
// CHECK: return %[[UNPACK]] : tensor<3x16x1x256xf32>
// -----
-func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32>
- %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
func.return %expanded : tensor<?x256x256xf32>
}
// CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[C256:.+]] = arith.constant 256 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8xf32>
+// CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C256]] : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [%[[SZ0]], 256, 32, 8] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
@@ -1057,11 +1063,11 @@ func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>,
func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
%6 = tensor.empty() : tensor<3072x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
- %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+ %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
func.return %expanded : tensor<256x12x256xf32>
}
// CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
// CHECK: return %[[EXPANDED]] : tensor<256x12x256xf32>
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 5a27fe76b13411..9fe50a521d2d81 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -26,7 +26,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
-// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
+// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] output_shape [32, 7] : tensor<224xf32> into tensor<32x7xf32>
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
return %0 : tensor<32x7xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index f1c729ef963ba8..751ece37bc094f 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -4,15 +4,19 @@
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-LABEL: func @reshape
-// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
+// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
+// CHECK: %[[C112:.*]] = arith.constant 112 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
+// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
+// CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C112]] : index
+// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
-func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
- %0 = tensor.expand_shape %A [[0, 1], [2]]
+func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
+ %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]
: tensor<?x16xf32> into tensor<?x112x16xf32>
%2 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
@@ -39,13 +43,13 @@ func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
-// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: return %[[RR]] : tensor<112x112x16xf32>
func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
%C: tensor<16xf32>) -> tensor<112x112x16xf32> {
- %0 = tensor.expand_shape %A [[0, 1], [2]]
+ %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
- %1 = tensor.expand_shape %B [[0, 1], [2]]
+ %1 = tensor.expand_shape %B [[0, 1], [2]] output_shape [112, 112, 16]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%2 = tensor.empty() : tensor<112x112x16xf32>
%3 = linalg.generic {indexing_maps = [
@@ -69,11 +73,11 @@ func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
// Negative test, since the second source is broadcasted from d1 we cannot merge
// d0 and d1 dimensions
// CHECK-LABEL: func @reshape_negative
-// CHECK: tensor.expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: tensor.expand_shape {{.*}} {{\[\[}}0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
// CHECK: linalg.generic
// CHECK: } -> tensor<112x112x16xf32>
func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
- %20 = tensor.expand_shape %A [[0, 1], [2]]
+ %20 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
: tensor<12544x16xf32> into tensor<112x112x16xf32>
%21 = tensor.empty() : tensor<112x112x16xf32>
%22 = linalg.generic {indexing_maps = [
@@ -96,7 +100,7 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
%cst_6 = arith.constant 1.000000e+00 : f32
%cst_7 = arith.constant 7.000000e+00 : f32
%cst_8 = arith.constant 1.1920929E-7 : f32
- %25 = tensor.expand_shape %arg0 [[0, 1], [2]]
+ %25 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 5]
: tensor<6x5xi32> into tensor<2x3x5xi32>
%26 = tensor.empty() : tensor<2x3x5xf32>
%28 = linalg.generic {
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 2d69904f27db5e..60fb0ffeee2403 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -106,9 +106,9 @@ func.func @expand_collapse_shape_static(
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<3x4x5xf32> into memref<12x5xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
- %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+ %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 5] :
memref<12x5xf32> into memref<3x4x5xf32>
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
@@ -116,9 +116,9 @@ func.func @expand_collapse_shape_static(
%1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
memref<3x4x5xf32> into memref<3x20xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
- %r1 = memref.expand_shape %1 [[0], [1, 2]] :
+ %r1 = memref.expand_shape %1 [[0], [1, 2]] output_shape [3, 4, 5] :
memref<3x20xf32> into memref<3x4x5xf32>
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
@@ -126,29 +126,29 @@ func.func @expand_collapse_shape_static(
%2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
memref<3x4x5xf32> into memref<60xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
- %r2 = memref.expand_shape %2 [[0, 1, 2]] :
+ %r2 = memref.expand_shape %2 [[0, 1, 2]] output_shape [3, 4, 5] :
memref<60xf32> into memref<3x4x5xf32>
-// CHECK: memref.expand_shape {{.*}} []
+// CHECK: memref.expand_shape {{.*}} [] output_shape [1, 1]
// CHECK-SAME: memref<f32> into memref<1x1xf32>
- %r5 = memref.expand_shape %arg5 [] :
+ %r5 = memref.expand_shape %arg5 [] output_shape [1, 1] :
memref<f32> into memref<1x1xf32>
// Reshapes with a custom layout map.
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
- %l0 = memref.expand_shape %arg3 [[0], [1, 2]] :
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [30, 4, 5]
+ %l0 = memref.expand_shape %arg3 [[0], [1, 2]] output_shape [30, 4, 5] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
- %l1 = memref.expand_shape %arg3 [[0, 1], [2]] :
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [2, 15, 20]
+ %l1 = memref.expand_shape %arg3 [[0, 1], [2]] output_shape [2, 15, 20] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
- %r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [1, 1, 5]
+ %r4 = memref.expand_shape %arg4 [[0], [1, 2]] output_shape [1, 1, 5] :
memref<1x5xf32, strided<[5, 1], offset: ?>> into
memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>>
@@ -164,9 +164,9 @@ func.func @expand_collapse_shape_static(
memref<2049xi64, strided<[?], offset: ?>>
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
- %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
+ %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]:
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
@@ -176,15 +176,18 @@ func.func @expand_collapse_shape_static(
// Reshapes on tensors.
// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
- %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] :
+ %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] :
tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
%rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+// CHECK: tensor.dim %arg2, {{.*}} : tensor<3x?x5xf32>
// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
- %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] :
+ %c1 = arith.constant 1 : index
+ %sz1 = tensor.dim %arg2, %c1 : tensor<3x?x5xf32>
+ %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] output_shape [1, 3, %sz1, 1, 5] :
tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
@@ -197,15 +200,18 @@ func.func @expand_collapse_shape_static(
func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>>,
%arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>,
- %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>) {
+ %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
+ %arg4: index,
+ %arg5: index,
+ %arg6: index) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<?x?x?xf32> into memref<?x?xf32>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
- %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+ %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32> into memref<?x4x?xf32>
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
@@ -214,9 +220,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>> into
memref<?x?xf32, strided<[?, 1], offset: 0>>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32, strided<[?, 1]>> into memref<?x4x?xf32, strided<[?, ?, 1]>>
- %r1 = memref.expand_shape %1 [[0, 1], [2]] :
+ %r1 = memref.expand_shape %1 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32, strided<[?, 1], offset: 0>> into
memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
@@ -226,9 +232,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> into
memref<?x?xf32, strided<[?, 1], offset: ?>>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32, strided<[?, 1], offset: ?>> into memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
- %r2 = memref.expand_shape %2 [[0, 1], [2]] :
+ %r2 = memref.expand_shape %2 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32, strided<[?, 1], offset: ?>> into
memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
@@ -238,9 +244,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x42xf32, strided<[42, 1], offset: 0>> into
memref<?xf32, strided<[1]>>
-// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]]
+// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42]
// CHECK-SAME: memref<?xf32, strided<[1]>> into memref<?x42xf32>
- %r3 = memref.expand_shape %3 [[0, 1]] :
+ %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
memref<?xf32, strided<[1]>> into memref<?x42xf32>
return
}
@@ -248,12 +254,12 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
func.func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
-> (memref<f32>, memref<1x1xf32>) {
%0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
- %1 = memref.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
+ %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
return %0, %1 : memref<f32>, memref<1x1xf32>
}
// CHECK-LABEL: func @expand_collapse_shape_zero_dim
// CHECK: memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
-// CHECK: memref.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
func.func @collapse_shape_to_dynamic
(%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
@@ -270,16 +276,18 @@ func.func @collapse_shape_to_dynamic
// CHECK-LABEL: func @expand_collapse_shape_transposed_layout
func.func @expand_collapse_shape_transposed_layout(
%m0: memref<?x?xf32, strided<[1, 10], offset: 0>>,
- %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>) {
+ %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>,
+ %sz0: index,
+ %sz1: index) {
- %r0 = memref.expand_shape %m0 [[0], [1, 2]] :
+ %r0 = memref.expand_shape %m0 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] :
memref<?x?xf32, strided<[1, 10], offset: 0>> into
memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>>
%rr0 = memref.collapse_shape %r0 [[0], [1, 2]] :
memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>> into
memref<?x?xf32, strided<[1, 10], offset: 0>>
- %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] :
+ %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] :
memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into
memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>>
%rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] :
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index edb53fa024c26b..c96f9c31443db3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -12,7 +12,7 @@
//
// CHECK-ROUND-LABEL: func.func @sparse_expand(
// CHECK-ROUND-SAME: %[[A:.*]]: tensor<100xf64, #sparse{{[0-9]*}}>) -> tensor<10x10xf64, #sparse{{[0-9]*}}>
-// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
+// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [10, 10] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: return %[[E]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
//
// CHECK-LABEL: func.func @sparse_expand(
@@ -39,7 +39,7 @@
// CHECK: return %[[NT1]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
- %0 = tensor.expand_shape %arg0 [[0, 1]] :
+ %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [10, 10] :
tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
return %0 : tensor<10x10xf64, #SparseMatrix>
}
@@ -94,8 +94,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// roundtrip:
//
// CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand(
-// CHECK-ROUND-SAME: %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
-// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
+// CHECK-ROUND-SAME: %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>, %[[SZ0:.*]]: index) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
+// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [%[[SZ0]], 10] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: return %[[E]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
//
// CHECK-LABEL: func.func @dynamic_sparse_expand(
@@ -127,8 +127,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-NOT: sparse_tensor.convert
// CHECK: return %[[NT1]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
//
-func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
- %0 = tensor.expand_shape %arg0 [[0, 1]] :
+func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>, %sz0: index) -> tensor<?x10xf64, #SparseMatrix> {
+ %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 10] :
tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
return %0 : tensor<?x10xf64, #SparseMatrix>
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 815bc383af95a6..64d9c86ba2dee5 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -367,11 +367,11 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
// CHECK-LABEL: func @tensor.expand_shape(
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
-func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
+func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
- // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
- %0 = tensor.expand_shape %t1 [[0, 1], [2]]
+ // CHECK-SAME: [0, 1], [2]] output_shape [2, %sz0, 10] : memref<?x10xf32> into memref<2x?x10xf32>
+ %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
@@ -384,14 +384,14 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
func.func @tensor.expand_shape_of_slice(
- %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
+ %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
- // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
- %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
+ // CHECK-SAME: [0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+ %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
@@ -407,8 +407,8 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref<?xf32> to memref<f32, strided<[], offset: ?>>
%0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
- %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
+ %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..23921a824f2136 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -4,7 +4,7 @@
// CHECK-LABEL: expand_shape_identity_fold
// CHECK-NEXT: return
func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
- %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+ %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
return %0 : tensor<5xf32>
}
@@ -13,7 +13,7 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
// CHECK-LABEL: expand_shape_rank0_identity_fold
// CHECK-NEXT: return
func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
- %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+ %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
return %0 : tensor<f32>
}
@@ -1051,29 +1051,28 @@ func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4
// -----
-func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
+func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
-> tensor<?x6x4x?x5xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
- %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]]
- : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
return %1 : tensor<?x6x4x?x5xf32>
}
// CHECK-LABEL: compose_expand_of_expand
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
// CHECK-NOT: tensor.expand_shape
// -----
func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
-> tensor<1x1x1xf32> {
- %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
- %1 = tensor.expand_shape %0 [[0, 1, 2]]
+ %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
+ %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
: tensor<1xf32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
// CHECK-LABEL: compose_expand_of_expand_of_zero_dim
-// CHECK: tensor.expand_shape %{{.*}} []
+// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
// -----
@@ -1093,7 +1092,7 @@ func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
// -----
func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
: tensor<12x4xf32> into tensor<3x4x4xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<3x4x4xf32> into tensor<12x4xf32>
@@ -1104,9 +1103,9 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
// -----
-func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
+func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x?xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1121,7 +1120,7 @@ func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
-> tensor<24x5x42x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
: tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
- %1 = tensor.expand_shape %0 [[0, 1, 2, 3]]
+ %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
: tensor<40320xf32> into tensor<24x5x42x8xf32>
return %1 : tensor<24x5x42x8xf32>
}
@@ -1137,7 +1136,7 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
-> tensor<2x3x4x5x6x7x8xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
: tensor<24x5x42x8xf32> into tensor<40320xf32>
- %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]]
+ %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
: tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
return %1 : tensor<2x3x4x5x6x7x8xf32>
}
@@ -1149,16 +1148,16 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
// -----
-func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
+func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?xi64> {
- %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+ %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
: tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
%1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
: tensor<?x?x?x1xi64> into tensor<?x?xi64>
return %1 : tensor<?x?xi64>
}
// CHECK-LABEL: func @compose_collapse_of_expand
-// CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>)
+// CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
// CHECK-NEXT: tensor.collapse_shape %[[ARG]]
// CHECK-SAME: [0, 1], [2]
// CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64>
@@ -1167,14 +1166,14 @@ func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
-> tensor<4x512xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
: tensor<2048xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<1x4x1x512xf32> into tensor<4x512xf32>
return %1 : tensor<4x512xf32>
}
// CHECK: func @compose_collapse_of_expand_1D
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
// -----
@@ -1183,14 +1182,14 @@ func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>
-> tensor<1x1x1x1xf32> {
%0 = tensor.collapse_shape %arg0 []
: tensor<1x1x1xf32> into tensor<f32>
- %1 = tensor.expand_shape %0 []
+ %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
: tensor<f32> into tensor<1x1x1x1xf32>
return %1 : tensor<1x1x1x1xf32>
}
// CHECK: func @compose_expand_of_collapse_0_rank_to_expand
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32>
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME: [0], [1], [2, 3]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1]
// CHECK: return %[[RESULT]]
// -----
@@ -1199,7 +1198,7 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
-> tensor<1x1x1xf32> {
%0 = tensor.collapse_shape %arg0 []
: tensor<1x1x1x1xf32> into tensor<f32>
- %1 = tensor.expand_shape %0 []
+ %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
: tensor<f32> into tensor<1x1x1xf32>
return %1 : tensor<1x1x1xf32>
}
@@ -1214,8 +1213,8 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
- %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
- %1 = tensor.expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
+ %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
+ %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
%2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
return %2 : tensor<f32>
}
@@ -1250,7 +1249,7 @@ func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
// -----
func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
: tensor<4x512xf32> into tensor<1x4x1x512xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
: tensor<1x4x1x512xf32> into tensor<2048xf32>
@@ -1264,42 +1263,40 @@ func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048x
func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
-> tensor<4x512x1x1xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
- : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
: tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
return %1 : tensor<4x512x1x1xf32>
}
// CHECK: func @fold_collapse_of_expand_unit_dims
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
// CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
// -----
func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
-> tensor<4x512x1x512x4xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
- : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}
// CHECK: func @compose_collapse_of_expand_unit_dims
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
// -----
func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
-> tensor<2x1xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
: tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
// CHECK: func @compose_collapse_of_expand_trailing_unit_dims
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
@@ -1321,14 +1318,13 @@ func.func @compose_collapse_of_collapse_unit_dims_dynamic(
func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
-> tensor<2x1xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
- : tensor<2xf32> into tensor<2x1x1xf32>
+ %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
: tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
// CHECK: func @fold_collapse_of_expand_trailing_unit_dims
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
// -----
@@ -1349,8 +1345,7 @@ func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
-> tensor<12x42xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
- : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
+ %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
: tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
return %1 : tensor<12x42xf32>
@@ -1361,9 +1356,9 @@ func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf3
// -----
-func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
+func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
-> tensor<?x?xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
+ %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
: tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
: tensor<?x?x1x?xf32> into tensor<?x?xf32>
@@ -1378,7 +1373,7 @@ func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>
func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
-> tensor<2x6x16xf32> {
- %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
: tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
%1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
: tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
@@ -1392,7 +1387,7 @@ func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
-> tensor<12x1xf32> {
- %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
+ %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
: tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
%1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
: tensor<3x2x2x1xf32> into tensor<12x1xf32>
@@ -1401,7 +1396,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
// CHECK: func @no_fold_collapse_of_expand_empty_expr
// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
// CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME: [0], [1], [2, 3]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
// CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
// CHECK-SAME: [0, 1, 2], [3]
// CHECK: return %[[RES:.+]] : tensor<12x1xf32>
@@ -1410,7 +1405,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
%c0 = arith.constant dense<42> : tensor<2x8xi32>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
: tensor<2x8xi32> into tensor<2x4x2xi32>
return %0 : tensor<2x4x2xi32>
}
@@ -1421,7 +1416,7 @@ func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
// -----
func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
%c0 = tensor.splat %arg : tensor<2x4xf32>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
: tensor<2x4xf32> into tensor<2x2x2xf32>
return %0 : tensor<2x2x2xf32>
}
@@ -1434,13 +1429,12 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
// -----
// CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
-// CHECK-SAME: %[[F:.+]]: f32
-// CHECK-SAME: %[[M:.+]]: index
-func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> {
- // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+// CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index)
+func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
+ // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
%c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32>
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
return %0 : tensor<2x2x?xf32>
}
@@ -1475,7 +1469,7 @@ func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x
func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
%c0 = arith.constant dense<42> : tensor<2x8xi16>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
: tensor<2x8xi16> into tensor<2x4x2xi16>
return %0 : tensor<2x4x2xi16>
}
@@ -1488,7 +1482,7 @@ func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
%c0 = arith.constant dense<42.0> : tensor<2x8xf32>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
: tensor<2x8xf32> into tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
@@ -1501,7 +1495,7 @@ func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
%c0 = arith.constant dense<42.0> : tensor<2x8xf64>
- %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
: tensor<2x8xf64> into tensor<2x4x2xf64>
return %0 : tensor<2x4x2xf64>
}
@@ -1851,7 +1845,7 @@ func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
// CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
// CHECK: return %[[FROM]] : tensor<1xi32>
%0 = tensor.from_elements %arg0 : tensor<i32>
- %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
+ %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
return %1 : tensor<1xi32>
}
@@ -2073,9 +2067,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
// CHECK: return %[[apply]]
-func.func @dim_of_expand_shape(%t: tensor<?x?xf32>) -> index {
+func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
- %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]]
+ %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
: tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
%1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
return %1 : index
@@ -2107,9 +2101,9 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
// CHECK-LABEL: func @collapse_expand_fold_to_cast(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
// CHECK: return %[[t]]
-func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
{
- %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
return %1 : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 15f841f2128edb..e200a4f8926130 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -13,10 +13,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
-func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
%0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
- %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
- : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 625408dfefe216..d3ac6ce792f365 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -11,9 +11,11 @@ func.func @expand_shape_of_rank_reducing_extract(
{
%0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
: tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
- %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
+ %c0 = arith.constant 0 : index
+ %sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32>
+ %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5]
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
- %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
+ %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5]
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
}
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2b0a74acce0826..378137a14b59ff 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -194,12 +194,26 @@ func.func @insert_slice(
func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
-> (tensor<f32>, tensor<1x1xf32>) {
%0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
- %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
+ %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
return %0, %1 : tensor<f32>, tensor<1x1xf32>
}
// CHECK-LABEL: func @tensor_reshape_zero_dim
// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
-// CHECK: tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor<?x?xf32>, %sz0 : index, %sz1 : index, %sz2 : index)
+ -> (tensor<5x?x?x?xf32>) {
+ %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+ return %1 : tensor<5x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> {
+// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+// CHECK: return %expanded : tensor<5x?x?x?xf32>
+// CHECK: }
+
// -----
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 9948c0246e6ed6..5a2eade0ecccf1 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: func.func @single_dim_packing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
@@ -27,7 +27,7 @@ func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x3
// CHECK-LABEL: func.func @single_last_inner_dim_packing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
%empty = tensor.empty() : tensor<5x8x32xf32>
@@ -39,7 +39,7 @@ func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8
// CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm(
// CHECK-SAME: %[[ARG0:.+]]: tensor<64xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<64xf32> into tensor<2x32xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<2x32xf32>
func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> {
%empty = tensor.empty() : tensor<2x32xf32>
@@ -51,7 +51,7 @@ func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf3
// CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
%empty = tensor.empty() : tensor<5x8x32xf32>
@@ -85,7 +85,7 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
%empty = tensor.empty() : tensor<1x32x1x1xf32>
@@ -98,7 +98,7 @@ func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf3
// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
%empty = tensor.empty() : tensor<1x16x1x2xf32>
@@ -111,7 +111,7 @@ func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf3
// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1]
// CHECK: return %[[EXPANDED]]
func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
%empty = tensor.empty() : tensor<1x16x2x1xf32>
More information about the Mlir-commits
mailing list