[Mlir-commits] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 09:03:23 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Gaurav Shukla (Shukla-Gaurav)

<details>
<summary>Changes</summary>

This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values.  This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation that's also used in `tensor.extract_slice`.

Differential Revision: https://reviews.llvm.org/D140821

---

Patch is 227.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69267.diff


51 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+68-20) 
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+76-23) 
- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+44-6) 
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+2-3) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (-1) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+12-3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+6-4) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+17-8) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+8-7) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (+1) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-3) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+35-6) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+9-2) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+42-14) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+3-2) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+16-8) 
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+72-14) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+3-4) 
- (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8) 
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-10) 
- (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+81-33) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+3-1) 
- (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3) 
- (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10) 
- (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+18-12) 
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+71-37) 
- (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+57-44) 
- (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+14-10) 
- (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+27-22) 
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-3) 
- (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14) 
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+17-18) 
- (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+9-7) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+12-10) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+14-24) 
- (modified) mlir/test/Dialect/MemRef/ops.mlir (+40-32) 
- (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+3-2) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6) 
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+8-8) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+53-59) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-3) 
- (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+4-2) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+5-16) 
- (modified) mlir/test/Dialect/Tensor/ops.mlir (+16-2) 
- (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8ea46bf3b275e5 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]]
-        : memref<?x?xf32> into memref<?x5x?xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
+        : memref<?x32xf32> into memref<?x?x32xf32>
     ```
 
-    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
-    dynamic in the result type. Otherwise, the op would be ambiguous, as it
-    would not be clear how the source dimension is extended.
-
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,29 +1613,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
+
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation),
+    [{
+      SmallVector<OpFoldResult> inputShape = 
+          getMixedSizes($_builder, $_state.location, src);
+      std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+      auto status = 
+          inferOutputShape($_builder, $_state.location, 
+                           resultType.cast<MemRefType>(), 
+                           reassociation, inputShape, outputShape);
+      (void) status;
+      assert(succeeded(status) && "unable to infer output shape"); 
+      build($_builder, $_state, resultType.cast<MemRefType>(), src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            outputShape.second, outputShape.first);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
+      auto [staticOutputShape, dynamicOutputShape] =
+          decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+      build($_builder, $_state, resultType, src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            dynamicOutputShape, staticOutputShape);
     }]>,
 
     // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
+    [{
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps,
+            outputShape);
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
@@ -1657,6 +1691,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
+
+    // Infer the output shape for a memref.expand_shape when it is possible
+    // to do so.
+    static LogicalResult inferOutputShape(
+        OpBuilder &b, Location loc, MemRefType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape,
+        std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
   }];
 
   let hasVerifier = 1;
@@ -1707,6 +1749,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1718,7 +1766,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1736,7 +1784,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..d3a1250f7e8985 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyRankedTensor:$result)> {
+    Results<(outs AnyTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1102,43 +1097,95 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions. It is
-    represented with an array of DenseI64ArrayAttr attribute. Entries in the
-    array are referred to as reassociation maps.
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
+    maps applied to the result tensor with the higher rank must result in the
+    operand tensor with the smaller rank.
 
-    The reassociation maps are applied to the result shape to obtain the operand
-    shape.
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
+        : tensor<?x32xf32> into tensor<?x?x32xf32>
     ```
   }];
+
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation),
+    [{
+      SmallVector<OpFoldResult> inputShape = 
+          getMixedSizes($_builder, $_state.location, src);
+      std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+      auto status = 
+          inferOutputShape($_builder, $_state.location, 
+                           resultType.cast<RankedTensorType>(), 
+                           reassociation, inputShape, outputShape);
+      (void) status;
+      assert(succeeded(status) && "unable to infer output shape"); 
+      build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            outputShape.second, outputShape.first);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
+      auto [staticOutputShape, dynamicOutputShape] =
+          decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+      build($_builder, $_state, resultType, src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            dynamicOutputShape, staticOutputShape);
+    }]>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
+    [{
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices,
+            outputShape);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+    // Infer the output shape for a tensor.expand_shape when it is possible
+    // to do so.
+    static LogicalResult inferOutputShape(
+        OpBuilder &b, Location loc, RankedTensorType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape,
+        std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
   }];
 
   let hasVerifier = 1;
@@ -1146,6 +1193,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1211,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1174,7 +1227,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1192,7 +1245,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..e46d4ff558916f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,28 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+LogicalResult inferExpandShapeOutputShape(
+    OpBuilder &b, Location loc, RankedTensorType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
+
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +84,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+    ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -156,9 +178,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
+enum class ReshapeOpKind { kExpand, kCollapse };
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +205,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+    if constexpr (opKind == ReshapeOpKind::kExpand) {
+      SmallVector<OpFoldResult> outputShape(
+          getMixedValues(reshapeOp.getStaticOutputShape(),
+                         reshapeOp.getOutputShape(), rewriter));
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+          outputShape);
+    } else {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+    }
     return success();
   }
 };
@@ -215,7 +249,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+          typename DimOpTy, typename TensorTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +357,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
+    SmallVector<OpFoldResult> outputShape(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+        outputShape);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dia...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69267


More information about the Mlir-commits mailing list