[Mlir-commits] [mlir] [DRAFT] Generalize expand_shape to take shape as explicit input (PR #69267)
Gaurav Shukla
llvmlistbot at llvm.org
Thu Nov 16 08:02:15 PST 2023
https://github.com/Shukla-Gaurav updated https://github.com/llvm/llvm-project/pull/69267
>From 273d71d369603d044edbe42e90b080f33615f7af Mon Sep 17 00:00:00 2001
From: Ramiro Leal-Cavazos <ramiroleal050 at gmail.com>
Date: Mon, 16 Oct 2023 17:02:23 -0700
Subject: [PATCH] [MLIR] Generalize expand_shape to take shape as explicit
input
*DO NOT SUBMIT*
(This patch is for early design feedback only. Notably, tests have not been
updated and the implementation is incomplete in some cases.)
This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values. This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.
The output_shape input to expand_shape follows the static/dynamic representation
that's also used in `tensor.extract_slice`.
Differential Revision: https://reviews.llvm.org/D140821
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 88 +++++++++++++----
.../mlir/Dialect/Tensor/IR/TensorOps.td | 97 +++++++++++++-----
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 50 ++++++++--
.../mlir/Dialect/Utils/StaticValueUtils.h | 5 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 1 -
.../Conversion/TosaToTensor/TosaToTensor.cpp | 1 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 ++-
.../Transforms/ConvertConv2DToImg2Col.cpp | 2 +-
.../Linalg/Transforms/DropUnitDims.cpp | 18 ++--
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 15 +--
.../Linalg/Transforms/SplitReduction.cpp | 1 +
.../Dialect/Linalg/Transforms/Transforms.cpp | 8 +-
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 43 ++++++--
.../Transforms/SparseTensorRewriting.cpp | 11 ++-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 75 ++++++++++----
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 99 ++++++++++++++++---
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 7 +-
mlir/test/Dialect/Tensor/ops.mlir | 18 +++-
18 files changed, 432 insertions(+), 122 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 657f601aad2f5a1..4f5c204099ad2e3 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1554,7 +1554,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
MemRef_Op<mnemonic, !listconcat(traits,
[Pure, ViewLikeOpInterface])>,
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
code commonExtraClassDeclaration = [{
@@ -1579,10 +1578,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Value getViewSource() { return getSrc(); }
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1604,14 +1599,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
Example:
```mlir
- %r = memref.expand_shape %0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x5x?xf32>
+ %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
+ : memref<?x32xf32> into memref<?x?x32xf32>
```
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
- would not be clear how the source dimension is extended.
-
If an op can be statically proven to be invalid (e.g, an expansion from
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1628,29 +1619,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
there must be a dynamic result dimension in the corresponding reassociation
group. Same for strides.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
+
Note: This op currently assumes that the inner strides are of the
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation),
+ [{
+ SmallVector<OpFoldResult> inputShape =
+ getMixedSizes($_builder, $_state.location, src);
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+ auto status =
+ inferOutputShape($_builder, $_state.location,
+ resultType.cast<MemRefType>(),
+ reassociation, inputShape, outputShape);
+ (void) status;
+ assert(succeeded(status) && "unable to infer output shape");
+ build($_builder, $_state, resultType.cast<MemRefType>(), src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ outputShape.second, outputShape.first);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto [staticOutputShape, dynamicOutputShape] =
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+ build($_builder, $_state, resultType, src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ dynamicOutputShape, staticOutputShape);
}]>,
// Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
+ [{
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationMaps,
+ outputShape);
}]>,
// Builder that infers the result layout map. The result shape must be
@@ -1663,6 +1697,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
static FailureOr<MemRefType> computeExpandedType(
MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation);
+
+ // Infer the output shape for a memref.expand_shape when it is possible
+ // to do so.
+ static LogicalResult inferOutputShape(
+ OpBuilder &b, Location loc, MemRefType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
}];
let hasVerifier = 1;
@@ -1713,6 +1755,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1724,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1742,7 +1790,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7ae27407a9526e7..e7e083a3d85197b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -998,8 +998,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Tensor_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure])>,
- Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyRankedTensor:$result)> {
+ Results<(outs AnyTensor:$result)> {
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1022,10 +1021,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1038,11 +1033,16 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank whose sizes are a reassociation of the original `src`.
A reassociation is defined as a continuous grouping of dimensions and is
- represented with an array of DenseI64ArrayAttr attribute.
+ represented with an array of DenseI64ArrayAttr attribute. The reassociation
+ maps applied to the result tensor with the higher rank must result in the
+ operand tensor with the smaller rank.
- The verification rule is that the reassociation maps are applied to the
- result tensor with the higher rank to obtain the operand tensor with the
- smaller rank.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
The operand tensor type of a reshape can be zero-ranked if the result
tensor type is statically shaped with all dimensions being unit extent. In
@@ -1052,32 +1052,79 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
- %b = tensor.expand_shape %a [[0, 1], [2]]
- : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
+ : tensor<?x32xf32> into tensor<?x?x32xf32>
```
}];
+
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation),
+ [{
+ SmallVector<OpFoldResult> inputShape =
+ getMixedSizes($_builder, $_state.location, src);
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+ auto status =
+ inferOutputShape($_builder, $_state.location,
+ resultType.cast<RankedTensorType>(),
+ reassociation, inputShape, outputShape);
+ (void) status;
+ assert(succeeded(status) && "unable to infer output shape");
+ build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ outputShape.second, outputShape.first);
+ }]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto [staticOutputShape, dynamicOutputShape] =
+ decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+ build($_builder, $_state, resultType, src,
+ getReassociationIndicesAttribute($_builder, reassociation),
+ dynamicOutputShape, staticOutputShape);
+ }]>,
+
+ // Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
+ [{
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
}]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices,
+ outputShape);
}]>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+ // Infer the output shape for a tensor.expand_shape when it is possible
+ // to do so.
+ static LogicalResult inferOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
}];
let hasVerifier = 1;
@@ -1085,6 +1132,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
let description = [{
The `tensor.collapse_shape` op produces a new tensor with a smaller
rank whose sizes are a reassociation of the original `src`.
@@ -1108,6 +1156,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
}];
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1119,7 +1172,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1137,7 +1190,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f272c..3d673a2eff97e16 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,28 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+LogicalResult inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
+
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +84,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
- OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+ ArrayRef<ReassociationExprs> reassociationExprs);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return std::nullopt when this computation
@@ -166,9 +188,11 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
+enum class ReshapeOpKind { kExpand, kCollapse };
+
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -191,8 +215,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
rewriter.getContext());
if (!reassociationIndices)
return failure();
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+ if constexpr (opKind == ReshapeOpKind::kExpand) {
+ SmallVector<OpFoldResult> outputShape(
+ getMixedValues(reshapeOp.getStaticOutputShape(),
+ reshapeOp.getOutputShape(), rewriter));
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+ outputShape);
+ } else {
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+ }
return success();
}
};
@@ -225,7 +259,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+ typename DimOpTy, typename TensorTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -332,8 +367,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (!composedReassociation)
return failure();
+ SmallVector<OpFoldResult> outputShape(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+ expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+ outputShape);
return success();
}
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..1842817042ea037 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -124,9 +124,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the
/// corresponding pair of arrays. This is the inverse function of
/// `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues);
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96ff..a1ea68ee30af308 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 06ec53d19b1e956..0c502fc8d788a47 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -194,6 +194,7 @@ Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
loc, "tosa.reshape Cannot expand into given shape");
return {};
}
+
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
reassociationMap);
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d12ba8c4c59b33f..27413ddc90b5c75 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -586,9 +586,18 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
return failure();
Location loc = oldFill.getLoc();
- auto newInit = rewriter.create<TensorReshapeOp>(
- loc, reshapeOp.getResultType(), oldFill.output(),
- reshapeOp.getReassociation());
+ TensorReshapeOp newInit;
+ if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
+
+ newInit = rewriter.create<TensorReshapeOp>(
+ loc, reshapeOp.getResultType(), oldFill.output(),
+ reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
+ reshapeOp.getStaticOutputShape());
+ } else {
+ newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
+ oldFill.output(),
+ reshapeOp.getReassociation());
+ }
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
ValueRange{newInit});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index e7629d79494bd47..83d50c64a20ae09 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -349,7 +349,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
{2, 3}};
- Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+ auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
batchMatVecReassociationIndice);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 6fbf35145578716..be6b4fe62052723 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -254,7 +255,7 @@ replaceUnitDimIndexOps(GenericOp genericOp,
/// Expand the given `value` so that the type matches the type of `origDest`.
/// The `reassociation` is used when `rankReductionStrategy` is set to
/// `RankReductionStrategy::ReassociativeReshape`.
-static Value
+static FailureOr<Value>
expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
ArrayRef<ReassociationIndices> reassociation,
ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
@@ -274,8 +275,9 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
assert(rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
- return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
- reassociation);
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+ .getResult();
}
/// Collapse the given `value` so that the type matches the type of
@@ -538,9 +540,13 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
resultReplacements.push_back(result);
continue;
}
- resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
- reassociations[opOperandIndex],
- options.rankReductionStrategy));
+ FailureOr<Value> expandedValue = expandValue(
+ rewriter, loc, result, origDest, reassociations[opOperandIndex],
+ options.rankReductionStrategy);
+ if (failed(expandedValue)) {
+ return rewriter.notifyMatchFailure(genericOp, "unable to expand result");
+ }
+ resultReplacements.push_back(*expandedValue);
}
rewriter.replaceOp(genericOp, resultReplacements);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..8bcb25b295628db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -829,8 +829,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- genericOp.getLoc(), expandedOutputType, opOperand.get(),
- reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation));
} else {
outputs.push_back(opOperand.get());
}
@@ -1576,15 +1575,17 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
+ Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
- Value result = rewriter.create<memref::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ MemRefType expandShapeResultType = MemRefType::get(
+ originalResultType.getShape(), originalResultType.getElementType());
+ result = rewriter.create<memref::ExpandShapeOp>(
+ loc, expandShapeResultType, collapsedOpResult, reassociation);
} else {
- Value result = rewriter.create<tensor::ExpandShapeOp>(
+ result = rewriter.create<tensor::ExpandShapeOp>(
loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
}
+ results.push_back(result);
} else {
results.push_back(collapsedOpResult);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 6559c86c9e0ff50..5bfdbc6d0bb59c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -114,6 +114,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
Type newType = RankedTensorType::get(
newShape,
cast<RankedTensorType>(operand->get().getType()).getElementType());
+
Value newInput = b.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 10dfbe6cec781d5..63901ad645b17b9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -351,11 +351,13 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
/*transposeOp=*/nullptr};
}
}
+
// 5. Expand from the padded result to the stripMinedShape.
+ auto expandShapeResultType =
+ RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
- loc,
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
- padOp.getResult(), packingMetadata.reassociations);
+ loc, expandShapeResultType, padOp.getResult(),
+ packingMetadata.reassociations);
// 6. Transpose stripMinedShape to packedShape.
SmallVector<int64_t> transpPerm =
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..7b74de78830242d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2240,6 +2240,17 @@ FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
srcType.getMemorySpace());
}
+LogicalResult ExpandShapeOp::inferOutputShape(
+ OpBuilder &b, Location loc, MemRefType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ auto expandedTensorType =
+ getTensorTypeFromMemRefType(expandedType).cast<RankedTensorType>();
+ return inferExpandShapeOutputShape(b, loc, expandedTensorType, reassociation,
+ inputShape, outputShape);
+}
+
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> resultShape, Value src,
ArrayRef<ReassociationIndices> reassociation) {
@@ -2250,7 +2261,9 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
// Failure of this assertion usually indicates a problem with the source
// type, e.g., could not get strides/offset.
assert(succeeded(resultType) && "could not compute layout");
- build(builder, result, *resultType, src, reassociation);
+ SmallVector<OpFoldResult> outputShape(
+ getMixedValues(resultShape, ValueRange{}, builder));
+ build(builder, result, *resultType, src, reassociation, outputShape);
}
LogicalResult ExpandShapeOp::verify() {
@@ -2279,14 +2292,28 @@ LogicalResult ExpandShapeOp::verify() {
return emitOpError("expected expanded type to be ")
<< *expectedResultType << " but found " << resultType;
+ if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+ return emitOpError("expected number of static shape bounds to be equal to "
+ "the output rank (")
+ << resultType.getRank() << ") but found "
+ << getStaticOutputShape().size() << " inputs instead";
+
+ if ((int64_t)getOutputShape().size() !=
+ llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+ return emitOpError("mismatch in dynamic dims in output_shape and "
+ "static_output_shape: static_output_shape has ")
+ << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+ << " dynamic dims while output_shape has " << getOutputShape().size()
+ << " values";
+
return success();
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
- context);
+ results.add<
+ ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
}
/// Compute the layout map after collapsing a given source MemRef type with the
@@ -2484,9 +2511,11 @@ struct CollapseShapeOpMemRefCastFolder
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
- CollapseShapeOpMemRefCastFolder>(context);
+ results.add<
+ ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+ memref::DimOp, MemRefType>,
+ CollapseShapeOpMemRefCastFolder>(context);
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 3fe0c551be57a4d..7556c76a4cc46c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -820,8 +820,15 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
auto rtp = getRankedTensorType(op.getResult());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
- auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
- op.getReassociation());
+ ReshapeOp reshape;
+ if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
+ reshape = rewriter.create<ReshapeOp>(
+ loc, denseTp, op.getSrc(), op.getReassociation(),
+ op.getOutputShape(), op.getStaticOutputShape());
+ } else {
+ reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+ op.getReassociation());
+ }
Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
rewriter.replaceOp(op, convert);
return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..7f0808946912fdf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1417,6 +1417,15 @@ int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
llvm_unreachable("could not find reassociation group");
}
+LogicalResult ExpandShapeOp::inferOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ return inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
+ inputShape, outputShape);
+}
+
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
@@ -1506,6 +1515,20 @@ LogicalResult ExpandShapeOp::verify() {
return emitOpError("expected rank expansion, but found source rank ")
<< srcType.getRank() << " >= result rank " << resultType.getRank();
+ if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+ return emitOpError("expected number of static shape dims to be equal to "
+ "the output rank (")
+ << resultType.getRank() << ") but found "
+ << getStaticOutputShape().size() << " inputs instead";
+
+ if ((int64_t)getOutputShape().size() !=
+ llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+ return emitOpError("mismatch in dynamic dims in output_shape and "
+ "static_output_shape: static_output_shape has ")
+ << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+ << " dynamic dims while output_shape has " << getOutputShape().size()
+ << " values";
+
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
@@ -1696,23 +1719,25 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
- FoldReshapeWithConstant<ExpandShapeOp>,
- FoldReshapeWithSplat<ExpandShapeOp>,
- FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
- FoldDimOfCollapseShape>(context);
+ results.add<
+ ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ FoldReshapeWithConstant<ExpandShapeOp>,
+ FoldReshapeWithSplat<ExpandShapeOp>,
+ FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+ FoldDimOfCollapseShape>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
- ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
- FoldReshapeWithConstant<CollapseShapeOp>,
- FoldReshapeWithSplat<CollapseShapeOp>,
- FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
- context);
+ results.add<
+ ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+ ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+ tensor::DimOp, RankedTensorType>,
+ FoldReshapeWithConstant<CollapseShapeOp>,
+ FoldReshapeWithSplat<CollapseShapeOp>,
+ FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
+ context);
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
@@ -3263,12 +3288,16 @@ namespace {
struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
- Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
- Type newOperandType, ArrayAttr reassociation) const {
+ FailureOr<Value>
+ insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType,
+ ArrayRef<ReassociationIndices> reassociation) const {
if (operand.getType() == newOperandType)
return operand;
- return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
- reassociation);
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+ reassociation)
+ .getResult();
}
LogicalResult matchAndRewrite(PackOp packOp,
@@ -3281,10 +3310,14 @@ struct SimplifyPackToExandShape : public OpRewritePattern<PackOp> {
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
return failure();
- Value expanded = insertExpand(
- rewriter, packOp.getLoc(), packOp.getSource(), destType,
- getReassociationIndicesAttribute(rewriter, *reassociation));
- rewriter.replaceOp(packOp, expanded);
+ FailureOr<Value> expanded =
+ insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
+ *reassociation);
+ if (failed(expanded)) {
+ return rewriter.notifyMatchFailure(
+ packOp, "unable to expand source of tensor.pack");
+ }
+ rewriter.replaceOp(packOp, *expanded);
return success();
}
};
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 41c7af4593c77ce..92d842b252f4e6c 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -16,6 +18,82 @@
using namespace mlir;
+LogicalResult mlir::inferExpandShapeOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape,
+ std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+ SmallVector<Value> outputShapeValues;
+ SmallVector<int64_t> outputShapeInts;
+ // For zero-rank inputs, all dims in result shape are unit extent.
+ if (inputShape.empty()) {
+ outputShapeInts.resize(expandedType.getRank(), 1);
+ outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+ return success();
+ }
+
+ outputShapeValues.resize(expandedType.getRank());
+ outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+
+ for (const auto &it : llvm::enumerate(reassociation)) {
+ ReassociationIndices indexGroup = it.value();
+
+ int64_t indexGroupStaticSizesProductInt = 1;
+ bool foundDynamic = false;
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ // Cannot infer expanded shape with multiple dynamic dims in the
+ // same reassociation group!
+ if (ShapedType::isDynamic(outputDimSize)) {
+ if (foundDynamic)
+ return failure();
+ foundDynamic = true;
+ } else {
+ indexGroupStaticSizesProductInt *= outputDimSize;
+ }
+ }
+ Value indexGroupStaticSizesProduct =
+ b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+
+ int64_t inputIndex = it.index();
+ for (int64_t index : indexGroup) {
+ if (ShapedType::isDynamic(expandedType.getDimSize(index))) {
+ // Call get<Value>() under the assumption that we're not casting
+ // dynamism.
+ Value indexGroupSize = inputShape[inputIndex].get<Value>();
+
+ // Create an AffineMap representing the division operation.
+ MLIRContext *context = b.getContext();
+ AffineExpr dividend = getAffineSymbolExpr(0, context);
+ AffineExpr divisor = getAffineSymbolExpr(1, context);
+ AffineMap divisionMap = AffineMap::get(/*numDims=*/0, /*numSymbols=*/2,
+ {dividend.floorDiv(divisor)});
+ Value dynamicDimSize = b.createOrFold<affine::AffineApplyOp>(
+ loc, divisionMap,
+ ValueRange({indexGroupSize, indexGroupStaticSizesProduct}));
+ outputShapeValues[index] = dynamicDimSize;
+ }
+ }
+
+ for (int64_t index : indexGroup) {
+ int64_t outputDimSize = expandedType.getDimSize(index);
+ if (ShapedType::isDynamic(outputDimSize))
+ continue;
+ outputShapeInts[index] = outputDimSize;
+ }
+ }
+
+ if (static_cast<uint64_t>(
+ llvm::count(outputShapeInts, ShapedType::kDynamic)) ==
+ (outputShapeValues.size() - llvm::count(outputShapeValues, Value{})))
+ return failure();
+
+ llvm::erase(outputShapeValues, Value{});
+
+ outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+ return success();
+}
+
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
ShapedType targetType) {
@@ -168,7 +246,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
}
SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
- OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+ ArrayRef<ReassociationExprs> reassociationExprs) {
SmallVector<ReassociationIndices, 2> reassociationIndices;
for (const auto &exprs : reassociationExprs) {
ReassociationIndices indices;
@@ -230,24 +308,17 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
unsigned expandedDimStart = 0;
for (const auto &map : llvm::enumerate(reassociationMaps)) {
- std::optional<int64_t> dynamicShape;
+ bool foundDynamicShape = false;
int64_t linearizedStaticShape = 1;
+
for (const auto &dim : llvm::enumerate(
expandedShape.slice(expandedDimStart, map.value().size()))) {
- if (ShapedType::isDynamic(dim.value())) {
- if (isExpandingReshape && dynamicShape) {
- return emitError("invalid to have a single dimension (" +
- Twine(map.index()) +
- ") expanded into multiple dynamic dims (" +
- Twine(expandedDimStart + dynamicShape.value()) +
- "," + Twine(expandedDimStart + dim.index()) + ")");
- }
- dynamicShape = dim.index();
- } else {
+ if (ShapedType::isDynamic(dim.value()))
+ foundDynamicShape = true;
+ else
linearizedStaticShape *= dim.value();
- }
}
- if (dynamicShape) {
+ if (foundDynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return emitError(
"expected dimension " + Twine(map.index()) +
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..a6ebdb162b6168a 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -184,9 +184,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
- const SmallVectorImpl<OpFoldResult> &mixedValues) {
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
@@ -197,7 +196,7 @@ decomposeMixedValues(Builder &b,
dynamicValues.push_back(it.get<Value>());
}
}
- return {b.getI64ArrayAttr(staticValues), dynamicValues};
+ return {staticValues, dynamicValues};
}
/// Helper to sort `values` according to matching `keys`.
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489b23f5f2d..ee036a69944b296 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -177,12 +177,26 @@ func.func @insert_slice(
func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
-> (tensor<f32>, tensor<1x1xf32>) {
%0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
- %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
+ %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
return %0, %1 : tensor<f32>, tensor<1x1xf32>
}
// CHECK-LABEL: func @tensor_reshape_zero_dim
// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
-// CHECK: tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor<?x?xf32>, %sz0 : index, %sz1 : index, %sz2 : index)
+ -> (tensor<5x?x?x?xf32>) {
+ %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+ return %1 : tensor<5x?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> {
+// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+// CHECK: return %expanded : tensor<5x?x?x?xf32>
+// CHECK: }
+
// -----
More information about the Mlir-commits
mailing list