[Mlir-commits] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)

Gaurav Shukla llvmlistbot at llvm.org
Wed Apr 10 05:41:49 PDT 2024


https://github.com/Shukla-Gaurav updated https://github.com/llvm/llvm-project/pull/69267

>From 3f0ee5961221923906f66b8c09d2ccb0d67a671c Mon Sep 17 00:00:00 2001
From: Ramiro Leal-Cavazos <ramiroleal050 at gmail.com>
Date: Mon, 16 Oct 2023 17:02:23 -0700
Subject: [PATCH] [MLIR] Generalize expand_shape to take shape as explicit
 input

*DO NOT SUBMIT*
(This patch needs to be tested in IREE backend)

This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values.  This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation
that's also used in `tensor.extract_slice`.

Differential Revision: https://reviews.llvm.org/D140821
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 88 +++++++++++++----
 .../mlir/Dialect/Tensor/IR/TensorOps.td       | 99 ++++++++++++++-----
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      | 50 ++++++++--
 .../mlir/Dialect/Utils/StaticValueUtils.h     |  5 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  1 -
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 15 ++-
 .../Transforms/ConvertConv2DToImg2Col.cpp     |  2 +-
 .../Transforms/DataLayoutPropagation.cpp      | 10 +-
 .../Linalg/Transforms/DropUnitDims.cpp        | 25 +++--
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 15 +--
 .../Linalg/Transforms/SplitReduction.cpp      |  1 +
 .../Dialect/Linalg/Transforms/Transforms.cpp  |  8 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 43 ++++++--
 .../Transforms/SparseTensorRewriting.cpp      | 11 ++-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 56 ++++++++---
 .../BufferizableOpInterfaceImpl.cpp           |  7 +-
 .../Transforms/PackAndUnpackPatterns.cpp      | 24 +++--
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 86 +++++++++++++---
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp   |  7 +-
 .../expand-then-convert-to-llvm.mlir          |  8 +-
 .../MemRefToLLVM/memref-to-llvm.mlir          |  4 +-
 .../Linalg/bubble-up-extract-slice-op.mlir    |  4 +-
 mlir/test/Dialect/Tensor/bufferize.mlir       | 16 +--
 mlir/test/Dialect/Tensor/fold-empty-op.mlir   |  5 +-
 .../Tensor/fold-reassociative-reshapes.mlir   |  6 +-
 mlir/test/Dialect/Tensor/ops.mlir             | 18 +++-
 .../Dialect/Tensor/simplify-pack-unpack.mlir  | 14 +--
 27 files changed, 469 insertions(+), 159 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..8ea46bf3b275e5 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]]
-        : memref<?x?xf32> into memref<?x5x?xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]] [%sz0, %sz1, 32]
+        : memref<?x32xf32> into memref<?x?x32xf32>
     ```
 
-    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
-    dynamic in the result type. Otherwise, the op would be ambiguous, as it
-    would not be clear how the source dimension is extended.
-
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,29 +1613,72 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
+
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation),
+    [{
+      SmallVector<OpFoldResult> inputShape = 
+          getMixedSizes($_builder, $_state.location, src);
+      std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+      auto status = 
+          inferOutputShape($_builder, $_state.location, 
+                           resultType.cast<MemRefType>(), 
+                           reassociation, inputShape, outputShape);
+      (void) status;
+      assert(succeeded(status) && "unable to infer output shape"); 
+      build($_builder, $_state, resultType.cast<MemRefType>(), src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            outputShape.second, outputShape.first);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
+      auto [staticOutputShape, dynamicOutputShape] =
+          decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+      build($_builder, $_state, resultType, src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            dynamicOutputShape, staticOutputShape);
     }]>,
 
     // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
+    [{
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps,
+            outputShape);
     }]>,
 
     // Builder that infers the result layout map. The result shape must be
@@ -1657,6 +1691,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
+
+    // Infer the output shape for a memref.expand_shape when it is possible
+    // to do so.
+    static LogicalResult inferOutputShape(
+        OpBuilder &b, Location loc, MemRefType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape,
+        std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
   }];
 
   let hasVerifier = 1;
@@ -1707,6 +1749,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1718,7 +1766,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1736,7 +1784,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..d3a1250f7e8985 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyRankedTensor:$result)> {
+    Results<(outs AnyTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1102,43 +1097,95 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions. It is
-    represented with an array of DenseI64ArrayAttr attribute. Entries in the
-    array are referred to as reassociation maps.
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
+    maps applied to the result tensor with the higher rank must result in the
+    operand tensor with the smaller rank.
 
-    The reassociation maps are applied to the result shape to obtain the operand
-    shape.
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]] [%sz0, %sz1, 32]
+        : tensor<?x32xf32> into tensor<?x?x32xf32>
     ```
   }];
+
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation),
+    [{
+      SmallVector<OpFoldResult> inputShape = 
+          getMixedSizes($_builder, $_state.location, src);
+      std::pair<SmallVector<int64_t>, SmallVector<Value>> outputShape;
+      auto status = 
+          inferOutputShape($_builder, $_state.location, 
+                           resultType.cast<RankedTensorType>(), 
+                           reassociation, inputShape, outputShape);
+      (void) status;
+      assert(succeeded(status) && "unable to infer output shape"); 
+      build($_builder, $_state, resultType.cast<RankedTensorType>(), src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            outputShape.second, outputShape.first);
+    }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
+      auto [staticOutputShape, dynamicOutputShape] =
+          decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+      build($_builder, $_state, resultType, src,
+            getReassociationIndicesAttribute($_builder, reassociation),
+            dynamicOutputShape, staticOutputShape);
+    }]>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
+    [{
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices,
+            outputShape);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+    // Infer the output shape for a tensor.expand_shape when it is possible
+    // to do so.
+    static LogicalResult inferOutputShape(
+        OpBuilder &b, Location loc, RankedTensorType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape,
+        std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
   }];
 
   let hasVerifier = 1;
@@ -1146,6 +1193,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1211,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1174,7 +1227,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1192,7 +1245,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..e46d4ff558916f 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,28 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+LogicalResult inferExpandShapeOutputShape(
+    OpBuilder &b, Location loc, RankedTensorType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape);
+
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +84,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+    ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -156,9 +178,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
+enum class ReshapeOpKind { kExpand, kCollapse };
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +205,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+    if constexpr (opKind == ReshapeOpKind::kExpand) {
+      SmallVector<OpFoldResult> outputShape(
+          getMixedValues(reshapeOp.getStaticOutputShape(),
+                         reshapeOp.getOutputShape(), rewriter));
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+          outputShape);
+    } else {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+    }
     return success();
   }
 };
@@ -215,7 +249,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+          typename DimOpTy, typename TensorTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +357,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
+    SmallVector<OpFoldResult> outputShape(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+        outputShape);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 /// Decompose a vector of mixed static or dynamic values into the
 /// corresponding pair of arrays. This is the inverse function of
 /// `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
-                     const SmallVectorImpl<OpFoldResult> &mixedValues);
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
 
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..22ff502a7be7a4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..fb43dae305b8de 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -586,9 +586,18 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
       return failure();
 
     Location loc = oldFill.getLoc();
-    auto newInit = rewriter.create<TensorReshapeOp>(
-        loc, reshapeOp.getResultType(), oldFill.output(),
-        reshapeOp.getReassociation());
+    TensorReshapeOp newInit;
+    if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
+
+      newInit = rewriter.create<TensorReshapeOp>(
+          loc, reshapeOp.getResultType(), oldFill.output(),
+          reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
+          reshapeOp.getStaticOutputShape());
+    } else {
+      newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
+                                                 oldFill.output(),
+                                                 reshapeOp.getReassociation());
+    }
     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
                                         ValueRange{newInit});
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 420b04b3ee28cf..81d44ba04fa1d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -349,7 +349,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
   SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
                                                                       {2, 3}};
 
-  Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
       loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
       batchMatVecReassociationIndice);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7fd88dec71d491..9a2493a59e019e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -757,7 +757,10 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
 
-  ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+  auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
+  if (!expandTy)
+    return failure();
+  ArrayRef<int64_t> dstShape = expandTy.getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       expandOp.getReassociationIndices();
   // Project inner tile pos to the dim pos after expanding. For example, if dims
@@ -796,9 +799,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
     nextPos += 1;
   }
 
-  RankedTensorType newExpandType =
-      tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
-                                      projectedInnerDimsPos, newOuterDimsPerm);
+  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
+      expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
       newReassocIndices);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 023ea277bcf499..cef61b20e40a00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -252,7 +253,7 @@ replaceUnitDimIndexOps(GenericOp genericOp,
 /// Expand the given `value` so that the type matches the type of `origDest`.
 /// The `reassociation` is used when `rankReductionStrategy` is set to
 /// `RankReductionStrategy::ReassociativeReshape`.
-static Value
+static FailureOr<Value>
 expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
             ArrayRef<ReassociationIndices> reassociation,
             ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
@@ -272,8 +273,9 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
   assert(rankReductionStrategy ==
              ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
          "unknown rank reduction strategy");
-  return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
-                                                reassociation);
+  return rewriter
+      .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+      .getResult();
 }
 
 /// Collapse the given `value` so that the type matches the type of
@@ -536,9 +538,13 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
       resultReplacements.push_back(result);
       continue;
     }
-    resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
-                                             reassociations[opOperandIndex],
-                                             options.rankReductionStrategy));
+    FailureOr<Value> expandedValue = expandValue(
+        rewriter, loc, result, origDest, reassociations[opOperandIndex],
+        options.rankReductionStrategy);
+    if (failed(expandedValue)) {
+      return rewriter.notifyMatchFailure(genericOp, "unable to expand result");
+    }
+    resultReplacements.push_back(*expandedValue);
   }
 
   rewriter.replaceOp(genericOp, resultReplacements);
@@ -669,10 +675,13 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
           padOp.getResultType().getElementType());
     }
 
-    Value expandedValue =
+    FailureOr<Value> expandedValue =
         expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
                     reassociationMap, options.rankReductionStrategy);
-    rewriter.replaceOp(padOp, expandedValue);
+    if (failed(expandedValue)) {
+      return rewriter.notifyMatchFailure(padOp, "unable to expand result");
+    }
+    rewriter.replaceOp(padOp, *expandedValue);
     return success();
   }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 373e9cfc3ce719..5a34ecbaf50b15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -843,8 +843,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          linalgOp.getLoc(), expandedOutputType, opOperand.get(),
-          reassociation));
+          loc, expandedOutputType, opOperand.get(), reassociation));
     } else {
       outputs.push_back(opOperand.get());
     }
@@ -1615,15 +1614,17 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
-        Value result = rewriter.create<memref::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
-        results.push_back(result);
+        MemRefType expandShapeResultType = MemRefType::get(
+            originalResultType.getShape(), originalResultType.getElementType());
+        result = rewriter.create<memref::ExpandShapeOp>(
+            loc, expandShapeResultType, collapsedOpResult, reassociation);
       } else {
-        Value result = rewriter.create<tensor::ExpandShapeOp>(
+        result = rewriter.create<tensor::ExpandShapeOp>(
             loc, originalResultType, collapsedOpResult, reassociation);
-        results.push_back(result);
       }
+      results.push_back(result);
     } else {
       results.push_back(collapsedOpResult);
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 6559c86c9e0ff5..5bfdbc6d0bb59c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -114,6 +114,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     Type newType = RankedTensorType::get(
         newShape,
         cast<RankedTensorType>(operand->get().getType()).getElementType());
+
     Value newInput = b.create<tensor::ExpandShapeOp>(
         loc, newType, operand->get(), reassociation);
     newInputs.push_back(newInput);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..c41a899b2e6f5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -329,11 +329,13 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                              /*transposeOp=*/nullptr};
     }
   }
+
   // 5. Expand from the padded result to the stripMinedShape.
+  auto expandShapeResultType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
+      loc, expandShapeResultType, padOp.getResult(),
+      packingMetadata.reassociations);
 
   // 6. Transpose stripMinedShape to packedShape.
   SmallVector<int64_t> transpPerm =
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..4c205d3bb360a8 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2237,6 +2237,17 @@ FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
                          srcType.getMemorySpace());
 }
 
+LogicalResult ExpandShapeOp::inferOutputShape(
+    OpBuilder &b, Location loc, MemRefType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+  auto expandedTensorType =
+      getTensorTypeFromMemRefType(expandedType).cast<RankedTensorType>();
+  return inferExpandShapeOutputShape(b, loc, expandedTensorType, reassociation,
+                                     inputShape, outputShape);
+}
+
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                           ArrayRef<int64_t> resultShape, Value src,
                           ArrayRef<ReassociationIndices> reassociation) {
@@ -2247,7 +2258,9 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
   // Failure of this assertion usually indicates a problem with the source
   // type, e.g., could not get strides/offset.
   assert(succeeded(resultType) && "could not compute layout");
-  build(builder, result, *resultType, src, reassociation);
+  SmallVector<OpFoldResult> outputShape(
+      getMixedValues(resultShape, ValueRange{}, builder));
+  build(builder, result, *resultType, src, reassociation, outputShape);
 }
 
 LogicalResult ExpandShapeOp::verify() {
@@ -2280,14 +2293,28 @@ LogicalResult ExpandShapeOp::verify() {
     return emitOpError("expected expanded type to be ")
            << *expectedResultType << " but found " << resultType;
 
+  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+    return emitOpError("expected number of static shape bounds to be equal to "
+                       "the output rank (")
+           << resultType.getRank() << ") but found "
+           << getStaticOutputShape().size() << " inputs instead";
+
+  if ((int64_t)getOutputShape().size() !=
+      llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+    return emitOpError("mismatch in dynamic dims in output_shape and "
+                       "static_output_shape: static_output_shape has ")
+           << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+           << " dynamic dims while output_shape has " << getOutputShape().size()
+           << " values";
+
   return success();
 }
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
-      context);
+  results.add<
+      ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
 /// Compute the layout map after collapsing a given source MemRef type with the
@@ -2488,9 +2515,11 @@ struct CollapseShapeOpMemRefCastFolder
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
-              CollapseShapeOpMemRefCastFolder>(context);
+  results.add<
+      ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+      ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+                                memref::DimOp, MemRefType>,
+      CollapseShapeOpMemRefCastFolder>(context);
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..149597df332907 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -952,8 +952,15 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       auto rtp = getRankedTensorType(op.getResult());
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
-                                                op.getReassociation());
+      ReshapeOp reshape;
+      if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
+        reshape = rewriter.create<ReshapeOp>(
+            loc, denseTp, op.getSrc(), op.getReassociation(),
+            op.getOutputShape(), op.getStaticOutputShape());
+      } else {
+        reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
+                                             op.getReassociation());
+      }
       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
       rewriter.replaceOp(op, convert);
       return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..e86d318167e421 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1606,6 +1606,15 @@ int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
   llvm_unreachable("could not find reassociation group");
 }
 
+LogicalResult ExpandShapeOp::inferOutputShape(
+    OpBuilder &b, Location loc, RankedTensorType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+  return inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
+                                     inputShape, outputShape);
+}
+
 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
@@ -1689,7 +1698,24 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
+  auto srcType = getSrcType();
+  auto resultType = getResultType();
+
+  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+    return emitOpError("expected number of static shape dims to be equal to "
+                       "the output rank (")
+           << resultType.getRank() << ") but found "
+           << getStaticOutputShape().size() << " inputs instead";
+
+  if ((int64_t)getOutputShape().size() !=
+      llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+    return emitOpError("mismatch in dynamic dims in output_shape and "
+                       "static_output_shape: static_output_shape has ")
+           << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+           << " dynamic dims while output_shape has " << getOutputShape().size()
+           << " values";
+
+  return verifyTensorReshapeOp(*this, resultType, srcType);
 }
 
 LogicalResult CollapseShapeOp::verify() {
@@ -1873,23 +1899,25 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
-              FoldReshapeWithConstant<ExpandShapeOp>,
-              FoldReshapeWithSplat<ExpandShapeOp>,
-              FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
-              FoldDimOfCollapseShape>(context);
+  results.add<
+      ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      FoldReshapeWithConstant<ExpandShapeOp>,
+      FoldReshapeWithSplat<ExpandShapeOp>,
+      FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+      FoldDimOfCollapseShape>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results
-      .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
-           FoldReshapeWithConstant<CollapseShapeOp>,
-           FoldReshapeWithSplat<CollapseShapeOp>,
-           FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
-          context);
+  results.add<
+      ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+      ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+                                tensor::DimOp, RankedTensorType>,
+      FoldReshapeWithConstant<CollapseShapeOp>,
+      FoldReshapeWithSplat<CollapseShapeOp>,
+      FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
+      context);
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 58ea4cc4da3c36..9d28fdcd348ab0 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -338,9 +338,12 @@ struct ExpandShapeOpInterface
 
     // Memref result type is inferred by the builder based on reassociation
     // indices and result shape.
-    replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, tensorResultType.getShape(), *buffer,
+    auto expandTy = MemRefType::get(tensorResultType.getShape(),
+                                    tensorResultType.getElementType());
+    Value memrefExpandOp = rewriter.create<memref::ExpandShapeOp>(
+        op->getLoc(), expandTy, *buffer,
         expandShapeOp.getReassociationIndices());
+    replaceOpWithBufferizedValues(rewriter, op, memrefExpandOp);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 666ac56c6cd5cd..7011ce23b55a6b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -52,12 +52,16 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
 
-  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
-                     Type newOperandType, ArrayAttr reassociation) const {
+  FailureOr<Value>
+  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+               Type newOperandType,
+               ArrayRef<ReassociationIndices> reassociation) const {
     if (operand.getType() == newOperandType)
       return operand;
-    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
-                                                  reassociation);
+    return rewriter
+        .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+                                       reassociation)
+        .getResult();
   }
 
   /// Returns success() if it is only packing on the innermost dimension.
@@ -96,10 +100,14 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
       return failure();
-    Value expanded = insertExpand(
-        rewriter, packOp.getLoc(), packOp.getSource(), destType,
-        getReassociationIndicesAttribute(rewriter, *reassociation));
-    rewriter.replaceOp(packOp, expanded);
+    FailureOr<Value> expanded =
+        insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
+                     *reassociation);
+    if (failed(expanded)) {
+      return rewriter.notifyMatchFailure(
+          packOp, "unable to expand source of tensor.pack");
+    }
+    rewriter.replaceOp(packOp, *expanded);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 41c7af4593c77c..bf517e15a37dc6 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 
@@ -16,6 +18,69 @@
 
 using namespace mlir;
 
+LogicalResult mlir::inferExpandShapeOutputShape(
+    OpBuilder &b, Location loc, RankedTensorType expandedType,
+    ArrayRef<ReassociationIndices> reassociation,
+    ArrayRef<OpFoldResult> inputShape,
+    std::pair<SmallVector<int64_t>, SmallVector<Value>> &outputShape) {
+  SmallVector<Value> outputShapeValues;
+  SmallVector<int64_t> outputShapeInts;
+  // For zero-rank inputs, all dims in result shape are unit extent.
+  if (inputShape.empty()) {
+    outputShapeInts.resize(expandedType.getRank(), 1);
+    outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+    return success();
+  }
+
+  // Check for all static shapes.
+  if (expandedType.hasStaticShape()) {
+    ArrayRef<int64_t> shapeInts = expandedType.getShape();
+    outputShapeInts.assign(shapeInts.begin(), shapeInts.end());
+    outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+    return success();
+  }
+
+  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+  for (const auto &it : llvm::enumerate(reassociation)) {
+    ReassociationIndices indexGroup = it.value();
+
+    int64_t indexGroupStaticSizesProductInt = 1;
+    bool foundDynamic = false;
+    for (int64_t index : indexGroup) {
+      int64_t outputDimSize = expandedType.getDimSize(index);
+      // Cannot infer expanded shape with multiple dynamic dims in the
+      // same reassociation group!
+      if (ShapedType::isDynamic(outputDimSize)) {
+        if (foundDynamic)
+          return failure();
+        foundDynamic = true;
+      } else {
+        outputShapeInts[index] = outputDimSize;
+        indexGroupStaticSizesProductInt *= outputDimSize;
+      }
+    }
+    if (!foundDynamic)
+      continue;
+
+    int64_t inputIndex = it.index();
+    // Call get<Value>() under the assumption that we're not casting
+    // dynamism.
+    Value indexGroupSize = inputShape[inputIndex].get<Value>();
+    Value indexGroupStaticSizesProduct =
+        b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+    Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
+        loc, indexGroupSize, indexGroupStaticSizesProduct);
+    outputShapeValues.push_back(dynamicDimSize);
+  }
+
+  if (llvm::count(outputShapeInts, ShapedType::kDynamic) !=
+      outputShapeValues.size())
+    return failure();
+
+  outputShape = std::make_pair(outputShapeInts, outputShapeValues);
+  return success();
+}
+
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                         ShapedType targetType) {
@@ -168,7 +233,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
 }
 
 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+    ArrayRef<ReassociationExprs> reassociationExprs) {
   SmallVector<ReassociationIndices, 2> reassociationIndices;
   for (const auto &exprs : reassociationExprs) {
     ReassociationIndices indices;
@@ -230,24 +295,17 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
   unsigned expandedDimStart = 0;
   for (const auto &map : llvm::enumerate(reassociationMaps)) {
-    std::optional<int64_t> dynamicShape;
+    bool foundDynamicShape = false;
     int64_t linearizedStaticShape = 1;
+
     for (const auto &dim : llvm::enumerate(
              expandedShape.slice(expandedDimStart, map.value().size()))) {
-      if (ShapedType::isDynamic(dim.value())) {
-        if (isExpandingReshape && dynamicShape) {
-          return emitError("invalid to have a single dimension (" +
-                           Twine(map.index()) +
-                           ") expanded into multiple dynamic dims (" +
-                           Twine(expandedDimStart + dynamicShape.value()) +
-                           "," + Twine(expandedDimStart + dim.index()) + ")");
-        }
-        dynamicShape = dim.index();
-      } else {
+      if (ShapedType::isDynamic(dim.value()))
+        foundDynamicShape = true;
+      else
         linearizedStaticShape *= dim.value();
-      }
     }
-    if (dynamicShape) {
+    if (foundDynamicShape) {
       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
         return emitError(
             "expected dimension " + Twine(map.index()) +
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 1e8197e1094424..74a53709592dd2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -180,9 +180,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
-                     const SmallVectorImpl<OpFoldResult> &mixedValues) {
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
   SmallVector<int64_t> staticValues;
   SmallVector<Value> dynamicValues;
   for (const auto &it : mixedValues) {
@@ -193,7 +192,7 @@ decomposeMixedValues(Builder &b,
       dynamicValues.push_back(it.get<Value>());
     }
   }
-  return {b.getI64ArrayAttr(staticValues), dynamicValues};
+  return {staticValues, dynamicValues};
 }
 
 /// Helper to sort `values` according to matching `keys`.
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 87d613986c7c3f..9158592f755e7f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -453,7 +453,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 
 func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
   // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
@@ -510,7 +510,7 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
 // -----
 
 func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
-  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+  %0 = memref.expand_shape %arg0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
   return %0 : memref<1x1xf32>
 }
 
@@ -571,8 +571,8 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 
 // -----
 
-func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
+func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<1x2x?xf32> {
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0]: memref<1x?xf32> into memref<1x2x?xf32>
   return %0 : memref<1x2x?xf32>
 }
 
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 37999d6fc14ad1..baf9cfe610a5a0 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -334,9 +334,9 @@ memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
 // CHECK-LABEL: func @expand_shape_static(
 // CHECK-SAME:         %[[ARG:.*]]: memref<{{.*}}>)
 func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
-  // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]]
+  // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
   // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
index 0e353a1fa43fcb..4bf81820f0e805 100644
--- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
@@ -165,7 +165,9 @@ func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
   %init = tensor.empty(%width) : tensor<1x?xf32>
   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
   %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
-  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
+  %c0 = arith.constant 0 : index
+  %sz0 = tensor.dim %slice, %c0 : tensor<?xf32>
+  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor<?xf32> into tensor<1x1x1x?xf32>
   return %expand : tensor<1x1x1x?xf32>
 }
 
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 815bc383af95a6..64d9c86ba2dee5 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -367,11 +367,11 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
 
 // CHECK-LABEL: func @tensor.expand_shape(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
-func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
+func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
   // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
-  // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
-  %0 = tensor.expand_shape %t1 [[0, 1], [2]]
+  // CHECK-SAME: [0, 1], [2]] output_shape [2, %sz0, 10] : memref<?x10xf32> into memref<2x?x10xf32>
+  %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
       : tensor<?x10xf32> into tensor<2x?x10xf32>
 
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
@@ -384,14 +384,14 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
 // CHECK-LABEL: func @tensor.expand_shape_of_slice(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
 func.func @tensor.expand_shape_of_slice(
-    %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
+    %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
       tensor<?x20xf32> to tensor<?x10xf32>
   // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
-  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
-  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
+  // CHECK-SAME: [0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
       tensor<?x10xf32> into tensor<?x7x2x5xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
@@ -407,8 +407,8 @@ func.func @tensor.expand_shape_of_scalar_slice(
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] :  memref<?xf32> to memref<f32, strided<[], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
-  %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
+  %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
   return %1 : tensor<1xf32>
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 15f841f2128edb..e200a4f8926130 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -13,10 +13,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
-func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
   %0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
-  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
-      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 625408dfefe216..d3ac6ce792f365 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -11,9 +11,11 @@ func.func @expand_shape_of_rank_reducing_extract(
 {
   %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
       : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
+  %c0 = arith.constant 0 : index
+  %sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5]
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
-  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
+  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5]
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
   return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
 }
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2b0a74acce0826..378137a14b59ff 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -194,12 +194,26 @@ func.func @insert_slice(
 func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
     -> (tensor<f32>, tensor<1x1xf32>) {
   %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
-  %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
+  %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
   return %0, %1 : tensor<f32>, tensor<1x1xf32>
 }
 // CHECK-LABEL: func @tensor_reshape_zero_dim
 //       CHECK:   tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
-//       CHECK:   tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+//       CHECK:   tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor<?x?xf32>, %sz0 : index, %sz1 : index, %sz2 : index)
+    -> (tensor<5x?x?x?xf32>) {
+  %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+  return %1 : tensor<5x?x?x?xf32>
+}
+
+// CHECK-LABEL:  func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> {
+//       CHECK:    %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+//       CHECK:    return %expanded : tensor<5x?x?x?xf32>
+//       CHECK:  }
+
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 9948c0246e6ed6..5a2eade0ecccf1 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func.func @single_dim_packing(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<8x32xf32>
 func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
@@ -27,7 +27,7 @@ func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x3
 
 // CHECK-LABEL: func.func @single_last_inner_dim_packing(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
 func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
   %empty = tensor.empty() : tensor<5x8x32xf32>
@@ -39,7 +39,7 @@ func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8
 
 // CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<64xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<64xf32> into tensor<2x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<2x32xf32>
 func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> {
   %empty = tensor.empty() :  tensor<2x32xf32>
@@ -51,7 +51,7 @@ func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf3
 
 // CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
 func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
   %empty = tensor.empty() : tensor<5x8x32xf32>
@@ -85,7 +85,7 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
 
 // CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
   %empty = tensor.empty() : tensor<1x32x1x1xf32>
@@ -98,7 +98,7 @@ func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf3
 
 // CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
   %empty = tensor.empty() : tensor<1x16x1x2xf32>
@@ -111,7 +111,7 @@ func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf3
 
 // CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
   %empty = tensor.empty() : tensor<1x16x2x1xf32>



More information about the Mlir-commits mailing list