[llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #90040)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 25 04:28:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Gaurav Shukla (Shukla-Gaurav)
<details>
<summary>Changes</summary>
This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.
The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`.
Differential Revision: https://reviews.llvm.org/D140821
---
Patch is 259.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90040.diff
53 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-22)
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+56-23)
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+46-12)
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+2-3)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (-1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+12-4)
- (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+6-4)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+8-5)
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+49-18)
- (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (+1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-3)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+75-7)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+9-2)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+71-14)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+3)
- (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+16-8)
- (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+69-14)
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+3-4)
- (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-10)
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+81-33)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+3-1)
- (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
- (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+18-12)
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+71-37)
- (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+57-44)
- (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+14-10)
- (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+192-100)
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-3)
- (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+17-18)
- (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+9-7)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+12-10)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+14-24)
- (modified) mlir/test/Dialect/MemRef/ops.mlir (+40-32)
- (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+3-2)
- (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+14-10)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+53-59)
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-3)
- (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+4-2)
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+5-16)
- (modified) mlir/test/Dialect/Tensor/ops.mlir (+16-2)
- (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..14b8d95ea15b41 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
MemRef_Op<mnemonic, !listconcat(traits,
[Pure, ViewLikeOpInterface])>,
- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)>{
code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Value getViewSource() { return getSrc(); }
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
Example:
```mlir
- %r = memref.expand_shape %0 [[0, 1], [2]]
- : memref<?x?xf32> into memref<?x5x?xf32>
+ %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+ : memref<?x32xf32> into memref<?x?x32xf32>
```
- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
- dynamic in the result type. Otherwise, the op would be ambiguous, as it
- would not be clear how the source dimension is extended.
-
If an op can be statically proven to be invalid (e.g, an expansion from
`memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
there must be a dynamic result dimension in the corresponding reassociation
group. Same for strides.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
+
Note: This op currently assumes that the inner strides are of the
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape)>,
+
+ // It will infer output shape using inferOutputShape() method.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+ // Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
}]>,
- // Builder using ReassociationExprs.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationMaps,
+ outputShape);
}]>,
+ // Builder that infers the result layout map. The result shape must be
+ // specified. Otherwise, the op may be ambiguous. The output shape for
+ // the op will be inferred using the inferOutputShape() method.
+ OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation)>,
+
// Builder that infers the result layout map. The result shape must be
// specified. Otherwise, the op may be ambiguous.
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
- "ArrayRef<ReassociationIndices>":$reassociation)>
+ "ArrayRef<ReassociationIndices>":$reassociation,
+ "ArrayRef<OpFoldResult>":$outputShape)>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
static FailureOr<MemRefType> computeExpandedType(
MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation);
+
+ // Infer the output shape for a memref.expand_shape when it is possible
+ // to do so.
+ static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+ OpBuilder &b, Location loc, MemRefType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape);
}];
let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
source/result layout map are the faster-varying ones.
}];
+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..a403e89a39f98c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Tensor_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure])>,
- Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyRankedTensor:$result)> {
+ Results<(outs AnyTensor:$result)> {
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
}
}];
- let assemblyFormat = [{
- $src $reassociation attr-dict `:` type($src) `into` type($result)
- }];
-
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -1102,43 +1097,75 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
rank than the operand `src` whose dimension sizes are a reassociation of
`src`.
- A reassociation is defined as a continuous grouping of dimensions. It is
- represented with an array of DenseI64ArrayAttr attribute. Entries in the
- array are referred to as reassociation maps.
+ A reassociation is defined as a continuous grouping of dimensions and is
+ represented with an array of DenseI64ArrayAttr attribute. The reassociation
+ maps applied to the result tensor with the higher rank must result in the
+ operand tensor with the smaller rank.
- The reassociation maps are applied to the result shape to obtain the operand
- shape.
+ The representation for the output shape supports a partially-static
+ specification via attributes specified through the `static_output_shape`
+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
+ corresponding entry has a dynamic value. There must be exactly as many SSA
+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+ `static_output_shape`.
Example:
```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
- %b = tensor.expand_shape %a [[0, 1], [2]]
- : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+ : tensor<?x32xf32> into tensor<?x?x32xf32>
```
}];
+
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+ Variadic<Index>:$output_shape,
+ DenseI64ArrayAttr:$static_output_shape);
+
+ let assemblyFormat = [{
+ $src $reassociation `output_shape`
+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+ type($src) `into` type($result)
+ }];
+
let builders = [
// Builders using ReassociationIndices.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape)>,
+
+ // It will infer output shape using inferOutputShape() method.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+ // Builder using ReassociationExprs.
+ OpBuilder<(ins "Type":$resultType, "Value":$src,
+ "ArrayRef<ReassociationExprs>":$reassociation),
[{
- build($_builder, $_state, resultType, src, attrs);
- $_state.addAttribute("reassociation",
- getReassociationIndicesAttribute($_builder, reassociation));
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices);
}]>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+ "ArrayRef<OpFoldResult>":$outputShape),
[{
- auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+ auto reassociationIndices =
+ convertReassociationMapsToIndices(reassociation);
+ build($_builder, $_state, resultType, src, reassociationIndices,
+ outputShape);
}]>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+ // Infer the output shape for a tensor.expand_shape when it is possible
+ // to do so.
+ static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+ OpBuilder &b, Location loc, RankedTensorType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape);
}];
let hasVerifier = 1;
@@ -1146,6 +1173,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
let summary = "operation to produce a tensor with a smaller rank";
+ let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
let description = [{
The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
: tensor<?x?x?xf32> into tensor<?x?xf32>
```
}];
+
+ let assemblyFormat = [{
+ $src $reassociation attr-dict `:` type($src) `into` type($result)
+ }];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,
@@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
- convertReassociationMapsToIndices($_builder, reassociation);
+ convertReassociationMapsToIndices(reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..8a41a0a18b0ab3 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,27 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+std::optional<SmallVector<OpFoldResult>>
+inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<OpFoldResult> inputShape);
+
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +83,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
- OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+ ArrayRef<ReassociationExprs> reassociationExprs);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return std::nullopt when this computation
@@ -140,14 +161,11 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
op.getReassociationIndices(), isExpansion);
}
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-/// dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rule:
+/// if a dimension in the collapsed type is static, then the corresponding
+/// dimensions in the expanded shape should be
/// a) static
/// b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-/// corresponding dimensions in the expanded type should be dynamic. This
-/// rule is only needed with reshape operations that are expanding.
LogicalResult reshapeLikeShapesAreCompatible(
function_ref<LogicalResult(const Twine &)> emitError,
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -156,9 +174,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
+enum class ReshapeOpKind { kExpand, kCollapse };
+
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +201,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
rewriter.getContext());
if (!reassociationIndices)
return failure();
- rewriter.replaceOpWithNewOp<ReshapeOpTy>(
- reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+ if constexpr (opKind == ReshapeOpKind::kExpand) {
+ SmallVector<OpFoldResult> outputShape(
+ getMixedValues(reshapeOp.getStaticOutputShape(),
+ reshapeOp.getOutputShape(), rewriter));
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+ outputShape);
+ } else {
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+ }
return success();
}
};
@@ -215,7 +245,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+ typename DimOpTy, typename TensorTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +353,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
if (!composedReassociation)
return failure();
+ SmallVector<OpFoldResult> outputShape(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+ expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+ outputShape);
return success();
}
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVect...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/90040
More information about the llvm-commits
mailing list