[llvm] [mlir] [MLIR] Generalize expand_shape to take shape as explicit input (PR #69267)

Gaurav Shukla via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 13:25:43 PDT 2024


https://github.com/Shukla-Gaurav updated https://github.com/llvm/llvm-project/pull/69267

>From ea8dee0a6602724c39d8f43404b10b39b619e905 Mon Sep 17 00:00:00 2001
From: Ramiro Leal-Cavazos <ramiroleal050 at gmail.com>
Date: Mon, 16 Oct 2023 17:02:23 -0700
Subject: [PATCH] [MLIR] Generalize expand_shape to take shape as explicit
 input

This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values.  This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation
that's also used in `tensor.extract_slice`.

Differential Revision: https://reviews.llvm.org/D140821
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  80 +++--
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  79 +++--
 .../include/mlir/Dialect/Tensor/Utils/Utils.h |   5 +
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      |  58 +++-
 .../mlir/Dialect/Utils/StaticValueUtils.h     |   5 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |   1 -
 mlir/lib/Dialect/Linalg/IR/CMakeLists.txt     |   1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   9 +-
 .../Transforms/ConvertConv2DToImg2Col.cpp     |   2 +-
 .../Transforms/DataLayoutPropagation.cpp      |  10 +-
 .../Linalg/Transforms/DropUnitDims.cpp        |  13 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  67 ++--
 .../Linalg/Transforms/SplitReduction.cpp      |   1 +
 .../Dialect/Linalg/Transforms/Transforms.cpp  |   8 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  82 ++++-
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/SparseTensorRewriting.cpp      |   8 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  85 ++++-
 .../BufferizableOpInterfaceImpl.cpp           |   3 +
 .../Transforms/PackAndUnpackPatterns.cpp      |  24 +-
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |  19 ++
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  83 ++++-
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp   |   7 +-
 .../expand-then-convert-to-llvm.mlir          |  16 +-
 .../MemRefToLLVM/memref-to-llvm.mlir          |   4 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          |  29 +-
 .../TosaToTensor/tosa-to-tensor.mlir          | 114 +++++--
 ...ot-bufferize-empty-tensor-elimination.mlir |   2 +-
 .../Linalg/bubble-up-extract-slice-op.mlir    |   4 +-
 mlir/test/Dialect/Linalg/collapse-dim.mlir    |   6 +-
 .../Linalg/convert-conv2d-to-img2col.mlir     |  20 +-
 .../Linalg/data-layout-propagation.mlir       |  30 +-
 .../Dialect/Linalg/drop-unit-extent-dims.mlir | 108 ++++---
 .../Dialect/Linalg/flatten-elementwise.mlir   |   2 +-
 .../fuse-with-reshape-by-collapsing.mlir      | 101 +++---
 .../Dialect/Linalg/fusion-push-reshape.mlir   |  24 +-
 .../Linalg/reshape_control_fusion.mlir        |   2 +-
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 292 ++++++++++++------
 .../resolve-shaped-type-result-dims.mlir      |   5 +-
 .../Linalg/transform-op-split-reduction.mlir  |  28 +-
 .../Linalg/vectorization-with-patterns.mlir   |   4 +-
 mlir/test/Dialect/MemRef/canonicalize.mlir    |  35 +--
 .../MemRef/expand-strided-metadata.mlir       |  16 +-
 .../Dialect/MemRef/fold-memref-alias-ops.mlir |  22 +-
 mlir/test/Dialect/MemRef/invalid.mlir         |  38 +--
 mlir/test/Dialect/MemRef/ops.mlir             |  72 +++--
 .../Dialect/MemRef/runtime-verification.mlir  |   5 +-
 .../Dialect/SparseTensor/sparse_reshape.mlir  |  12 +-
 mlir/test/Dialect/Tensor/bufferize.mlir       |  24 +-
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 112 ++++---
 mlir/test/Dialect/Tensor/fold-empty-op.mlir   |   5 +-
 .../Tensor/fold-reassociative-reshapes.mlir   |   6 +-
 mlir/test/Dialect/Tensor/invalid.mlir         |  21 +-
 mlir/test/Dialect/Tensor/ops.mlir             |  18 +-
 .../Dialect/Tensor/simplify-pack-unpack.mlir  |  14 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 56 files changed, 1209 insertions(+), 634 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..14b8d95ea15b41 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]]
-        : memref<?x?xf32> into memref<?x5x?xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : memref<?x32xf32> into memref<?x?x32xf32>
     ```
 
-    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
-    dynamic in the result type. Otherwise, the op would be ambiguous, as it
-    would not be clear how the source dimension is extended.
-
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
+
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
 
-    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps,
+            outputShape);
     }]>,
 
+    // Builder that infers the result layout map. The result shape must be
+    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // the op will be inferred using the inferOutputShape() method.
+    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+               "ArrayRef<ReassociationIndices>":$reassociation)>,
+
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>
+               "ArrayRef<ReassociationIndices>":$reassociation,
+               "ArrayRef<OpFoldResult>":$outputShape)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
+
+    // Infer the output shape for a memref.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, MemRefType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..a403e89a39f98c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyRankedTensor:$result)> {
+    Results<(outs AnyTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1102,43 +1097,75 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions. It is
-    represented with an array of DenseI64ArrayAttr attribute. Entries in the
-    array are referred to as reassociation maps.
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
+    maps applied to the result tensor with the higher rank must result in the
+    operand tensor with the smaller rank.
 
-    The reassociation maps are applied to the result shape to obtain the operand
-    shape.
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : tensor<?x32xf32> into tensor<?x?x32xf32>
     ```
   }];
+
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices,
+            outputShape);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+    // Infer the output shape for a tensor.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, RankedTensorType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1146,6 +1173,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
index d09c9e36f6ff88..da8b278b48e5a3 100644
--- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
@@ -26,6 +26,11 @@ PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
 SmallVector<Value> createDynamicDimValues(OpBuilder &b, Location loc,
                                           Value rankedTensor);
 
+/// Creates Reshape op.
+template <typename ReshapeOp>
+Value createReshapeOp(ReshapeOp oldReshapeOp, OpBuilder &b, Location loc,
+                      RankedTensorType resultTy, Value src);
+
 /// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty.
 /// Fail if `transposeVector` is not a permutation matching the tensor rank.
 FailureOr<RankedTensorType>
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..8a41a0a18b0ab3 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,27 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+std::optional<SmallVector<OpFoldResult>>
+inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<OpFoldResult> inputShape);
+
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +83,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+    ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -140,14 +161,11 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-///    dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rule:
+/// if a dimension in the collapsed type is static, then the corresponding
+/// dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-///    corresponding dimensions in the expanded type should be dynamic. This
-///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -156,9 +174,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
+enum class ReshapeOpKind { kExpand, kCollapse };
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +201,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+    if constexpr (opKind == ReshapeOpKind::kExpand) {
+      SmallVector<OpFoldResult> outputShape(
+          getMixedValues(reshapeOp.getStaticOutputShape(),
+                         reshapeOp.getOutputShape(), rewriter));
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+          outputShape);
+    } else {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+    }
     return success();
   }
 };
@@ -215,7 +245,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+          typename DimOpTy, typename TensorTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +353,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
+    SmallVector<OpFoldResult> outputShape(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+        outputShape);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 /// Decompose a vector of mixed static or dynamic values into the
 /// corresponding pair of arrays. This is the inverse function of
 /// `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
-                     const SmallVectorImpl<OpFoldResult> &mixedValues);
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
 
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index af19ebaea937d0..4b29449c0302f3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index c187563b8f0c4e..0e60dd0cfbcee2 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   MLIRMathDialect
   MLIRMemRefDialect
   MLIRTensorDialect
+  MLIRTensorUtils
   MLIRTilingInterface
   MLIRValueBoundsOpInterface
   MLIRViewLikeInterface
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..ed0143c6284ca5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -586,12 +587,12 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
       return failure();
 
     Location loc = oldFill.getLoc();
-    auto newInit = rewriter.create<TensorReshapeOp>(
-        loc, reshapeOp.getResultType(), oldFill.output(),
-        reshapeOp.getReassociation());
+    Value newInit = tensor::createReshapeOp(
+        reshapeOp, rewriter, loc, reshapeOp.getResultType(), oldFill.output());
+    if (!newInit)
+      return failure();
     rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
                                         ValueRange{newInit});
-
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 420b04b3ee28cf..81d44ba04fa1d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -349,7 +349,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
   SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
                                                                       {2, 3}};
 
-  Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
+  auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
       loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
       batchMatVecReassociationIndice);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7fd88dec71d491..9a2493a59e019e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -757,7 +757,10 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
   ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
   ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
 
-  ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
+  auto expandTy = expandOp.getType().dyn_cast<RankedTensorType>();
+  if (!expandTy)
+    return failure();
+  ArrayRef<int64_t> dstShape = expandTy.getShape();
   SmallVector<ReassociationIndices> reassocIndices =
       expandOp.getReassociationIndices();
   // Project inner tile pos to the dim pos after expanding. For example, if dims
@@ -796,9 +799,8 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
     nextPos += 1;
   }
 
-  RankedTensorType newExpandType =
-      tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
-                                      projectedInnerDimsPos, newOuterDimsPerm);
+  RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
+      expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
   auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
       expandOp.getLoc(), newExpandType, unPackOp.getSource(),
       newReassocIndices);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 023ea277bcf499..65efa18af18f65 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -272,8 +273,9 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
   assert(rankReductionStrategy ==
              ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
          "unknown rank reduction strategy");
-  return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
-                                                reassociation);
+  return rewriter
+      .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
+      .getResult();
 }
 
 /// Collapse the given `value` so that the type matches the type of
@@ -536,9 +538,10 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
       resultReplacements.push_back(result);
       continue;
     }
-    resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
-                                             reassociations[opOperandIndex],
-                                             options.rankReductionStrategy));
+    Value expandedValue = expandValue(rewriter, loc, result, origDest,
+                                      reassociations[opOperandIndex],
+                                      options.rankReductionStrategy);
+    resultReplacements.push_back(expandedValue);
   }
 
   rewriter.replaceOp(genericOp, resultReplacements);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 373e9cfc3ce719..bf3a737f1c2bf3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -625,14 +625,14 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   return success();
 }
 
-/// Epanding the body of a linalg operation requires adaptations of the accessed
-/// loop indices. Specifically, access of indices in the original operation need
-/// to be replaced with linearizations of indices in the expanded op. That
-/// requires the shape of the expanded dimensions to be static (at least all but
-/// the most significant). For now check that these are all statically sized.
-/// Note that this could be extended to handle dynamic case, but the
-/// implementation below uses `affine.apply` which seems to have issues when the
-/// shapes are not static.
+/// Expanding the body of a linalg operation requires adaptations of the
+/// accessed loop indices. Specifically, access of indices in the original
+/// operation need to be replaced with linearizations of indices in the expanded
+/// op. That requires the shape of the expanded dimensions to be static (at
+/// least all but the most significant). For now check that these are all
+/// statically sized. Note that this could be extended to handle dynamic case,
+/// but the implementation below uses `affine.apply` which seems to have issues
+/// when the shapes are not static.
 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
                                           const ExpansionInfo &expansionInfo,
                                           PatternRewriter &rewriter) {
@@ -750,6 +750,31 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
   }
 }
 
+/// Checks if a single dynamic dimension expanded into multiple dynamic
+/// dimensions.
+static LogicalResult
+validateDynamicDimExpansion(LinalgOp linalgOp,
+                            const ExpansionInfo &expansionInfo,
+                            PatternRewriter &rewriter) {
+  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
+    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
+    if (expandedShape.size() == 1)
+      continue;
+    bool foundDynamic = false;
+    for (int64_t shape : expandedShape) {
+      if (ShapedType::isDynamic(shape)) {
+        if (foundDynamic) {
+          return rewriter.notifyMatchFailure(
+              linalgOp, "cannot infer expanded shape with multiple dynamic "
+                        "dims in the same reassociation group");
+        }
+        foundDynamic = true;
+      }
+    }
+  }
+  return success();
+}
+
 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
 /// that those conditions have been satisfied.
@@ -759,6 +784,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                            PatternRewriter &rewriter) {
   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) &&
          "preconditions for fuse operation failed");
+
+  Location loc = linalgOp.getLoc();
   // Check if reshape is expanding or collapsing.
   auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
   auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
@@ -778,6 +805,11 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
           expandedType.getShape(), collapsedType.getShape(), rewriter)))
     return std::nullopt;
 
+  // TODO: With the support of multiple dynamic dims expansion in
+  // tensor.expand_shape op, this case can be handled.
+  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
+    return std::nullopt;
+
   if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
     return std::nullopt;
 
@@ -816,15 +848,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                 /*isExpandingReshape=*/true)))
           return std::nullopt;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            linalgOp.getLoc(), expandedOperandType, opOperand->get(),
-            reassociation));
+            loc, expandedOperandType, opOperand->get(), reassociation));
         continue;
       }
     }
     expandedOpOperands.push_back(opOperand->get());
   }
 
-  Location loc = linalgOp.getLoc();
   SmallVector<Value> outputs;
   for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
@@ -843,8 +873,7 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          linalgOp.getLoc(), expandedOutputType, opOperand.get(),
-          reassociation));
+          loc, expandedOutputType, opOperand.get(), reassociation));
     } else {
       outputs.push_back(opOperand.get());
     }
@@ -1615,15 +1644,17 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
-        Value result = rewriter.create<memref::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
-        results.push_back(result);
+        MemRefType expandShapeResultType = MemRefType::get(
+            originalResultType.getShape(), originalResultType.getElementType());
+        result = rewriter.create<memref::ExpandShapeOp>(
+            loc, expandShapeResultType, collapsedOpResult, reassociation);
       } else {
-        Value result = rewriter.create<tensor::ExpandShapeOp>(
+        result = rewriter.create<tensor::ExpandShapeOp>(
             loc, originalResultType, collapsedOpResult, reassociation);
-        results.push_back(result);
       }
+      results.push_back(result);
     } else {
       results.push_back(collapsedOpResult);
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 6559c86c9e0ff5..5bfdbc6d0bb59c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -114,6 +114,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     Type newType = RankedTensorType::get(
         newShape,
         cast<RankedTensorType>(operand->get().getType()).getElementType());
+
     Value newInput = b.create<tensor::ExpandShapeOp>(
         loc, newType, operand->get(), reassociation);
     newInputs.push_back(newInput);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a17bc8e4cd318f..c41a899b2e6f5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -329,11 +329,13 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
                              /*transposeOp=*/nullptr};
     }
   }
+
   // 5. Expand from the padded result to the stripMinedShape.
+  auto expandShapeResultType =
+      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
   auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
-      loc,
-      RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
-      padOp.getResult(), packingMetadata.reassociations);
+      loc, expandShapeResultType, padOp.getResult(),
+      packingMetadata.reassociations);
 
   // 6. Transpose stripMinedShape to packedShape.
   SmallVector<int64_t> transpPerm =
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..ced7fdd0a90f0f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2237,6 +2237,44 @@ FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
                          srcType.getMemorySpace());
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
+                                MemRefType expandedType,
+                                ArrayRef<ReassociationIndices> reassociation,
+                                ArrayRef<OpFoldResult> inputShape) {
+  std::optional<SmallVector<OpFoldResult>> outputShape =
+      inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
+                                  inputShape);
+  if (!outputShape)
+    return failure();
+  return *outputShape;
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          Type resultType, Value src,
+                          ArrayRef<ReassociationIndices> reassociation,
+                          ArrayRef<OpFoldResult> outputShape) {
+  auto [staticOutputShape, dynamicOutputShape] =
+      decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+  build(builder, result, resultType.cast<MemRefType>(), src,
+        getReassociationIndicesAttribute(builder, reassociation),
+        dynamicOutputShape, staticOutputShape);
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          Type resultType, Value src,
+                          ArrayRef<ReassociationIndices> reassociation) {
+  SmallVector<OpFoldResult> inputShape =
+      getMixedSizes(builder, result.location, src);
+  MemRefType memrefResultTy = resultType.cast<MemRefType>();
+  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
+      builder, result.location, memrefResultTy, reassociation, inputShape);
+  // Failure of this assertion usually indicates presence of multiple
+  // dynamic dimensions in the same reassociation group.
+  assert(succeeded(outputShape) && "unable to infer output shape");
+  build(builder, result, memrefResultTy, src, reassociation, *outputShape);
+}
+
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                           ArrayRef<int64_t> resultShape, Value src,
                           ArrayRef<ReassociationIndices> reassociation) {
@@ -2250,6 +2288,20 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, *resultType, src, reassociation);
 }
 
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          ArrayRef<int64_t> resultShape, Value src,
+                          ArrayRef<ReassociationIndices> reassociation,
+                          ArrayRef<OpFoldResult> outputShape) {
+  // Only ranked memref source values are supported.
+  auto srcType = llvm::cast<MemRefType>(src.getType());
+  FailureOr<MemRefType> resultType =
+      ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
+  // Failure of this assertion usually indicates a problem with the source
+  // type, e.g., could not get strides/offset.
+  assert(succeeded(resultType) && "could not compute layout");
+  build(builder, result, *resultType, src, reassociation, outputShape);
+}
+
 LogicalResult ExpandShapeOp::verify() {
   MemRefType srcType = getSrcType();
   MemRefType resultType = getResultType();
@@ -2266,7 +2318,7 @@ LogicalResult ExpandShapeOp::verify() {
   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
                                   resultType.getShape(),
                                   getReassociationIndices(),
-                                  /*allowMultipleDynamicDimsPerGroup=*/false)))
+                                  /*allowMultipleDynamicDimsPerGroup=*/true)))
     return failure();
 
   // Compute expected result type (including layout map).
@@ -2280,14 +2332,28 @@ LogicalResult ExpandShapeOp::verify() {
     return emitOpError("expected expanded type to be ")
            << *expectedResultType << " but found " << resultType;
 
+  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+    return emitOpError("expected number of static shape bounds to be equal to "
+                       "the output rank (")
+           << resultType.getRank() << ") but found "
+           << getStaticOutputShape().size() << " inputs instead";
+
+  if ((int64_t)getOutputShape().size() !=
+      llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+    return emitOpError("mismatch in dynamic dims in output_shape and "
+                       "static_output_shape: static_output_shape has ")
+           << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+           << " dynamic dims while output_shape has " << getOutputShape().size()
+           << " values";
+
   return success();
 }
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(
-      context);
+  results.add<
+      ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
 }
 
 /// Compute the layout map after collapsing a given source MemRef type with the
@@ -2488,9 +2554,11 @@ struct CollapseShapeOpMemRefCastFolder
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-              ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
-              CollapseShapeOpMemRefCastFolder>(context);
+  results.add<
+      ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+      ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+                                memref::DimOp, MemRefType>,
+      CollapseShapeOpMemRefCastFolder>(context);
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index af3a1b48f45af9..1a3ac2276b2228 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -47,6 +47,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   MLIRSparseTensorEnums
   MLIRSparseTensorUtils
   MLIRTensorDialect
+  MLIRTensorUtils
   MLIRTransforms
   MLIRVectorDialect
 )
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index b117c1694e45b8..c6ee0a696cc361 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
@@ -952,8 +953,11 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
       auto rtp = getRankedTensorType(op.getResult());
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
-      auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
-                                                op.getReassociation());
+      Value reshape =
+          tensor::createReshapeOp(op, rewriter, loc, denseTp, op.getSrc());
+      if (!reshape)
+        return failure();
+
       Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
       rewriter.replaceOp(op, convert);
       return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..857e45a79b2bf6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1606,6 +1606,44 @@ int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
   llvm_unreachable("could not find reassociation group");
 }
 
+FailureOr<SmallVector<OpFoldResult>>
+ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
+                                RankedTensorType expandedType,
+                                ArrayRef<ReassociationIndices> reassociation,
+                                ArrayRef<OpFoldResult> inputShape) {
+  std::optional<SmallVector<OpFoldResult>> outputShape =
+      inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
+                                  inputShape);
+  if (!outputShape)
+    return failure();
+  return *outputShape;
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          Type resultType, Value src,
+                          ArrayRef<ReassociationIndices> reassociation,
+                          ArrayRef<OpFoldResult> outputShape) {
+  auto [staticOutputShape, dynamicOutputShape] =
+      decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
+  build(builder, result, resultType.cast<RankedTensorType>(), src,
+        getReassociationIndicesAttribute(builder, reassociation),
+        dynamicOutputShape, staticOutputShape);
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          Type resultType, Value src,
+                          ArrayRef<ReassociationIndices> reassociation) {
+  SmallVector<OpFoldResult> inputShape =
+      getMixedSizes(builder, result.location, src);
+  auto tensorResultTy = resultType.cast<RankedTensorType>();
+  FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
+      builder, result.location, tensorResultTy, reassociation, inputShape);
+  // Failure of this assertion usually indicates presence of multiple
+  // dynamic dimensions in the same reassociation group.
+  assert(succeeded(outputShape) && "unable to infer output shape");
+  build(builder, result, tensorResultTy, src, reassociation, *outputShape);
+}
+
 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
@@ -1689,7 +1727,24 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
 }
 
 LogicalResult ExpandShapeOp::verify() {
-  return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
+  auto srcType = getSrcType();
+  auto resultType = getResultType();
+
+  if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
+    return emitOpError("expected number of static shape dims to be equal to "
+                       "the output rank (")
+           << resultType.getRank() << ") but found "
+           << getStaticOutputShape().size() << " inputs instead";
+
+  if ((int64_t)getOutputShape().size() !=
+      llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
+    return emitOpError("mismatch in dynamic dims in output_shape and "
+                       "static_output_shape: static_output_shape has ")
+           << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
+           << " dynamic dims while output_shape has " << getOutputShape().size()
+           << " values";
+
+  return verifyTensorReshapeOp(*this, resultType, srcType);
 }
 
 LogicalResult CollapseShapeOp::verify() {
@@ -1873,23 +1928,25 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
-              ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
-              FoldReshapeWithConstant<ExpandShapeOp>,
-              FoldReshapeWithSplat<ExpandShapeOp>,
-              FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
-              FoldDimOfCollapseShape>(context);
+  results.add<
+      ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      FoldReshapeWithConstant<ExpandShapeOp>,
+      FoldReshapeWithSplat<ExpandShapeOp>,
+      FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+      FoldDimOfCollapseShape>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results
-      .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
-           ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
-           FoldReshapeWithConstant<CollapseShapeOp>,
-           FoldReshapeWithSplat<CollapseShapeOp>,
-           FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
-          context);
+  results.add<
+      ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
+      ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
+                                tensor::DimOp, RankedTensorType>,
+      FoldReshapeWithConstant<CollapseShapeOp>,
+      FoldReshapeWithSplat<CollapseShapeOp>,
+      FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
+      context);
 }
 
 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 58ea4cc4da3c36..d078a575f40dda 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -338,6 +338,9 @@ struct ExpandShapeOpInterface
 
     // Memref result type is inferred by the builder based on reassociation
     // indices and result shape.
+    // TODO: Instead of inferring the output shape argument of
+    // memref.expand_shape op, use output_shape argument of tensor.expand_shape
+    // op.
     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
         rewriter, op, tensorResultType.getShape(), *buffer,
         expandShapeOp.getReassociationIndices());
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index 666ac56c6cd5cd..7011ce23b55a6b 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -52,12 +52,16 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
   using OpRewritePattern<PackOp>::OpRewritePattern;
 
-  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
-                     Type newOperandType, ArrayAttr reassociation) const {
+  FailureOr<Value>
+  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
+               Type newOperandType,
+               ArrayRef<ReassociationIndices> reassociation) const {
     if (operand.getType() == newOperandType)
       return operand;
-    return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
-                                                  reassociation);
+    return rewriter
+        .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
+                                       reassociation)
+        .getResult();
   }
 
   /// Returns success() if it is only packing on the innermost dimension.
@@ -96,10 +100,14 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
         getReassociationIndicesForReshape(sourceType, destType);
     if (!reassociation)
       return failure();
-    Value expanded = insertExpand(
-        rewriter, packOp.getLoc(), packOp.getSource(), destType,
-        getReassociationIndicesAttribute(rewriter, *reassociation));
-    rewriter.replaceOp(packOp, expanded);
+    FailureOr<Value> expanded =
+        insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
+                     *reassociation);
+    if (failed(expanded)) {
+      return rewriter.notifyMatchFailure(
+          packOp, "unable to expand source of tensor.pack");
+    }
+    rewriter.replaceOp(packOp, *expanded);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 15381ec520e211..ded257ad15ead0 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -53,6 +53,25 @@ SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
   return dynamicDims;
 }
 
+template <typename ReshapeOp>
+Value mlir::tensor::createReshapeOp(ReshapeOp oldReshapeOp, OpBuilder &b,
+                                    Location loc, RankedTensorType resultTy,
+                                    Value src) {
+  if constexpr (std::is_same<ReshapeOp, mlir::tensor::ExpandShapeOp>::value) {
+    return b
+        .create<ReshapeOp>(loc, resultTy, src, oldReshapeOp.getReassociation(),
+                           oldReshapeOp.getOutputShape(),
+                           oldReshapeOp.getStaticOutputShape())
+        .getResult();
+  }
+  if constexpr (std::is_same<ReshapeOp, mlir::tensor::CollapseShapeOp>::value) {
+    return b
+        .create<ReshapeOp>(loc, resultTy, src, oldReshapeOp.getReassociation())
+        .getResult();
+  }
+  return {};
+}
+
 FailureOr<RankedTensorType>
 mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
                                     ArrayRef<int64_t> transposeVector) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 41c7af4593c77c..6161faf7e30e11 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 
@@ -16,6 +17,67 @@
 
 using namespace mlir;
 
+std::optional<SmallVector<OpFoldResult>>
+mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc,
+                                  ShapedType expandedType,
+                                  ArrayRef<ReassociationIndices> reassociation,
+                                  ArrayRef<OpFoldResult> inputShape) {
+
+  SmallVector<Value> outputShapeValues;
+  SmallVector<int64_t> outputShapeInts;
+  // For zero-rank inputs, all dims in result shape are unit extent.
+  if (inputShape.empty()) {
+    outputShapeInts.resize(expandedType.getRank(), 1);
+    return getMixedValues(outputShapeInts, outputShapeValues, b);
+  }
+
+  // Check for all static shapes.
+  if (expandedType.hasStaticShape()) {
+    ArrayRef<int64_t> staticShape = expandedType.getShape();
+    outputShapeInts.assign(staticShape.begin(), staticShape.end());
+    return getMixedValues(outputShapeInts, outputShapeValues, b);
+  }
+
+  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
+  for (const auto &it : llvm::enumerate(reassociation)) {
+    ReassociationIndices indexGroup = it.value();
+
+    int64_t indexGroupStaticSizesProductInt = 1;
+    bool foundDynamicShape = false;
+    for (int64_t index : indexGroup) {
+      int64_t outputDimSize = expandedType.getDimSize(index);
+      // Cannot infer expanded shape with multiple dynamic dims in the
+      // same reassociation group!
+      if (ShapedType::isDynamic(outputDimSize)) {
+        if (foundDynamicShape)
+          return std::nullopt;
+        foundDynamicShape = true;
+      } else {
+        outputShapeInts[index] = outputDimSize;
+        indexGroupStaticSizesProductInt *= outputDimSize;
+      }
+    }
+    if (!foundDynamicShape)
+      continue;
+
+    int64_t inputIndex = it.index();
+    // Call get<Value>() under the assumption that we're not casting
+    // dynamism.
+    Value indexGroupSize = inputShape[inputIndex].get<Value>();
+    Value indexGroupStaticSizesProduct =
+        b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
+    Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
+        loc, indexGroupSize, indexGroupStaticSizesProduct);
+    outputShapeValues.push_back(dynamicDimSize);
+  }
+
+  if ((int64_t)outputShapeValues.size() !=
+      llvm::count(outputShapeInts, ShapedType::kDynamic))
+    return std::nullopt;
+
+  return getMixedValues(outputShapeInts, outputShapeValues, b);
+}
+
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForReshape(ShapedType sourceType,
                                         ShapedType targetType) {
@@ -168,7 +230,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
 }
 
 SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
+    ArrayRef<ReassociationExprs> reassociationExprs) {
   SmallVector<ReassociationIndices, 2> reassociationIndices;
   for (const auto &exprs : reassociationExprs) {
     ReassociationIndices indices;
@@ -230,24 +292,17 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
     ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
   unsigned expandedDimStart = 0;
   for (const auto &map : llvm::enumerate(reassociationMaps)) {
-    std::optional<int64_t> dynamicShape;
+    bool foundDynamicShape = false;
     int64_t linearizedStaticShape = 1;
+
     for (const auto &dim : llvm::enumerate(
              expandedShape.slice(expandedDimStart, map.value().size()))) {
-      if (ShapedType::isDynamic(dim.value())) {
-        if (isExpandingReshape && dynamicShape) {
-          return emitError("invalid to have a single dimension (" +
-                           Twine(map.index()) +
-                           ") expanded into multiple dynamic dims (" +
-                           Twine(expandedDimStart + dynamicShape.value()) +
-                           "," + Twine(expandedDimStart + dim.index()) + ")");
-        }
-        dynamicShape = dim.index();
-      } else {
+      if (ShapedType::isDynamic(dim.value()))
+        foundDynamicShape = true;
+      else
         linearizedStaticShape *= dim.value();
-      }
     }
-    if (dynamicShape) {
+    if (foundDynamicShape) {
       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
         return emitError(
             "expected dimension " + Twine(map.index()) +
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 1e8197e1094424..74a53709592dd2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -180,9 +180,8 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
-std::pair<ArrayAttr, SmallVector<Value>>
-decomposeMixedValues(Builder &b,
-                     const SmallVectorImpl<OpFoldResult> &mixedValues) {
+std::pair<SmallVector<int64_t>, SmallVector<Value>>
+decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) {
   SmallVector<int64_t> staticValues;
   SmallVector<Value> dynamicValues;
   for (const auto &it : mixedValues) {
@@ -193,7 +192,7 @@ decomposeMixedValues(Builder &b,
       dynamicValues.push_back(it.get<Value>());
     }
   }
-  return {b.getI64ArrayAttr(staticValues), dynamicValues};
+  return {staticValues, dynamicValues};
 }
 
 /// Helper to sort `values` according to matching `keys`.
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index 87d613986c7c3f..b86103422b0745 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -453,7 +453,7 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 
 func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
   // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
@@ -510,7 +510,7 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
 // -----
 
 func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
-  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x1xf32>
+  %0 = memref.expand_shape %arg0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
   return %0 : memref<1x1xf32>
 }
 
@@ -571,13 +571,13 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
 
 // -----
 
-func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]: memref<1x?xf32> into memref<1x2x?xf32>
+func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<1x2x?xf32> {
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0]: memref<1x?xf32> into memref<1x2x?xf32>
   return %0 : memref<1x2x?xf32>
 }
 
 // CHECK-LABEL:   func.func @expand_shape_dynamic(
-// CHECK-SAME:                                    %[[ARG:.*]]: memref<1x?xf32>) -> memref<1x2x?xf32> {
+// CHECK-SAME:              %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
@@ -614,15 +614,15 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
 // -----
 
 func.func @expand_shape_dynamic_with_non_identity_layout(
-            %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
+            %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>, %sz0: index) ->
             memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]:
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [1, 2, %sz0] :
     memref<1x?xf32, strided<[?, ?], offset: ?>> into
     memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
   return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func.func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK-SAME:                                                             %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK-SAME:        %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 37999d6fc14ad1..baf9cfe610a5a0 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -334,9 +334,9 @@ memref.global "private" @gv4 : memref<f32> = dense<1.0> {alignment = 64}
 // CHECK-LABEL: func @expand_shape_static(
 // CHECK-SAME:         %[[ARG:.*]]: memref<{{.*}}>)
 func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
-  // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]]
+  // CHECK: memref.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
   // Reshapes that expand a contiguous tensor with some 1's.
-  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
       : memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   return %0 : memref<1x3x4x1x5xf32>
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 445e8be47678d5..4eaecfd117ece7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -348,7 +348,7 @@ func.func @test_add_2d_all_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32
 // CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
 func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
 
-  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] : tensor<3x4xf32> into tensor<1x3x4xf32>
+  // CHECK: %[[ARG0_EXPANDED:.*]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [1, 3, 4] : tensor<3x4xf32> into tensor<1x3x4xf32>
   // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<2x3x4xf32>
   // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[ARG0_EXPANDED]], %[[ARG1]] : tensor<1x3x4xf32>, tensor<2x3x4xf32>) outs(%[[VAL_0]] : tensor<2x3x4xf32>) {
   // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32):
@@ -871,7 +871,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xf32> into tensor<1x4xf32>
+  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xf32> into tensor<1x4xf32>
   %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xf32>) -> tensor<1x4xf32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
@@ -882,7 +882,7 @@ func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield [[RES]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32>
   %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xf32>) -> tensor<5x1xf32>
 
   // CHECK: arith.constant 1.0
@@ -920,7 +920,10 @@ func.func @reduce_float_dyn(%arg0: tensor<?x5x4xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor<?x4xf32> into tensor<?x1x4xf32>
+  // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?x4xf32>
+  // CHECK: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [%[[DIM_1]], 1, 4] : tensor<?x4xf32> into tensor<?x1x4xf32>
   %0 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<?x5x4xf32>) -> tensor<?x1x4xf32>
   return
 }
@@ -938,7 +941,7 @@ func.func @reduce_float_dyn_rank_1(%arg0: tensor<?xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.addf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] : tensor<f32> into tensor<1xf32>
+  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}] output_shape [1] : tensor<f32> into tensor<1xf32>
   %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<?xf32>) -> tensor<1xf32>
   return
 }
@@ -958,7 +961,10 @@ func.func @reduce_float_dyn_nonzero_batch(%arg0: tensor<5x?x4xf32>) -> () {
   // CHECK:   %[[RES:.+]] = arith.mulf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[RES]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32>
+  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C1_0]] : tensor<5x?xf32>
+  // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
+  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0], [1, 2]] output_shape [5, %[[DIM_1]], 1] : tensor<5x?xf32> into tensor<5x?x1xf32>
   %0 = tosa.reduce_prod %arg0 {axis = 2 : i32} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32>
   return
 }
@@ -978,7 +984,10 @@ func.func @reduce_float_dyn_multiple(%arg0: tensor<?x?xf32>) -> () {
   // CHECK:   %[[MAX:.+]] = arith.maximumf %[[ARG1]], %[[ARG2]] : f32
   // CHECK:   linalg.yield %[[MAX]] : f32
   // CHECK:  }
-  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] : tensor<?xf32> into tensor<?x1xf32>
+  // CHECK: %[[C0_0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM_1:.+]] = tensor.dim %[[REDUCE]], %[[C0_0]] : tensor<?xf32>
+  // CHECK: %[[C1_2:.+]] = arith.constant 1 : index
+  // CHECK: tensor.expand_shape %[[REDUCE]] {{\[}}[0, 1]] output_shape [%[[DIM_1]], 1] : tensor<?xf32> into tensor<?x1xf32>
   %0 = tosa.reduce_max %arg0 {axis = 1 : i32} : (tensor<?x?xf32>) -> tensor<?x1xf32>
   return
 }
@@ -996,7 +1005,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
   // CHECK:  }
-  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi32> into tensor<1x4xi32>
+  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32>
   %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xi32>) -> tensor<1x4xi32>
 
   // CHECK: [[INIT:%.+]] = tensor.empty()
@@ -1007,7 +1016,7 @@ func.func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
   // CHECK:   [[RES:%.+]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
   // CHECK:   linalg.yield [[RES]] : i32
   // CHECK:  }
-  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<5xi32> into tensor<5x1xi32>
+  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xi32> into tensor<5x1xi32>
   %1 = tosa.reduce_sum %arg0 {axis = 1 : i32} : (tensor<5x4xi32>) -> tensor<5x1xi32>
 
   // CHECK: arith.constant 1
@@ -1043,7 +1052,7 @@ func.func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
   // CHECK:   [[RES:%.+]] = arith.andi %[[ARG1]], %[[ARG2]] : i1
   // CHECK:   linalg.yield [[RES]] : i1
   // CHECK:  }
-  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] : tensor<4xi1> into tensor<1x4xi1>
+  // CHECK: tensor.expand_shape [[REDUCE]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xi1> into tensor<1x4xi1>
   %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<5x4xi1>) -> tensor<1x4xi1>
 
   // CHECK: arith.constant false
diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
index a8a3c42e168422..b8c3d56f21f10c 100644
--- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
+++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
@@ -14,7 +14,7 @@ func.func @test_reshape_0d_same_s2s_explicit(%arg0: tensor<f32>) -> tensor<f32>
 
 // CHECK-LABEL: test_reshape_0d_up_s2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
 // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
 // CHECK: return %[[VAL_1]] : tensor<?xf32>
 func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
@@ -26,7 +26,7 @@ func.func @test_reshape_0d_up_s2d_auto(%arg0: tensor<f32>) -> tensor<?xf32> {
 
 // CHECK-LABEL: test_reshape_0d_up_s2d_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
 // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1xf32> to tensor<?xf32>
 // CHECK: return %[[VAL_1]] : tensor<?xf32>
 func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32> {
@@ -38,7 +38,7 @@ func.func @test_reshape_0d_up_s2d_explicit(%arg0: tensor<f32>) -> tensor<?xf32>
 
 // CHECK-LABEL: test_reshape_0d_up_s2s_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
 // CHECK: return %[[VAL_0]] : tensor<1xf32>
 func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1>} : (tensor<f32>) -> tensor<1xf32>
@@ -49,7 +49,7 @@ func.func @test_reshape_0d_up_s2s_auto(%arg0: tensor<f32>) -> tensor<1xf32> {
 
 // CHECK-LABEL: test_reshape_0d_up_s2s_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] : tensor<f32> into tensor<1xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
 // CHECK: return %[[VAL_0]] : tensor<1xf32>
 func.func @test_reshape_0d_up_s2s_explicit(%arg0: tensor<f32>) -> tensor<1xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 1>} : (tensor<f32>) -> tensor<1xf32>
@@ -83,8 +83,12 @@ func.func @test_reshape_1d_down_s2s_explicit(%arg0: tensor<1xf32>) -> tensor<f32
 
 // CHECK-LABEL: test_reshape_1d_up_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
-// CHECK: return %[[VAL_0]] : tensor<2x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C0]] : tensor<?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C2]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, %[[VAL_0]]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
 func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
@@ -94,7 +98,7 @@ func.func @test_reshape_1d_up_d2d_auto(%arg0: tensor<?xf32>) -> tensor<2x?xf32>
 
 // CHECK-LABEL: test_reshape_1d_up_s2s_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<6xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: %[[VAL_0:.*]] = tensor.expand_shape %[[ARG_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
 // CHECK: return %[[VAL_0]] : tensor<2x3xf32>
 func.func @test_reshape_1d_up_s2s_explicit(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<6xf32>) -> tensor<2x3xf32>
@@ -128,8 +132,12 @@ func.func @test_reshape_2d_down_s2s_explicit(%arg0: tensor<2x3xf32>) -> tensor<6
 // CHECK-LABEL: test_reshape_2d_same_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x2xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<?x2xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<?xf32> into tensor<2x?xf32>
-// CHECK: return %[[VAL_1]] : tensor<2x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C2]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, %[[DIV]]] : tensor<?xf32> into tensor<2x?xf32>
+// CHECK: return %[[EXPANDED]] : tensor<2x?xf32>
 func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1>} : (tensor<?x2xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
@@ -140,7 +148,7 @@ func.func @test_reshape_2d_same_d2d_auto(%arg0: tensor<?x2xf32>) -> tensor<2x?xf
 // CHECK-LABEL: test_reshape_2d_same_s2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32>
 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x2xf32>
 func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
@@ -153,7 +161,7 @@ func.func @test_reshape_2d_same_s2d_auto(%arg0: tensor<2x4xf32>) -> tensor<?x2xf
 // CHECK-LABEL: test_reshape_2d_same_s2d_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x4xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<2x4xf32> into tensor<8xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<8xf32> into tensor<4x2xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [4, 2] : tensor<8xf32> into tensor<4x2xf32>
 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<4x2xf32> to tensor<?x2xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x2xf32>
 func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?x2xf32> {
@@ -166,7 +174,7 @@ func.func @test_reshape_2d_same_s2d_explicit(%arg0: tensor<2x4xf32>) -> tensor<?
 // CHECK-LABEL: test_reshape_2d_same_s2s_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1]] : tensor<3x2xf32> into tensor<6xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : tensor<6xf32> into tensor<2x3xf32>
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
 // CHECK: return %[[VAL_1]] : tensor<2x3xf32>
 func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3>} : (tensor<3x2xf32>) -> tensor<2x3xf32>
@@ -178,7 +186,11 @@ func.func @test_reshape_2d_same_s2s_explicit(%arg0: tensor<3x2xf32>) -> tensor<2
 // CHECK-LABEL: test_reshape_3d_same_d2d_auto_empty
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<0x3x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C0_0]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [0, 3, %[[DIV]]] : tensor<?xf32> into tensor<0x3x?xf32>
 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<0x3x?xf32> to tensor<?x?x?xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
 func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
@@ -191,7 +203,11 @@ func.func @test_reshape_3d_same_d2d_auto_empty(%arg0: tensor<3x2x?xf32>) -> tens
 // CHECK-LABEL: test_reshape_3d_same_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<2x?x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<2x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor<?xf32> into tensor<2x?x4xf32>
 // CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<?x?x?xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
 func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?x?xf32> {
@@ -204,7 +220,11 @@ func.func @test_reshape_3d_same_d2d_auto(%arg0: tensor<2x?x?xf32>) -> tensor<?x?
 // CHECK-LABEL: test_reshape_3d_same_d2d_auto_identity
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x3x4xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x3x4xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x3x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, 3, %[[DIV]]] : tensor<?xf32> into tensor<2x3x?xf32>
 // CHECK: return %[[VAL_1]] : tensor<2x3x?xf32>
 func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> tensor<2x3x?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, -1>} : (tensor<?x3x4xf32>) -> tensor<2x3x?xf32>
@@ -216,8 +236,12 @@ func.func @test_reshape_3d_same_d2d_auto_identity(%arg0: tensor<?x3x4xf32>) -> t
 // CHECK-LABEL: test_reshape_3d_same_d2d_explicit_empty
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<3x2x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<3x2x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x2xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 2] : tensor<?xf32> into tensor<?x3x2xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2xf32> to tensor<?x?x?xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
 func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) -> tensor<?x?x?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 0, 3, 2>} : (tensor<3x2x?xf32>) -> tensor<?x?x?xf32>
@@ -229,8 +253,12 @@ func.func @test_reshape_3d_same_d2d_explicit_empty(%arg0: tensor<3x2x?xf32>) ->
 // CHECK-LABEL: test_reshape_3d_same_d2d_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C12:.*]] = arith.constant 12 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<?x?x?xf32>
 // CHECK: return %[[VAL_2]] : tensor<?x?x?xf32>
 func.func @test_reshape_3d_same_d2d_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
@@ -253,8 +281,12 @@ func.func @test_reshape_3d_same_d2d_explicit_identity(%arg0: tensor<?x3x4xf32>)
 // CHECK-LABEL: test_reshape_3d_same_d2s_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<2x?x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C8]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [2, %[[DIV]], 4] : tensor<?xf32> into tensor<2x?x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<2x?x4xf32> to tensor<2x3x4xf32>
 // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
 func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, -1, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
@@ -266,8 +298,12 @@ func.func @test_reshape_3d_same_d2s_auto(%arg0: tensor<?x?x?xf32>) -> tensor<2x3
 // CHECK-LABEL: test_reshape_3d_same_d2s_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x3x4xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C12:.*]] = arith.constant 12 : index
+// CHECK: %[[DIV:.*]] = arith.divui %[[DIM]], %[[C12]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] output_shape [%[[DIV]], 3, 4] : tensor<?xf32> into tensor<?x3x4xf32>
+// CHECK: %[[VAL_2:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x4xf32> to tensor<2x3x4xf32>
 // CHECK: return %[[VAL_2]] : tensor<2x3x4xf32>
 func.func @test_reshape_3d_same_d2s_explicit(%arg0: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 2, 3, 4>} : (tensor<?x?x?xf32>) -> tensor<2x3x4xf32>
@@ -288,10 +324,14 @@ func.func @test_reshape_3d_same_s2s_explicit_identity(%arg0: tensor<2x3x4xf32>)
 
 // CHECK-LABEL: test_reshape_3d_up_d2s_explicit
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : tensor<?xf32> into tensor<?x3x2x1xf32>
-// CHECK: %[[VAL_2:.*]] = tensor.cast %[[VAL_1]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
-// CHECK: return %[[VAL_2]] : tensor<1x3x2x1xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2]] : tensor<?x?x?xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2, 3]] output_shape [%[[VAL_0]], 3, 2, 1] : tensor<?xf32> into tensor<?x3x2x1xf32>
+// CHECK: %[[CAST:.*]] = tensor.cast %[[EXPANDED]] : tensor<?x3x2x1xf32> to tensor<1x3x2x1xf32>
+// CHECK: return %[[CAST]] : tensor<1x3x2x1xf32>
 func.func @test_reshape_3d_up_d2s_explicit(%input: tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32> {
   %0 = tosa.reshape %input {new_shape = array<i64: 1, 3, 2, 1>} : (tensor<?x?x?xf32>) -> tensor<1x3x2x1xf32>
   return %0 : tensor<1x3x2x1xf32>
@@ -313,9 +353,13 @@ func.func @test_reshape_4d_down_d2s_explicit(%arg0: tensor<?x?x?x?xf32>) -> tens
 
 // CHECK-LABEL: test_reshape_5d_down_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?x?x?x2x3xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x2x3xf32>
-// CHECK: return %[[VAL_1]] : tensor<?x2x3xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4]] : tensor<?x?x?x2x3xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C6]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 2, 3] : tensor<?xf32> into tensor<?x2x3xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x2x3xf32>
 func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 2, 3>} : (tensor<?x?x?x2x3xf32>) -> tensor<?x2x3xf32>
   return %0 : tensor<?x2x3xf32>
@@ -325,9 +369,13 @@ func.func @test_reshape_5d_down_d2d_auto(%arg0: tensor<?x?x?x2x3xf32>) -> tensor
 
 // CHECK-LABEL: test_reshape_6d_down_d2d_auto
 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<1x2x?x5x7x11xf32>
-// CHECK: %[[VAL_0:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
-// CHECK: %[[VAL_1:.*]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1, 2]] : tensor<?xf32> into tensor<?x5x77xf32>
-// CHECK: return %[[VAL_1]] : tensor<?x5x77xf32>
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG_0]] {{\[\[}}0, 1, 2, 3, 4, 5]] : tensor<1x2x?x5x7x11xf32> into tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[COLLAPSED]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[C385:.*]] = arith.constant 385 : index
+// CHECK: %[[VAL_0:.*]] = arith.divui %[[DIM]], %[[C385]] : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[COLLAPSED]] {{\[\[}}0, 1, 2]] output_shape [%[[VAL_0]], 5, 77] : tensor<?xf32> into tensor<?x5x77xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x5x77xf32>
 func.func @test_reshape_6d_down_d2d_auto(%arg0: tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32> {
   %0 = "tosa.reshape"(%arg0) {new_shape = array<i64: -1, 5, 77>} : (tensor<1x2x?x5x7x11xf32>) -> tensor<?x5x77xf32>
   return %0 : tensor<?x5x77xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 9a3e14b6d39178..efe59af97d9649 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -132,7 +132,7 @@ func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
   %cst = arith.constant 8.0 : f32
   %0 = tensor.empty() : tensor<128xf32>
   %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
-  %2 = tensor.expand_shape %1 [[0, 1, 2]]
+  %2 = tensor.expand_shape %1 [[0, 1, 2]] output_shape [1, 1, 128]
       : tensor<128xf32> into tensor<1x1x128xf32>
   %3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1]
       : tensor<1x1x128xf32> into tensor<5x6x128xf32>
diff --git a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
index 0e353a1fa43fcb..4bf81820f0e805 100644
--- a/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir
@@ -165,7 +165,9 @@ func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
   %init = tensor.empty(%width) : tensor<1x?xf32>
   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
   %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
-  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
+  %c0 = arith.constant 0 : index
+  %sz0 = tensor.dim %slice, %c0 : tensor<?xf32>
+  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor<?xf32> into tensor<1x1x1x?xf32>
   return %expand : tensor<1x1x1x?xf32>
 }
 
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 547320f5338747..61bedecbdca5a4 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -52,7 +52,7 @@ func.func @collapse_parallel(
 //  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"]}
 //  CHECK-SAME:     ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
 //       CHECK:   } -> tensor<2x32x40960xf32>
-//       CHECK:  tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
+//       CHECK:  tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] output_shape [2, 32, 10, 4096] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
 
 // -----
 
@@ -127,8 +127,8 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
 // CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
 // CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
 // CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
-// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
-// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
+// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
 // CHECK:         }
 
diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
index a6431996353121..c7c846d7ecc9c5 100644
--- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
+++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir
@@ -50,7 +50,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
 
 // CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -78,7 +78,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
 //                CHECK:     linalg.yield %[[ADD]] : f32
 //                CHECK: } -> tensor<1x196x16xf32>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_16433136(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -204,7 +204,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
 //      CHECK:     linalg.yield %[[ADD]] : f32
 //      CHECK:   } -> tensor<8x196x16xf32>
-//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1, 2], [3]] output_shape [8, 14, 14, 16] : tensor<8x196x16xf32> into tensor<8x14x14x16xf32>
 //      CHECK:   return %[[CS_FINAL]]
 func.func @batch_nhwc_conv(%arg0: tensor<8x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<8x14x14x16xf32>) -> tensor<8x14x14x16xf32> {
     %0 = linalg.conv_2d_nhwc_hwcf
@@ -269,7 +269,7 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
 //      CHECK:     linalg.yield %[[ADD]] : f32
 //      CHECK:   } -> tensor<8x16x196xf32>
-//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
+//      CHECK:   %[[CS_FINAL:.+]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0], [1], [2, 3]] output_shape [8, 16, 14, 14] : tensor<8x16x196xf32> into tensor<8x16x14x14xf32>
 //      CHECK:   return %[[CS_FINAL]]
 func.func @batch_nchw_conv(%arg0: tensor<8x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<8x16x14x14xf32>) -> tensor<8x16x14x14xf32> {
     %0 = linalg.conv_2d_nchw_fchw
@@ -310,7 +310,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK: linalg.yield %[[EXTRACTED_INPUT]] : f32
 
 // CHECK: IR printer: transformed
-// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+// CHECK: tensor.expand_shape %{{[^ ]*}} {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
@@ -338,7 +338,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = arith.addf %[[MUL]], %[[ARG2]] : f32
 //                CHECK:     linalg.yield %[[ADD]] : f32
 //                CHECK: } -> tensor<1x196x16xf32>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf32> into tensor<1x14x14x16xf32>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
@@ -378,7 +378,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32
 //                CHECK:     linalg.yield %[[ADD]] : i32
 //                CHECK: } -> tensor<1x196x16xi32>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> {
@@ -416,7 +416,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
 //                CHECK:     linalg.yield %[[ADD]] : complex<f32>
 //                CHECK: } -> tensor<1x196x16xcomplex<f32>>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_complex(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f32>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
@@ -459,7 +459,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
 //                CHECK:     linalg.yield %[[ADD]] : complex<f32>
 //                CHECK: } -> tensor<1x196x16xcomplex<f32>>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_complex_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xcomplex<f16>>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
@@ -500,7 +500,7 @@ module attributes {transform.with_named_sequence} {
 //                CHECK:     %[[ADD:.+]] = complex.add %[[MUL]], %[[ARG2]] : complex<f32>
 //                CHECK:     linalg.yield %[[ADD]] : complex<f32>
 //                CHECK: } -> tensor<1x196x16xcomplex<f32>>
-//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
+//      CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xcomplex<f32>> into tensor<1x14x14x16xcomplex<f32>>
 //      CHECK: return %[[RESULT]]
 
 func.func @conv_complex_f16_extended(%arg0: tensor<1x16x16x4xcomplex<f32>>, %arg1: tensor<3x3x4x16xf16>, %arg2: tensor<1x14x14x16xcomplex<f32>>) -> tensor<1x14x14x16xcomplex<f32>> {
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 79d61ab757e327..bee08503298fd4 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,17 +988,20 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
 
 // -----
 
-func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
   %6 = tensor.empty(%dim) : tensor<?x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
   func.return %expanded : tensor<?x256x256xf32>
 }
 // CHECK-LABEL: func.func @push_down_unpack_through_expand
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C32:.+]] = arith.constant 32 : index
 // CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
+// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
+// CHECK:         %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
 // CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
@@ -1009,12 +1012,12 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
 func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>) -> tensor<4x12x256x256xf32> {
   %6 = tensor.empty() : tensor<4x3072x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 8] into %6 : tensor<4x32x384x8x8xf32> -> tensor<4x3072x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0], [1, 2], [3]] output_shape [4, 12, 256, 256] : tensor<4x3072x256xf32> into tensor<4x12x256x256xf32>
   func.return %expanded : tensor<4x12x256x256xf32>
 }
 // CHECK-LABEL: @push_down_permuted_unpack_through_expand
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] output_shape [4, 32, 12, 32, 8, 8] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
 // CHECK:         return %[[UNPACK]] : tensor<4x12x256x256xf32>
@@ -1024,29 +1027,32 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
 func.func @push_down_unpack_through_unit_expand(%5: tensor<6x32x8x8xf32>) -> tensor<3x16x1x256xf32> {
   %6 = tensor.empty() : tensor<48x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<6x32x8x8xf32> -> tensor<48x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1, 2], [3]] output_shape [3, 16, 1, 256] : tensor<48x256xf32> into tensor<3x16x1x256xf32>
   func.return %expanded : tensor<3x16x1x256xf32>
 }
 // CHECK-LABEL: func.func @push_down_unpack_through_unit_expand
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3], [4], [5]] output_shape [3, 2, 1, 32, 8, 8] : tensor<6x32x8x8xf32> into tensor<3x2x1x32x8x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<3x16x1x256xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 3] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<3x2x1x32x8x8xf32> -> tensor<3x16x1x256xf32>
 // CHECK:         return %[[UNPACK]] : tensor<3x16x1x256xf32>
 
 // -----
 
-func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index) -> tensor<?x256x256xf32> {
+func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
   %6 = tensor.empty(%dim) : tensor<?x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [1] inner_tiles = [8] into %6 : tensor<?x32x8xf32> -> tensor<?x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<?x256xf32> into tensor<?x256x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [%sz0, 256, 256] : tensor<?x256xf32> into tensor<?x256x256xf32>
   func.return %expanded : tensor<?x256x256xf32>
 }
 // CHECK-LABEL: func.func @push_down_unpack_through_expand_on_outer_dims
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK-SAME:      %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK:         %[[C256:.+]] = arith.constant 256 : index
 // CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
+// CHECK:         %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8xf32>
+// CHECK:         %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C256]] : index
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]] output_shape [%[[SZ0]], 256, 32, 8] : tensor<?x32x8xf32> into tensor<?x256x32x8xf32>
 // CHECK:         %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x256x32x8xf32>
 // CHECK:         %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] outer_dims_perm = [0, 1, 2] inner_dims_pos = [2] inner_tiles = [8] into %[[EMPTY]] : tensor<?x256x32x8xf32> -> tensor<?x256x256xf32>
@@ -1057,11 +1063,11 @@ func.func @push_down_unpack_through_expand_on_outer_dims(%5: tensor<?x32x8xf32>,
 func.func @no_push_down_unpack_through_non_divisible_expand(%5: tensor<384x32x8x8xf32>) -> tensor<256x12x256xf32> {
   %6 = tensor.empty() : tensor<3072x256xf32>
   %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<384x32x8x8xf32> -> tensor<3072x256xf32>
-  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+  %expanded = tensor.expand_shape %unpack [[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
   func.return %expanded : tensor<256x12x256xf32>
 }
 // CHECK-LABEL: func.func @no_push_down_unpack_through_non_divisible_expand
 // CHECK-SAME:      %[[ARG0:[a-zA-Z0-9]+]]
 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] : tensor<3072x256xf32> into tensor<256x12x256xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[UNPACK]] {{\[}}[0, 1], [2]] output_shape [256, 12, 256] : tensor<3072x256xf32> into tensor<256x12x256xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<256x12x256xf32>
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index c140b6abcc37a2..a9cbaaf7fdc485 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -25,13 +25,22 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
 //   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
 // CHECK-LABEL: func @drop_one_trip_loops
-//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
-//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
+//       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]
+//       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
+//       CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
+//       CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
+//       CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
+//       CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
+//       CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
 
 //   CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //   CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
@@ -70,13 +79,18 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32,
 }
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
 //   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
+//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
 // CHECK-LABEL: func @drop_one_trip_loops_all_ones
+//       CHECK: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
 //       CHECK: tensor.collapse_shape %{{.*}} []
 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
 //       CHECK: linalg.generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME:   iterator_types = ["parallel"]
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
+//       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
+//       CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
+//       CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
 
 // -----
 
@@ -232,8 +246,8 @@ func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor
 
 func.func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32>
 {
-  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<5xf32> into tensor<1x5xf32>
-  %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<5xf32> into tensor<5x1xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : tensor<5xf32> into tensor<1x5xf32>
+  %1 = tensor.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32>
   %2 = linalg.generic #trait
      ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>)
     outs(%shape : tensor<5x5xf32>) {
@@ -331,7 +345,6 @@ func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor<
 
 //       CHECK: func @fold_unit_dim_for_empty_tensor
 
-
 //       CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
 //       CHECK: %[[INIT:.+]] = tensor.empty() : tensor<f32>
 //       CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<f32>) -> tensor<f32>
@@ -340,7 +353,7 @@ func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor<
 //  CHECK-SAME:     iterator_types = ["reduction"]
 //  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
 //  CHECK-SAME:   outs(%[[FILL]] : tensor<f32>)
-//       CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
+//       CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
 //       CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
 
 
@@ -364,11 +377,11 @@ func.func @fold_slice(
 //      CHECK:   %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]]
 // CHECK-SAME:       to tensor<?x?x?xf32>
 //      CHECK:   %[[RESULT1:.+]] = tensor.expand_shape %[[SLICE1]]
-// CHECK-SAME:       [0, 1], [2], [3, 4, 5, 6]
+// CHECK-SAME:       {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor<?x?x?xf32> into tensor<1x?x?x1x?x1x1xf32>
 //      CHECK:   %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]]
 // CHECK-SAME:       to tensor<?x?x?xf32>
 //      CHECK:   %[[RESULT2:.+]] = tensor.expand_shape %[[SLICE2]]
-// CHECK-SAME:       [0, 1], [2], [3, 4, 5, 6]
+// CHECK-SAME:       {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor<?x?x?xf32> into tensor<1x?x?x1x?x1x1xf32>
 //      CHECK:   return %[[RESULT1]], %[[RESULT2]]
 
 // -----
@@ -391,20 +404,27 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
   } -> tensor<1x?xf32>
   return %3 : tensor<1x?xf32>
 }
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
 //      CHECK: func @unit_dim_for_reduction
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
-//      CHECK:   %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
-//      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+//      CHECK: %[[C1:.+]] = arith.constant 1 : index
+//      CHECK: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
+//      CHECK: %[[C3:.+]] = arith.constant 3 : index
+//      CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C3]] : tensor<1x?x1x?xf32>
+//      CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//      CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+//      CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
+//      CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP2]]]
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
-//      CHECK:   return %[[RESULT_RESHAPE]]
+//      CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
+//      CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
+//      CHECK: return %[[EXPANDED]] : tensor<1x?xf32>
 
 // -----
 
@@ -437,7 +457,7 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1
 // CHECK-SAME:     iterator_types = ["parallel"]
 // CHECK-SAME:     ins(%[[RESHAPE]], %[[FILL]] : tensor<?xf32>, tensor<1xf32>)
 // CHECK-SAME:     outs(%[[INIT2]] : tensor<1xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, 1]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 // -----
@@ -460,20 +480,28 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
   } -> tensor<?x1xf32>
   return %3 : tensor<?x1xf32>
 }
-//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
 //      CHECK: func @unit_dim_for_reduction_inner
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
-//  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
-//      CHECK:   %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
-//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
-//      CHECK:   %[[RESULT:.+]] = linalg.generic
-// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP3]]]
+//      CHECK: %[[C1:.*]] = arith.constant 1 : index
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+//      CHECK: %[[C2:.*]] = arith.constant 2 : index
+//      CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x1x?x1xf32>
+//      CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
+//      CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+//      CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
+//      CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP2]]]
 // CHECK-SAME:     iterator_types = ["parallel", "reduction"]
 // CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
-//      CHECK:   return %[[RESULT_RESHAPE]]
+//      CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
+//      CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
+//      CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
+//      CHECK: return %[[RESULT_RESHAPE]]
 
 // -----
 
@@ -484,7 +512,7 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
 // CHECK-LABEL: func @slice_unit_dims
 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice
 //  CHECK-SAME:     tensor<1x3xf32> to tensor<f32>
-//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] []
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] [] output_shape [1, 1]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -496,7 +524,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x
 // CHECK-LABEL: func @rank_reduced_extract_slice
 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice
 //  CHECK-SAME:     tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
-//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] output_shape [1, 3, 3]
 //       CHECK:   return %[[RESULT]]
 
 // -----
@@ -709,8 +737,8 @@ func.func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref
 
 func.func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
 {
-  %0 = memref.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32>
-  %1 = memref.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : memref<5xf32> into memref<1x5xf32>
+  %1 = memref.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : memref<5xf32> into memref<5x1xf32>
   linalg.generic #trait
      ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
     outs(%shape : memref<5x5xf32>) {
@@ -966,7 +994,7 @@ func.func @drop_unit_pad_dims(%arg0: tensor<1x1x3x1x1xf32>) -> tensor<1x2x3x1x3x
 //       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 0] high[0, 0, 2]
 //       CHECK:   } : tensor<1x3x1xf32> to tensor<2x3x3xf32>
 //       CHECK:   tensor.expand_shape %[[PADDED]]
-//  CHECK-SAME:     {{\[}}[0, 1], [2, 3], [4]{{\]}} : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32>
+//  CHECK-SAME:     {{\[}}[0, 1], [2, 3], [4]{{\]}} output_shape [1, 2, 3, 1, 3] : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32>
 
 // CHECK-SLICES-LABEL: func @drop_unit_pad_dims
 //       CHECK-SLICES:   %[[EXTRACT:.+]] = tensor.extract_slice
@@ -989,13 +1017,19 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
   return %0 : tensor<1x?xf32>
 }
 
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
 // CHECK-LABEL: func @drop_unit_pad_dynamic_dims
+//       CHECK:   %[[C1:.*]] = arith.constant 1 : index
+//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 //       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
 //  CHECK-SAME:     {{\[}}[0, 1]{{\]}} : tensor<1x?xf32> into tensor<?xf32>
 //       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
 //       CHECK:   } : tensor<?xf32> to tensor<?xf32>
-//       CHECK:   tensor.expand_shape %[[PADDED]]
-//  CHECK-SAME:     {{\[}}[0, 1]{{\]}} : tensor<?xf32> into tensor<1x?xf32>
+//       CHECK:   %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
+//       CHECK:   %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
+//       CHECK:   %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
+//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>
 
 // CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
 
@@ -1052,4 +1086,4 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te
 //       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0] high[0, 0]
 //       CHECK:   } : tensor<383x128xf32> to tensor<384x128xf32>
 //       CHECK:   tensor.expand_shape %[[PADDED]]
-//  CHECK-SAME:     {{\[}}[0, 1], [2]] : tensor<384x128xf32> into tensor<1x384x128xf32>
+//  CHECK-SAME:     {{\[}}[0, 1], [2]] output_shape [1, 384, 128] : tensor<384x128xf32> into tensor<1x384x128xf32>
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
index 5a27fe76b13411..9fe50a521d2d81 100644
--- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -26,7 +26,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:                         %[[ARG1:.*]]: tensor<32x7xf32>
 // CHECK-NEXT:    %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
 // CHECK-NEXT:    %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
-// CHECK-NEXT:    %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] output_shape [32, 7] : tensor<224xf32> into tensor<32x7xf32>
 func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
     %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) ->  tensor<32x7xf32>
     return %0 :  tensor<32x7xf32>
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 50d308b6a9fee1..0d40df534a3bb7 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -9,8 +9,7 @@
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
 func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
     %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
-  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
-      : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
   %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
   %generic = linalg.generic {
     indexing_maps = [#map0, #map1, #map2, #map3],
@@ -40,7 +39,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
 // CHECK-SAME:       outs(%[[INIT_RESHAPE]] :
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 //      CONTROL: func @fuse_by_collapsing(
@@ -60,8 +59,7 @@ func.func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>,
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
 func.func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>,
     %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
-  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
-      : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
   %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
   %generic = linalg.generic {
     indexing_maps = [#map0, #map1, #map2, #map3],
@@ -122,8 +120,7 @@ func.func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>,
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
 func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi32>,
     %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> {
-  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
-      : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32>
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [9, 7, 8, 2, 3, 4, 5, 6] : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32>
   %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
   %generic = linalg.generic {
     indexing_maps = [#map0, #map1, #map2, #map3],
@@ -154,7 +151,7 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
 // CHECK-SAME:       outs(%[[INIT_RESHAPE]] :
-//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} output_shape [2, 3, 4, 5, 6, 7, 8, 9]
 //      CHECK:   return %[[RESULT_RESHAPE]]
 
 // -----
@@ -165,11 +162,11 @@ func.func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi3
 #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
 #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
 func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
-    %arg1 : tensor<?x?x?xi32>, %arg2 : tensor<?x?x?x?xi32>) -> tensor<?x3x?x5x?x7x?x?xi32> {
+    %arg1 : tensor<?x?x?xi32>, %arg2 : tensor<?x?x?x?xi32>, %sz0: index, %sz1: index, %sz2: index, %sz3: index, %sz4: index) -> tensor<?x3x?x5x?x7x?x?xi32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
-  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]]
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [%sz0, 7, %sz1, %sz2, 3, %sz3, 5, %sz4]
       : tensor<?x?x?x?x?xi32> into tensor<?x7x?x?x3x?x5x?xi32>
   %d0 = tensor.dim %arg1, %c2 : tensor<?x?x?xi32>
   %d2 = tensor.dim %arg2, %c2 : tensor<?x?x?x?xi32>
@@ -203,8 +200,8 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
     } -> tensor<?x3x?x5x?x7x?x?xi32>
   return %generic : tensor<?x3x?x5x?x7x?x?xi32>
 }
-//      CHECK: func @fuse_by_collapsing_dynamic(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?x?x?xi32>
+//      CHECK: func @fuse_by_collapsing_dynamic
+// CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?x?x?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index)
 //  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //  CHECK-DAG:   %[[C5:.+]] = arith.constant 5 : index
 //      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]]
@@ -224,8 +221,8 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
-func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -> tensor<2x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32>
+func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 6, %sz0, 5] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1],
       iterator_types = ["parallel", "reduction", "reduction", "parallel"]}
@@ -240,7 +237,8 @@ func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 //      CHECK: func @fuse_reductions(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<2x?x5xf32>
-// CHECK-SAME:     %[[ARG1:.+]]: tensor<2x5xf32>) -> tensor<2x5xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<2x5xf32>
+// CHECK-SAME:     %[[SZ0:.+]]: index) -> tensor<2x5xf32>
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:       iterator_types = ["parallel", "reduction", "parallel"]
@@ -253,7 +251,7 @@ func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
 func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x3x4x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
   %init = tensor.empty(): tensor<2x3x4x5xf32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1, #map0],
@@ -280,7 +278,7 @@ func.func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tenso
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
 func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2xf32>) -> tensor<2x4x3x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
   %init = tensor.empty() : tensor<2x4x3x5xf32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1, #map2],
@@ -307,7 +305,7 @@ func.func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %ar
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
 func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 3, 4, 5] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32>
   %init = tensor.empty() : tensor<2x5xf32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1, #map2],
@@ -335,8 +333,8 @@ func.func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 :
 #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tensor<2x3x4x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<6xf32> into tensor<2x3xf32>
-  %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<20xf32> into tensor<4x5xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [2, 3] : tensor<6xf32> into tensor<2x3xf32>
+  %1 = tensor.expand_shape %arg1 [[0, 1]] output_shape [4, 5] : tensor<20xf32> into tensor<4x5xf32>
     %init = tensor.empty() : tensor<2x3x4x5xf32>
   %2 = linalg.generic {
       indexing_maps = [#map0, #map1, #map2],
@@ -359,8 +357,8 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens
 // CHECK-SAME:       iterator_types = ["parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%{{.+}}: tensor<6x20xf32>)
-//      CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}}
-//      CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}}
+//      CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}} output_shape [6, 4, 5]
+//      CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, 3, 4, 5]
 //      CHECK:   return %[[RESHAPE2]]
 
 //  CONTROL-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -375,14 +373,14 @@ func.func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tens
 //      CONTROL:     %[[GENERIC:.+]] = linalg.generic
 // CONTROL-SAME:         ins(%[[EXPAND]], %[[ARG1]] :
 // CONTROL-SAME:         outs(%[[INIT_RESHAPE]] :
-//      CONTROL:     %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}}
+//      CONTROL:     %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} output_shape [2, 3, 4, 5]
 
 // -----
 
 // Corner case that isnt handled currently.
 #map = affine_map<(d0) -> (d0)>
 func.func @zero_D_test(%arg0: tensor<f32>) -> tensor<1xf32> {
-  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
+  %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
   %init = tensor.empty() : tensor<1xf32>
   %1 = linalg.generic {
       indexing_maps = [#map, #map],
@@ -404,8 +402,8 @@ func.func @zero_D_test(%arg0: tensor<f32>) -> tensor<1xf32> {
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xf32> into tensor<?x4x?x8xf32>
+func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4x?x?x8xf32>, %sz0: index, %sz1: index) -> tensor<4x?x?x8xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor<?x?xf32> into tensor<?x4x?x8xf32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1, #map1],
       iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
@@ -419,10 +417,12 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-//      CHECK: func @fuse_only_one_reassociation(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
-// CHECK-SAME:     %[[ARG1:.+]]: tensor<4x?x?x8xf32>
-//  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+//      CHECK: func @fuse_only_one_reassociation
+// CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
+//  CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
+//  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
 //  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -431,17 +431,20 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
 // CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
-//      CHECK:   %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}}
-//      CHECK:   return %[[EXPAND_GENERIC]]
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+//      CHECK:   return %[[EXPANDED_3]]
 
 // -----
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>
-func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4xi32> {
+func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1: index) -> tensor<?x8x?x4xi32> {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
   %d0 = tensor.dim %0, %c0 : tensor<?x4x?x8xi32>
   %d1 = tensor.dim %0, %c2 : tensor<?x4x?x8xi32>
   %init = tensor.empty(%d1, %d0) : tensor<?x8x?x4xi32>
@@ -465,10 +468,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4x
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 //      CHECK: func @fold_non_consecutive_dims(
-// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>)
-//  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[INIT:.+]] = tensor.empty
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
 //      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -487,8 +496,12 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4x
 //  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
 //  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
 //      CHECK:       linalg.yield %[[T7]]
-//      CHECK:   %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2, 3]{{\]}}
-//      CHECK:   return %[[EXPAND_GENERIC]]
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index
+//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C4]] : index
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+//      CHECK:   return %[[EXPANDED_3]]
 
 // -----
 
@@ -496,10 +509,10 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4x
 // So no change in the code.
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
 #map1 = affine_map<(d0, d1, d2, d3) -> ()>
-func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>) -> tensor<i32> {
+func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1: index) -> tensor<i32> {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [%sz0, 4, %sz1, 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
   %init = tensor.empty() : tensor<i32>
   %1 = linalg.generic {
       indexing_maps = [#map0, #map1],
@@ -519,8 +532,8 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>) -> te
   return %1 : tensor<i32>
 }
 //      CHECK: func @no_fold_non_consecutive_reduction_dims(
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>)
-//      CHECK:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
+//      CHECK:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPAND_ARG0]] :
 //      CHECK:   return %[[GENERIC]]
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index f1c729ef963ba8..751ece37bc094f 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -4,15 +4,19 @@
 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
 
 // CHECK-LABEL: func @reshape
-// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>)
+// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
+//      CHECK: %[[C112:.*]] = arith.constant 112 : index
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
 //      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<?x16xf32> into tensor<?x112x16xf32>
+//      CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
+//      CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C112]] : index
+//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<?x112x16xf32>
-func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>) -> tensor<?x112x16xf32> {
-  %0 = tensor.expand_shape %A [[0, 1], [2]]
+func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
+  %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]
       : tensor<?x16xf32> into tensor<?x112x16xf32>
   %2 = linalg.generic {indexing_maps = [
     affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
@@ -39,13 +43,13 @@ func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>)
-//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] : tensor<12544x16xf32> into tensor<112x112x16xf32>
+//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<112x112x16xf32>
 func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
   %C: tensor<16xf32>) -> tensor<112x112x16xf32> {
-  %0 = tensor.expand_shape %A [[0, 1], [2]]
+  %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
-  %1 = tensor.expand_shape %B [[0, 1], [2]]
+  %1 = tensor.expand_shape %B [[0, 1], [2]] output_shape [112, 112, 16]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
   %2 = tensor.empty() : tensor<112x112x16xf32>
   %3 = linalg.generic {indexing_maps = [
@@ -69,11 +73,11 @@ func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>,
 // Negative test, since the second source is broadcasted from d1 we cannot merge
 // d0 and d1 dimensions
 // CHECK-LABEL: func @reshape_negative
-// CHECK: tensor.expand_shape {{.*}} : tensor<12544x16xf32> into tensor<112x112x16xf32>
+// CHECK: tensor.expand_shape {{.*}} {{\[\[}}0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32>
 // CHECK: linalg.generic
 // CHECK: } -> tensor<112x112x16xf32>
 func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> {
-  %20 = tensor.expand_shape %A [[0, 1], [2]]
+  %20 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16]
       : tensor<12544x16xf32> into tensor<112x112x16xf32>
   %21 = tensor.empty() : tensor<112x112x16xf32>
   %22 = linalg.generic {indexing_maps = [
@@ -96,7 +100,7 @@ func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
   %cst_6 = arith.constant 1.000000e+00 : f32
   %cst_7 = arith.constant 7.000000e+00 : f32
   %cst_8 = arith.constant 1.1920929E-7 : f32
-  %25 = tensor.expand_shape %arg0 [[0, 1], [2]]
+  %25 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 5]
       : tensor<6x5xi32> into tensor<2x3x5xi32>
   %26 = tensor.empty() : tensor<2x3x5xf32>
   %28 = linalg.generic {
diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
index ab948988b7b6e7..0f0337a3604e00 100644
--- a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
@@ -48,7 +48,7 @@ func.func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : te
       ^bb0(%arg2: f32):
         linalg.yield %cst : f32
       } -> tensor<?x?xf32>
-  %0 = tensor.expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
+  %0 = tensor.expand_shape %fill [[0, 1], [2]] output_shape [1, %d0, %d1] : tensor<?x?xf32> into tensor<1x?x?xf32>
   %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
       outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
   return %1 : tensor<1x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 342c067b5c4ba4..f42666f81bbadd 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,10 +30,20 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0], [1], [2, 3]
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0], [1], [2, 3]
+//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM_1]], %[[C4]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_4]], %[[C4]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -50,7 +60,9 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 #map1 = affine_map<(d0, d1) -> ()>
 func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
                                          %arg1 : tensor<?x?xf32>,
-                                         %arg2 : f32) ->
+                                         %arg2 : f32,
+                                         %sz0: index,
+                                         %sz1: index) ->
                                          tensor<?x4x?x5xf32>
 {
   %0 = linalg.generic {
@@ -63,7 +75,7 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
       %2 = arith.addf %1, %arg5 : f32
       linalg.yield %2 : f32
   } -> tensor<?x?xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, 4, %sz1, 5] :
     tensor<?x?xf32> into tensor<?x4x?x5xf32>
   return %1 : tensor<?x4x?x5xf32>
 }
@@ -75,14 +87,22 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0], [1, 2, 3]
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
+//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM_0]], %[[C20]] : index
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -94,7 +114,7 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // -----
 
 func.func @reshape_as_consumer_permutation
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
     -> tensor<?x2x?x3x4x?xf32> {
   %c = linalg.generic {
          indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
@@ -107,8 +127,7 @@ func.func @reshape_as_consumer_permutation
          %1 = arith.addf %arg0, %arg1 : f32
          linalg.yield %1 : f32
        } -> tensor<?x?x?xf32>
-  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
-       : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
   return %d : tensor<?x2x?x3x4x?xf32>
 }
 //  CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
@@ -117,15 +136,27 @@ func.func @reshape_as_consumer_permutation
 //      CHECK: func @reshape_as_consumer_permutation
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0, 1, 2], [3, 4], [5]
-// CHECK-SAME:     tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0, 1, 2], [3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<3x4x?x?xf32>
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0, 1], [2], [3, 4, 5]]
-// CHECK-SAME:     tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
+//      CHECK:   %[[C12:.+]] = arith.constant 12 : index
+//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C12]] : index
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C2]] : index
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_2]], %[[C12]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+//      CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_5]], %[[C2]] : index
+//      CHECK:   %[[VAL_4:.+]] = arith.divui %[[DIM_7]], %[[C12]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
@@ -152,7 +183,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
       %2 = arith.mulf %arg1, %arg2 : f32
       linalg.yield %2 : f32
     } -> tensor<264x4xf32>
-  %2 = tensor.expand_shape %1 [[0, 1], [2]] :
+  %2 = tensor.expand_shape %1 [[0, 1], [2]] output_shape [8, 33, 4] :
     tensor<264x4xf32> into tensor<8x33x4xf32>
   return %2 : tensor<8x33x4xf32>
 }
@@ -163,12 +194,8 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
 //  CHECK-DAG:   %[[CST:.+]] = arith.constant
 // CHECK-SAME:     : tensor<8x33x4xf32>
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0, 1], [2]
-// CHECK-SAME:     tensor<264x4xf32> into tensor<8x33x4xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[INIT]]
-// CHECK-SAME:     [0, 1], [2]
-// CHECK-SAME:     : tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
 //      CHECK:   %[[T2:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel"]
@@ -232,7 +259,8 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>) ->
+                                         %arg1 : tensor<?x?xi32>, 
+                                         %sz0: index, %sz1: index) ->
                                          tensor<?x?x4x5xi32>
 {
   %0 = linalg.generic {
@@ -250,7 +278,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
       %5 = arith.addi %3, %4 : i32
       linalg.yield %5 : i32
   } -> tensor<?x?xi32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
     tensor<?x?xi32> into tensor<?x?x4x5xi32>
   return %1 : tensor<?x?x4x5xi32>
 }
@@ -302,8 +330,7 @@ func.func @reshape_as_consumer_permutation
          %7 = arith.addi %5, %6 : i32
          linalg.yield %7 : i32
        } -> tensor<6x4x210xi32>
-  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
-       : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
   return %d : tensor<2x3x4x5x6x7xi32>
 }
 
@@ -319,13 +346,9 @@ func.func @reshape_as_consumer_permutation
 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
 //  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
 //   CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
-//   CHECK-DAG:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
-//  CHECK-SAME:     [0, 1, 2], [3, 4], [5]
-//   CHECK-DAG:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
-//  CHECK-SAME:     [0, 1, 2], [3]
-//   CHECK-DAG:   %[[T3:.+]] = tensor.expand_shape %[[INIT]]
-//  CHECK-SAME:     [0, 1], [2], [3, 4, 5]
-//  CHECK-SAME:     : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
+//       CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32>
+//       CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32>
+//       CHECK:   %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
 //       CHECK:   %[[T4:.+]] = linalg.generic
 //  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 //  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
@@ -411,7 +434,8 @@ func.func @reshape_as_producer_projected_permutation(
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d1, d0)>
 func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
-                                                   %arg1 : tensor<?x?xf32>) ->
+                                                   %arg1 : tensor<?x?xf32>,
+                                                   %sz0: index, %sz1: index) ->
                                                    tensor<?x?x4x5xf32>
 {
   %0 = linalg.generic {
@@ -423,7 +447,7 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
       %1 = arith.mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
   } -> tensor<?x?xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
     tensor<?x?xf32> into tensor<?x?x4x5xf32>
   return %1 : tensor<?x?x4x5xf32>
 }
@@ -433,15 +457,22 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
 //      CHECK: func @generic_op_reshape_consumer_fusion_projected
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0, 1, 2], [3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0, 1, 2], [3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
+//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C20]] : index
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_1]], %[[C20]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -466,6 +497,7 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
     } -> tensor<?xf32>
   return %3 : tensor<?xf32>
 }
+
 //      CHECK: func @no_fuse_dynamic_dims
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
 //      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
@@ -503,7 +535,8 @@ func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi
 // -----
 
 func.func @reshape_as_consumer_permutation_with_multiple_results
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>)
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 
+   %sz1: index, %sz2: index, %sz3: index, %sz4: index)
     -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
   %c:2 = linalg.generic {
          indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
@@ -517,10 +550,8 @@ func.func @reshape_as_consumer_permutation_with_multiple_results
          %1 = arith.addf %arg0, %arg1 : f32
          linalg.yield %1, %1 : f32, f32
        } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-  %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]]
-       : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
-  %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]]
-       : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+  %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+  %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] output_shape [%sz3, %sz4, 2, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
   return %d, %e : tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>
 }
 //  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
@@ -528,17 +559,40 @@ func.func @reshape_as_consumer_permutation_with_multiple_results
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
 //  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)>
 //      CHECK: func @reshape_as_consumer_permutation_with_multiple_results
-// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
-// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-//  CHECK-DAG:  %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5]{{\]}}
-//  CHECK-DAG:  %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1, 2], [3]{{\]}}
-//  CHECK-DAG:  %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3, 4, 5]{{\]}}
-//  CHECK-DAG:  %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3, 4, 5]{{\]}}
-//      CHECK:  %[[GENERIC:.+]]:2 = linalg.generic
-// CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
-// CHECK-SAME:      ins(%[[RESHAPE0]], %[[RESHAPE1]] :
-// CHECK-SAME:      outs(%[[RESHAPE2]], %[[RESHAPE3]] :
-//      CHECK:  return %[[GENERIC]]#0, %[[GENERIC]]#1
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index
+//       CHECK:   %[[C12:.+]] = arith.constant 12 : index
+//       CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+//       CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C12]] : index
+//       CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C2]] : index
+//       CHECK:   %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
+//       CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//       CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//       CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_2]], %[[C12]] : index
+//       CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
+//       CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+//       CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_5]], %[[C2]] : index
+//       CHECK:   %[[VAL_4:.+]] = arith.divui %[[DIM_7]], %[[C12]] : index
+//       CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
+//       CHECK:   %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
+//       CHECK:   %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
+//       CHECK:   %[[VAL_5:.+]] = arith.divui %[[DIM_10]], %[[C2]] : index
+//       CHECK:   %[[VAL_6:.+]] = arith.divui %[[DIM_11]], %[[C12]] : index
+//       CHECK:   %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
+//       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
+//  CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+//  CHECK-SAME:      ins(%[[RESHAPE0]], %[[RESHAPE1]] :
+//  CHECK-SAME:      outs(%[[RESHAPE2]], %[[RESHAPE3]] :
+//       CHECK:  return %[[GENERIC]]#0, %[[GENERIC]]#1
 
 // -----
 
@@ -556,7 +610,7 @@ module {
         %2 = arith.addf %arg4, %arg5 : f32
         linalg.yield %2, %2 : f32, f32
       } -> (tensor<512xf32>, tensor<200x512xf32>)
-    %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] : tensor<200x512xf32> into tensor<25x8x1x512xf32>
+    %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32>
     return %1 : tensor<25x8x1x512xf32>
   }
 }
@@ -567,7 +621,7 @@ module {
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512xf32>
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<512xf32>
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<200x512xf32>
-//      CHECK:   %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[}}[0, 1, 2], [3]{{\]}}
+//      CHECK:     %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[\[}}0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32>
 //      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]], #[[MAP1]]]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
@@ -581,7 +635,9 @@ module {
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
                                                         %arg1 : tensor<?x?xf32>,
-                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        %arg2 : tensor<?x?xf32>,
+                                                        %sz0: index,
+                                                        %sz1: index) ->
                                                         tensor<?x?x4x5xf32>
 {
   %0 = linalg.generic {
@@ -593,7 +649,7 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
       %1 = arith.mulf %arg3, %arg4 : f32
       linalg.yield %1 : f32
   } -> tensor<?x?xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
     tensor<?x?xf32> into tensor<?x?x4x5xf32>
   return %1 : tensor<?x?x4x5xf32>
 }
@@ -605,12 +661,18 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0, 1, 2], [3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
+//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C20]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
@@ -650,10 +712,21 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0, 1], [2], [3, 4]
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
-// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
+//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C8]] : index
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C7]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C8]] : index
+//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C7]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
@@ -668,12 +741,14 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x
 
 func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
                                               %arg1 : tensor<?x?xf32>,
-                                              %arg2 : tensor<?x?xf32>) ->
+                                              %arg2 : tensor<?x?xf32>,
+                                              %sz0: index,
+                                              %sz1: index) ->
                                               tensor<?x?x4x5xf32>
 {
   %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
        outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
     tensor<?x?xf32> into tensor<?x?x4x5xf32>
   return %1 : tensor<?x?x4x5xf32>
 }
@@ -683,15 +758,22 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
-//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]]
-// CHECK-SAME:     [0], [1, 2, 3]
-// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
+//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM_0]], %[[C20]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_2]], %[[C20]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_5]], %[[C20]] : index
+//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
 //      CHECK:   %[[T4:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -721,10 +803,20 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
-// CHECK-SAME:     [0, 1], [2, 3]
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
-// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
+//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_0:.+]] = arith.divui %[[DIM]], %[[C7]] : index
+//      CHECK:   %[[VAL_1:.+]] = arith.divui %[[DIM_0]], %[[C8]] : index
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+//      CHECK:   %[[VAL_2:.+]] = arith.divui %[[DIM_1]], %[[C7]] : index
+//      CHECK:   %[[VAL_3:.+]] = arith.divui %[[DIM_2]], %[[C8]] : index
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 4262cd23e7469d..8fb84248c9613b 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -199,13 +199,12 @@ func.func @empty_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
 
 // -----
 
-func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (index, index, index)
 {
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c4 = arith.constant 4 : index
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
-      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   %1 = tensor.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
   %2 = tensor.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index 006d6105677e97..31e9fd00cffa04 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -13,8 +13,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 //  CHECK-LABEL: @matmul_split
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 4, 64] : tensor<16x256xf32> into tensor<16x4x64xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 64, 32] : tensor<256x32xf32> into tensor<4x64x32xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
@@ -65,7 +65,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: ten
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
 //CHECK-LABEL: @generic_split_1d
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [4, 8] : tensor<32xf32> into tensor<4x8xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic
@@ -119,8 +119,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 0xFF800000 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
@@ -177,8 +177,8 @@ func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d_ninf
 //  CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
@@ -218,8 +218,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 //  CHECK-LABEL: @matmul_split
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 64, 4] : tensor<16x256xf32> into tensor<16x64x4xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [64, 4, 32] : tensor<256x32xf32> into tensor<64x4x32xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
@@ -270,7 +270,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: ten
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
 //CHECK-LABEL: @generic_split_1d
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic
@@ -324,8 +324,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 0x7F800000 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
@@ -382,8 +382,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d
 //  CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32
-//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
-//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
 //      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 58d4b21ea2dd90..d7ff1ded9d9332 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1710,10 +1710,12 @@ module attributes {transform.with_named_sequence} {
 #map = affine_map<(d0) -> (d0)>
 // CHECK-LABEL:   @not_vectorizable
 func.func @not_vectorizable(%arg0: tensor<1x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<1x128xf32> {
+  %c0 = arith.constant 0 : index
   %0 = tensor.empty() : tensor<1x128xf32>
   %1 = scf.for %arg5 = %arg2 to %arg1 step %arg3 iter_args(%arg6 = %0) -> (tensor<1x128xf32>) {
     %extracted_slice = tensor.extract_slice %arg6[0, 0] [1, %arg1] [1, 1] : tensor<1x128xf32> to tensor<?xf32>
-    %expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+    %sz0 = tensor.dim %extracted_slice, %c0 : tensor<?xf32>
+    %expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
     %extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
     %extracted_slice_1 = tensor.extract_slice %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
     %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_0 : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..f442a61dc31ed1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -13,7 +13,7 @@ func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
 // CHECK-LABEL: expand_shape_identity_fold
 // CHECK-NEXT: return
 func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
-  %0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
+  %0 = memref.expand_shape %arg0 [[0], [1]] output_shape [5, 4] : memref<5x4xi8> into memref<5x4xi8>
   return %0 : memref<5x4xi8>
 }
 
@@ -23,7 +23,7 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
 // CHECK-NEXT: return
 func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
   %0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
-  %1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
+  %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<i8> into memref<1x1xi8>
   return %1 : memref<1x1xi8>
 }
 
@@ -455,9 +455,9 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
 // -----
 
 func.func @do_not_compose_collapse_of_expand_non_identity_layout(
-    %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>)
+    %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
     -> memref<?xf32, strided<[?], offset: 0>> {
-  %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
+  %1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] :
     memref<?x?xf32, strided<[?, 1], offset: 0>> into
     memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
   %2 = memref.collapse_shape %1 [[0, 1, 2]] :
@@ -471,35 +471,34 @@ func.func @do_not_compose_collapse_of_expand_non_identity_layout(
 
 // -----
 
-func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
+func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index, %sz2: index, %sz3: index)
     -> memref<?x6x4x5x?xf32> {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
       : memref<?x?xf32> into memref<?x4x?xf32>
-  %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]]
-      : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%sz2, 6, 4, 5, %sz3] : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
   return %1 : memref<?x6x4x5x?xf32>
 }
 // CHECK-LABEL: func @compose_expand_of_expand
-//       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//       CHECK:   memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%{{.*}}, 6, 4, 5, %{{.*}}]
 //   CHECK-NOT:   memref.expand_shape
 
 // -----
 
 func.func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
     -> memref<1x1x1xf32> {
-  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
-  %1 = memref.expand_shape %0 [[0, 1, 2]]
+  %0 = memref.expand_shape %arg0 [] output_shape [1] : memref<f32> into memref<1xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
       : memref<1xf32> into memref<1x1x1xf32>
   return %1 : memref<1x1x1xf32>
 }
 // CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
-//       CHECK:   memref.expand_shape %{{.*}} []
+//       CHECK:   memref.expand_shape %{{.*}} [] output_shape [1, 1, 1]
 //  CHECK-SAME:     memref<f32> into memref<1x1x1xf32>
 
 // -----
 
 func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
       : memref<12x4xf32> into memref<3x4x4xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<3x4x4xf32> into memref<12x4xf32>
@@ -510,9 +509,9 @@ func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
 
 // -----
 
-func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
+func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index)
     -> memref<?x?xf32> {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
       : memref<?x?xf32> into memref<?x4x?xf32>
   %1 = memref.collapse_shape %0 [[0, 1], [2]]
       : memref<?x4x?xf32> into memref<?x?xf32>
@@ -525,7 +524,7 @@ func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
 
 func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
   %0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
-  %1 = memref.expand_shape %0 [[0, 1], [2]]
+  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 4, 4]
       : memref<8x4xf32> into memref<2x4x4xf32>
   return %1 : memref<2x4x4xf32>
 }
@@ -981,10 +980,10 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
 //  CHECK-SAME:     %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
 //       CHECK:   %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
 //       CHECK:   return %[[casted]]
-func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
+func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0: index)
     -> (memref<?xf32, 3>)
 {
-  %0 = memref.expand_shape %m [[0, 1]]
+  %0 = memref.expand_shape %m [[0, 1]] output_shape [1, %sz0]
       : memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
   %1 = memref.collapse_shape %0 [[0, 1]]
       : memref<1x?xf32, 3> into memref<?xf32, 3>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 28b70043005940..fdfaa72168d188 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -421,10 +421,11 @@ func.func @simplify_expand_shape(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
     %offset0: index, %offset1: index, %offset2: index,
     %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index)
+    %stride0: index, %stride1: index, %stride2: index,
+    %sz0: index, %sz1: index)
     -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
 
-  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+  %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
       memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
@@ -491,7 +492,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
        index, index, index, index, index,
        index, index, index, index, index) {
 
-  %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] :
+  %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] output_shape [3, 5, 2, 2, 2] :
     memref<30x4xi16> into memref<3x5x2x2x2xi16>
 
   %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
@@ -595,12 +596,13 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
     %offset0: index, %offset1: index, %offset2: index,
     %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index)
+    %stride0: index, %stride1: index, %stride2: index,
+    %sz0: index, %sz1: index)
     -> (memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index) {
 
-  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
       memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
@@ -643,7 +645,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
        index, index, index, index, index,
        index, index, index, index, index) {
 
-  %expand_shape = memref.expand_shape %arg[] :
+  %expand_shape = memref.expand_shape %arg[] output_shape [1, 1, 1, 1, 1] :
     memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
 
   %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
@@ -1513,4 +1515,4 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
     %sizes, %strides :
       memref<f16,3>, index,
       index, index
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 5b853a6cc5a37a..254cd4015eed94 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -412,7 +412,7 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
@@ -458,7 +458,7 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
@@ -469,15 +469,17 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
 // -----
 
 // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
-func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
+// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
+func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
   %c0 = arith.constant 0 : index
-  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
-// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
-// CHECK: return %[[LOAD]]
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
+// CHECK: return %[[VAL_0]] : f32
 
 // -----
 
@@ -486,7 +488,7 @@ func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memr
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
   affine.for %arg3 = 0 to 1 {
     affine.for %arg4 = 0 to 1024 {
       affine.for %arg5 = 0 to 1020 {
@@ -515,7 +517,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
   affine.for %arg3 = 0 to 1 {
     affine.for %arg4 = 0 to 1024 {
       affine.for %arg5 = 0 to 1020 {
@@ -544,7 +546,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
   %cst = arith.constant 0 : index
   affine.for %arg3 = 0 to 1 {
     affine.for %arg4 = 0 to 1024 {
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 1aef417549d9a1..21bbffc5b5a9c5 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -392,9 +392,9 @@ func.func @copy_different_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
 
 // -----
 
-func.func @expand_shape(%arg0: memref<?x?xf32>) {
+func.func @expand_shape(%arg0: memref<?x?xf32>, %sz0: index, %sz1: index) {
   // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}}
-  %0 = memref.expand_shape %arg0 [[0, 1]] : memref<?x?xf32> into memref<?x5x?xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 5, %sz1] : memref<?x?xf32> into memref<?x5x?xf32>
   return
 }
 
@@ -402,7 +402,7 @@ func.func @expand_shape(%arg0: memref<?x?xf32>) {
 
 func.func @expand_shape(%arg0: memref<f32>) {
   // expected-error @+1 {{rank 0 memrefs can only be extended/collapsed with/from ones}}
-  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x2xf32>
+  %0 = memref.expand_shape %arg0 [] output_shape [1, 2] : memref<f32> into memref<1x2xf32>
   return
 }
 
@@ -415,9 +415,9 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {
 
 // -----
 
-func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
+func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>, %sz0: index) {
   // expected-error @+1 {{op reassociation index 2 is out of bounds}}
-  %0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
+  %0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [4, %sz0] : memref<?xf32> into memref<4x?xf32>
 }
 
 // -----
@@ -425,7 +425,7 @@ func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
 func.func @expand_shape_invalid_result_layout(
     %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
   // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}}
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]] :
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 20] :
       memref<30x20xf32, strided<[4000, 2], offset: 100>>
       into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>
 }
@@ -462,7 +462,7 @@ func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>)
 // like this. Verify that a sensible error is emitted in this case.
 func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
   // expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}}
-  %0 = memref.expand_shape %arg0 [[0], [1], [1]] :
+  %0 = memref.expand_shape %arg0 [[0], [1], [1]] output_shape [2, 3] :
     memref<2x3x1xf32> into memref<2x3xf32>
 }
 
@@ -495,20 +495,10 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
-func.func @expand_shape_illegal_dynamic_memref
-  (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
-  // expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}}
-  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
-      : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
-  return %0 : memref<?x?x?x4x?xf32>
-}
-
-// -----
-
 func.func @expand_shape_illegal_static_memref
   (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
   // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
-  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5]
       : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
   return %0 : memref<2x3x2x4x5xf32>
 }
@@ -525,30 +515,30 @@ func.func @collapse_shape_illegal_static_memref
 
 // -----
 
-func.func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
+func.func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>, %sz0: index)
     -> memref<?x4x5xf32> {
   // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]]
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
 }
 
 // -----
 
-func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
+func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>, %sz0: index)
     -> memref<?x4x5xf32> {
   // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
 }
 
 // -----
 
-func.func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>)
+func.func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>, %sz0: index)
     -> memref<?x4x5xf32> {
   // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}}
-  %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
       : memref<?x21xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
 }
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 2d69904f27db5e..60fb0ffeee2403 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -106,9 +106,9 @@ func.func @expand_collapse_shape_static(
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
     memref<3x4x5xf32> into memref<12x5xf32>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [3, 4, 5]
 //  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
-  %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+  %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 5] :
     memref<12x5xf32> into memref<3x4x5xf32>
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
@@ -116,9 +116,9 @@ func.func @expand_collapse_shape_static(
   %1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
     memref<3x4x5xf32> into memref<3x20xf32>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [3, 4, 5]
 //  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
-  %r1 = memref.expand_shape %1 [[0], [1, 2]] :
+  %r1 = memref.expand_shape %1 [[0], [1, 2]] output_shape [3, 4, 5] :
     memref<3x20xf32> into memref<3x4x5xf32>
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
@@ -126,29 +126,29 @@ func.func @expand_collapse_shape_static(
   %2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
     memref<3x4x5xf32> into memref<60xf32>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] output_shape [3, 4, 5]
 //  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
-  %r2 = memref.expand_shape %2 [[0, 1, 2]] :
+  %r2 = memref.expand_shape %2 [[0, 1, 2]] output_shape [3, 4, 5] :
       memref<60xf32> into memref<3x4x5xf32>
 
-//       CHECK:   memref.expand_shape {{.*}} []
+//       CHECK:   memref.expand_shape {{.*}} [] output_shape [1, 1]
 //  CHECK-SAME:     memref<f32> into memref<1x1xf32>
-  %r5 = memref.expand_shape %arg5 [] :
+  %r5 = memref.expand_shape %arg5 [] output_shape [1, 1] :
       memref<f32> into memref<1x1xf32>
 
 // Reshapes with a custom layout map.
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-  %l0 = memref.expand_shape %arg3 [[0], [1, 2]] :
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [30, 4, 5]
+  %l0 = memref.expand_shape %arg3 [[0], [1, 2]] output_shape [30, 4, 5] :
       memref<30x20xf32, strided<[4000, 2], offset: 100>>
       into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-  %l1 = memref.expand_shape %arg3 [[0, 1], [2]] :
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [2, 15, 20]
+  %l1 = memref.expand_shape %arg3 [[0, 1], [2]] output_shape [2, 15, 20] :
       memref<30x20xf32, strided<[4000, 2], offset: 100>>
       into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-  %r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [1, 1, 5]
+  %r4 = memref.expand_shape %arg4 [[0], [1, 2]] output_shape [1, 1, 5] :
       memref<1x5xf32, strided<[5, 1], offset: ?>> into
       memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>>
 
@@ -164,9 +164,9 @@ func.func @expand_collapse_shape_static(
       memref<2049xi64, strided<[?], offset: ?>>
 
   // Reshapes that expand and collapse back a contiguous buffer with some 1's.
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-  %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
+  %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]:
     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
@@ -176,15 +176,18 @@ func.func @expand_collapse_shape_static(
 
   // Reshapes on tensors.
 //       CHECK:   tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-  %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] :
+  %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] :
     tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
 
 //       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
   %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
     tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
 
+//       CHECK:   tensor.dim %arg2, {{.*}} : tensor<3x?x5xf32>
 //       CHECK:   tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-  %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] :
+  %c1 = arith.constant 1 : index
+  %sz1 = tensor.dim %arg2, %c1 : tensor<3x?x5xf32>
+  %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] output_shape [1, 3, %sz1, 1, 5] :
     tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
 
 //       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
@@ -197,15 +200,18 @@ func.func @expand_collapse_shape_static(
 func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>>,
          %arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>,
-         %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>) {
+         %arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
+         %arg4: index,
+         %arg5: index,
+         %arg6: index) {
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
     memref<?x?x?xf32> into memref<?x?xf32>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
 //  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
-  %r0 = memref.expand_shape %0 [[0, 1], [2]] :
+  %r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
     memref<?x?xf32> into memref<?x4x?xf32>
 
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
@@ -214,9 +220,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
     memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>> into
     memref<?x?xf32, strided<[?, 1], offset: 0>>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
 //  CHECK-SAME:     memref<?x?xf32, strided<[?, 1]>> into memref<?x4x?xf32, strided<[?, ?, 1]>>
-  %r1 = memref.expand_shape %1 [[0, 1], [2]] :
+  %r1 = memref.expand_shape %1 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
     memref<?x?xf32, strided<[?, 1], offset: 0>> into
     memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
 
@@ -226,9 +232,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
     memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> into
     memref<?x?xf32, strided<[?, 1], offset: ?>>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
 //  CHECK-SAME:     memref<?x?xf32, strided<[?, 1], offset: ?>> into memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
-  %r2 = memref.expand_shape %2 [[0, 1], [2]] :
+  %r2 = memref.expand_shape %2 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
     memref<?x?xf32, strided<[?, 1], offset: ?>> into
     memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
 
@@ -238,9 +244,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
     memref<?x42xf32, strided<[42, 1], offset: 0>> into
     memref<?xf32, strided<[1]>>
 
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]]
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42]
 //  CHECK-SAME:     memref<?xf32, strided<[1]>> into memref<?x42xf32>
-  %r3 = memref.expand_shape %3 [[0, 1]] :
+  %r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
     memref<?xf32, strided<[1]>> into memref<?x42xf32>
   return
 }
@@ -248,12 +254,12 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 func.func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
     -> (memref<f32>, memref<1x1xf32>) {
   %0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
-  %1 = memref.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
+  %1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
   return %0, %1 : memref<f32>, memref<1x1xf32>
 }
 // CHECK-LABEL: func @expand_collapse_shape_zero_dim
 //       CHECK:   memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
-//       CHECK:   memref.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
+//       CHECK:   memref.expand_shape %{{.*}} [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
 
 func.func @collapse_shape_to_dynamic
   (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
@@ -270,16 +276,18 @@ func.func @collapse_shape_to_dynamic
 // CHECK-LABEL: func @expand_collapse_shape_transposed_layout
 func.func @expand_collapse_shape_transposed_layout(
     %m0: memref<?x?xf32, strided<[1, 10], offset: 0>>,
-    %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>) {
+    %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>,
+    %sz0: index,
+    %sz1: index) {
 
-  %r0 = memref.expand_shape %m0 [[0], [1, 2]] :
+  %r0 = memref.expand_shape %m0 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] :
     memref<?x?xf32, strided<[1, 10], offset: 0>> into
     memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>>
   %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] :
     memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>> into
     memref<?x?xf32, strided<[1, 10], offset: 0>>
 
-  %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] :
+  %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] :
     memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into
     memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>>
   %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] :
diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir
index 4d7fcf6ac7cbbc..28777a3e886722 100644
--- a/mlir/test/Dialect/MemRef/runtime-verification.mlir
+++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir
@@ -2,13 +2,14 @@
 
 // CHECK-LABEL: func @expand_shape(
 //  CHECK-SAME:     %[[m:.*]]: memref<?xf32>
+//  CHECK-SAME:     %[[sz0:.*]]: index 
 //   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
 //   CHECK-DAG:   %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
 //       CHECK:   %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
 //       CHECK:   %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
 //       CHECK:   cf.assert %[[cmpi]], "ERROR: Runtime op verification failed
-func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
-  %0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
+func.func @expand_shape(%m: memref<?xf32>, %sz0: index) -> memref<?x5xf32> {
+  %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz0, 5] : memref<?xf32> into memref<?x5xf32>
   return %0 : memref<?x5xf32>
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
index edb53fa024c26b..c96f9c31443db3 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
@@ -12,7 +12,7 @@
 //
 // CHECK-ROUND-LABEL: func.func @sparse_expand(
 // CHECK-ROUND-SAME:  %[[A:.*]]: tensor<100xf64, #sparse{{[0-9]*}}>) -> tensor<10x10xf64, #sparse{{[0-9]*}}>
-//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
+//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [10, 10] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
 //      CHECK-ROUND:  return %[[E]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
 //
 // CHECK-LABEL:   func.func @sparse_expand(
@@ -39,7 +39,7 @@
 // CHECK:         return %[[NT1]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
 //
 func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
-  %0 = tensor.expand_shape %arg0 [[0, 1]] :
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [10, 10] :
     tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
   return %0 : tensor<10x10xf64, #SparseMatrix>
 }
@@ -94,8 +94,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // roundtrip:
 //
 // CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand(
-// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
-//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
+// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>, %[[SZ0:.*]]: index) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
+//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [%[[SZ0]], 10] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
 //      CHECK-ROUND:  return %[[E]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
 //
 // CHECK-LABEL:   func.func @dynamic_sparse_expand(
@@ -127,8 +127,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
 // CHECK-NOT:     sparse_tensor.convert
 // CHECK:         return %[[NT1]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
 //
-func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
-  %0 = tensor.expand_shape %arg0 [[0, 1]] :
+func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>, %sz0: index) -> tensor<?x10xf64, #SparseMatrix> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 10] :
     tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
   return %0 : tensor<?x10xf64, #SparseMatrix>
 }
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 815bc383af95a6..4f553adcc500fb 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -367,11 +367,14 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
 
 // CHECK-LABEL: func @tensor.expand_shape(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x10xf32>
-func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
+func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
-  // CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
-  %0 = tensor.expand_shape %t1 [[0, 1], [2]]
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C2]] : index
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
+  %0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
       : tensor<?x10xf32> into tensor<2x?x10xf32>
 
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
@@ -384,14 +387,15 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
 // CHECK-LABEL: func @tensor.expand_shape_of_slice(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
 func.func @tensor.expand_shape_of_slice(
-    %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
+    %t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
       tensor<?x20xf32> to tensor<?x10xf32>
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
-  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
-  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
+  // CHECK: %[[C7:.*]] = arith.constant 7 : index
+  // CHECK: %[[VAL_1:.*]] = arith.divui %{{.*}}, %[[C7]] : index
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
+  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
       tensor<?x10xf32> into tensor<?x7x2x5xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
@@ -407,8 +411,8 @@ func.func @tensor.expand_shape_of_scalar_slice(
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
   // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] :  memref<?xf32> to memref<f32, strided<[], offset: ?>>
   %0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
-  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
-  %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
+  %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
   // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
   // CHECK: return %[[r]]
   return %1 : tensor<1xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..23921a824f2136 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -4,7 +4,7 @@
 // CHECK-LABEL: expand_shape_identity_fold
 // CHECK-NEXT: return
 func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0]] : tensor<5xf32> into tensor<5xf32>
+  %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
   return %0 : tensor<5xf32>
 }
 
@@ -13,7 +13,7 @@ func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
 // CHECK-LABEL: expand_shape_rank0_identity_fold
 // CHECK-NEXT: return
 func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
-  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<f32>
+  %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
   return %0 : tensor<f32>
 }
 
@@ -1051,29 +1051,28 @@ func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4
 
 // -----
 
-func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>)
+func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
     -> tensor<?x6x4x?x5xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
-  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]]
-      : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
   return %1 : tensor<?x6x4x?x5xf32>
 }
 // CHECK-LABEL: compose_expand_of_expand
-//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
+//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
 //   CHECK-NOT:   tensor.expand_shape
 
 // -----
 
 func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
     -> tensor<1x1x1xf32> {
-  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
-  %1 = tensor.expand_shape %0 [[0, 1, 2]]
+  %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
+  %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
       : tensor<1xf32> into tensor<1x1x1xf32>
   return %1 : tensor<1x1x1xf32>
 }
 // CHECK-LABEL: compose_expand_of_expand_of_zero_dim
-//       CHECK:   tensor.expand_shape %{{.*}} []
+//       CHECK:   tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
 //  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
 
 // -----
@@ -1093,7 +1092,7 @@ func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
 // -----
 
 func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
       : tensor<12x4xf32> into tensor<3x4x4xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<3x4x4xf32> into tensor<12x4xf32>
@@ -1104,9 +1103,9 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
 
 // -----
 
-func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>)
+func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
     -> tensor<?x?xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1121,7 +1120,7 @@ func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
     -> tensor<24x5x42x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
       : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
-  %1 = tensor.expand_shape %0 [[0, 1, 2, 3]]
+  %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
       : tensor<40320xf32> into tensor<24x5x42x8xf32>
   return %1 : tensor<24x5x42x8xf32>
 }
@@ -1137,7 +1136,7 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
     -> tensor<2x3x4x5x6x7x8xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
       : tensor<24x5x42x8xf32> into tensor<40320xf32>
-  %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]]
+  %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
       : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
   return %1 : tensor<2x3x4x5x6x7x8xf32>
 }
@@ -1149,16 +1148,16 @@ func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
 
 // -----
 
-func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
+func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
     -> tensor<?x?xi64> {
-  %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
+  %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
   return %1 : tensor<?x?xi64>
 }
 // CHECK-LABEL: func @compose_collapse_of_expand
-//       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>)
+//       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
 //  CHECK-NEXT: tensor.collapse_shape %[[ARG]]
 //  CHECK-SAME:   [0, 1], [2]
 //  CHECK-SAME:   : tensor<?x?x?xi64> into tensor<?x?xi64>
@@ -1167,14 +1166,14 @@ func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>)
 
 func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
     -> tensor<4x512xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]]
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
     : tensor<2048xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
     : tensor<1x4x1x512xf32> into tensor<4x512xf32>
   return %1 : tensor<4x512xf32>
 }
 //       CHECK: func @compose_collapse_of_expand_1D
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
 //  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
 
 // -----
@@ -1183,14 +1182,14 @@ func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>
     -> tensor<1x1x1x1xf32> {
   %0 = tensor.collapse_shape %arg0 []
       : tensor<1x1x1xf32> into tensor<f32>
-  %1 = tensor.expand_shape %0 []
+  %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
       : tensor<f32> into tensor<1x1x1x1xf32>
   return %1 : tensor<1x1x1x1xf32>
 }
 //      CHECK: func @compose_expand_of_collapse_0_rank_to_expand
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1xf32>
 //      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:     [0], [1], [2, 3]
+// CHECK-SAME:     {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1]
 //      CHECK:   return %[[RESULT]]
 
 // -----
@@ -1199,7 +1198,7 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
     -> tensor<1x1x1xf32> {
   %0 = tensor.collapse_shape %arg0 []
       : tensor<1x1x1x1xf32> into tensor<f32>
-  %1 = tensor.expand_shape %0 []
+  %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
       : tensor<f32> into tensor<1x1x1xf32>
   return %1 : tensor<1x1x1xf32>
 }
@@ -1214,8 +1213,8 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0
-  %0 = tensor.expand_shape %arg0 [] : tensor<f32> into tensor<1xf32>
-  %1 = tensor.expand_shape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32>
+  %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
+  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
   %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
   return %2 : tensor<f32>
 }
@@ -1250,7 +1249,7 @@ func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
 // -----
 
 func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]]
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
     : tensor<4x512xf32> into tensor<1x4x1x512xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
     : tensor<1x4x1x512xf32> into tensor<2048xf32>
@@ -1264,42 +1263,40 @@ func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048x
 
 func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
     -> tensor<4x512x1x1xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]]
-    : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
     : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
   return %1 : tensor<4x512x1x1xf32>
 }
 //       CHECK: func @fold_collapse_of_expand_unit_dims
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
 //  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
 
 // -----
 
 func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
     -> tensor<4x512x1x512x4xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]]
-    : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
 //       CHECK: func @compose_collapse_of_expand_unit_dims
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
 //  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
 
 // -----
 
 func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
     -> tensor<2x1xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
       : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
 //       CHECK: func @compose_collapse_of_expand_trailing_unit_dims
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
@@ -1321,14 +1318,13 @@ func.func @compose_collapse_of_collapse_unit_dims_dynamic(
 
 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
     -> tensor<2x1xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1, 2]]
-      : tensor<2xf32> into tensor<2x1x1xf32>
+  %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
       : tensor<2x1x1xf32> into tensor<2x1xf32>
   return %1 : tensor<2x1xf32>
 }
 //       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
-//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]]
+//       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
 
 // -----
@@ -1349,8 +1345,7 @@ func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
 
 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
     -> tensor<12x42xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]]
-      : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
+  %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
       : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
   return %1 : tensor<12x42xf32>
@@ -1361,9 +1356,9 @@ func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf3
 
 // -----
 
-func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>)
+func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
     -> tensor<?x?xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
+  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
@@ -1378,7 +1373,7 @@ func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>
 
 func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
     -> tensor<2x6x16xf32> {
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]]
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
       : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
   %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
       : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
@@ -1392,7 +1387,7 @@ func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
 
 func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
     -> tensor<12x1xf32> {
-  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]]
+  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
       : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
       : tensor<3x2x2x1xf32> into tensor<12x1xf32>
@@ -1401,7 +1396,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 //      CHECK: func @no_fold_collapse_of_expand_empty_expr
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
 //      CHECK:    %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
-// CHECK-SAME:      [0], [1], [2, 3]
+// CHECK-SAME:      {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
 //      CHECK:    %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
 // CHECK-SAME:      [0, 1, 2], [3]
 //      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>
@@ -1410,7 +1405,7 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
       : tensor<2x8xi32> into tensor<2x4x2xi32>
   return %0 : tensor<2x4x2xi32>
 }
@@ -1421,7 +1416,7 @@ func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
 // -----
 func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
   %c0 = tensor.splat %arg : tensor<2x4xf32>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
       : tensor<2x4xf32> into tensor<2x2x2xf32>
   return %0 : tensor<2x2x2xf32>
 }
@@ -1434,13 +1429,12 @@ func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
 // -----
 
 // CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
-// CHECK-SAME: %[[F:.+]]: f32
-// CHECK-SAME: %[[M:.+]]: index
-func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index) -> tensor<2x2x?xf32> {
-  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
+// CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index)
+func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
+  // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
   // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
   %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]] : tensor<2x?xf32> into tensor<2x2x?xf32>
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
   return %0 : tensor<2x2x?xf32>
 }
 
@@ -1475,7 +1469,7 @@ func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x
 
 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
   %c0 = arith.constant dense<42> : tensor<2x8xi16>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
       : tensor<2x8xi16> into tensor<2x4x2xi16>
   return %0 : tensor<2x4x2xi16>
 }
@@ -1488,7 +1482,7 @@ func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
 
 func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
   %c0 = arith.constant dense<42.0> : tensor<2x8xf32>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
       : tensor<2x8xf32> into tensor<2x4x2xf32>
   return %0 : tensor<2x4x2xf32>
 }
@@ -1501,7 +1495,7 @@ func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
 
 func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
   %c0 = arith.constant dense<42.0> : tensor<2x8xf64>
-  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
       : tensor<2x8xf64> into tensor<2x4x2xf64>
   return %0 : tensor<2x4x2xf64>
 }
@@ -1851,7 +1845,7 @@ func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
   // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
   // CHECK: return %[[FROM]] : tensor<1xi32>
   %0 = tensor.from_elements %arg0 : tensor<i32>
-  %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
+  %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
   return %1 : tensor<1xi32>
 }
 
@@ -2073,9 +2067,9 @@ func.func @empty_tensor_canonicalize(%i : index) {
 //       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
 //       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
 //       CHECK:   return %[[apply]]
-func.func @dim_of_expand_shape(%t: tensor<?x?xf32>) -> index {
+func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
   %c2 = arith.constant 2 : index
-  %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]]
+  %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
       : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
   %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
   return %1 : index
@@ -2107,9 +2101,9 @@ func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
 // CHECK-LABEL: func @collapse_expand_fold_to_cast(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
 //       CHECK:   return %[[t]]
-func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>) -> (tensor<?xf32>)
+func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
 {
-  %0 = tensor.expand_shape %t [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
   %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
   return %1 : tensor<?xf32>
 }
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 15f841f2128edb..e200a4f8926130 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -13,10 +13,9 @@ module attributes {transform.with_named_sequence} {
 // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
-func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
   %0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
-  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
-      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 625408dfefe216..d3ac6ce792f365 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -11,9 +11,11 @@ func.func @expand_shape_of_rank_reducing_extract(
 {
   %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
       : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
-  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
+  %c0 = arith.constant 0 : index
+  %sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5]
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
-  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
+  %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5]
       : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
   return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
 }
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 79ca0de68a1e9b..3617ed5d61afee 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -273,21 +273,10 @@ func.func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8
 
 // -----
 
-func.func @illegal_expanding_reshape_dynamic_tensor
-  (%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x4x?xf32> {
-  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
-  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]]
-      : tensor<?x?x?xf32> into tensor<?x?x?x4x?xf32>
-  return %0 : tensor<?x?x?x4x?xf32>
-}
-
-// -----
-
-
 func.func @illegal_expanding_reshape_static_tensor
     (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> {
   // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
-  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]]
+  %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5]
       : tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32>
   return %0 : tensor<2x3x2x4x5xf32>
 }
@@ -304,20 +293,20 @@ func.func @illegal_collapsing_reshape_static_tensor
 
 // -----
 
-func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>)
+func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>, %sz0: index)
     -> tensor<?x4x5xf32> {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
-  %0 = tensor.expand_shape %arg0 [[0, 1], [2]]
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5]
       : tensor<?x?xf32> into tensor<?x4x5xf32>
   return %0 : tensor<?x4x5xf32>
 }
 
 // -----
 
-func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>)
+func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>, %sz0: index)
     -> tensor<?x4x5xf32> {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
-  %0 = tensor.expand_shape %arg0 [[0], [1, 2]]
+  %0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
       : tensor<?x?xf32> into tensor<?x4x5xf32>
   return %0 : tensor<?x4x5xf32>
 }
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 2b0a74acce0826..378137a14b59ff 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -194,12 +194,26 @@ func.func @insert_slice(
 func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
     -> (tensor<f32>, tensor<1x1xf32>) {
   %0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
-  %1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
+  %1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
   return %0, %1 : tensor<f32>, tensor<1x1xf32>
 }
 // CHECK-LABEL: func @tensor_reshape_zero_dim
 //       CHECK:   tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
-//       CHECK:   tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
+//       CHECK:   tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
+
+// -----
+
+func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor<?x?xf32>, %sz0 : index, %sz1 : index, %sz2 : index)
+    -> (tensor<5x?x?x?xf32>) {
+  %1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+  return %1 : tensor<5x?x?x?xf32>
+}
+
+// CHECK-LABEL:  func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> {
+//       CHECK:    %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
+//       CHECK:    return %expanded : tensor<5x?x?x?xf32>
+//       CHECK:  }
+
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index 9948c0246e6ed6..5a2eade0ecccf1 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -2,7 +2,7 @@
 
 // CHECK-LABEL: func.func @single_dim_packing(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<8x32xf32>
 func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
   %empty = tensor.empty() : tensor<8x32xf32>
@@ -27,7 +27,7 @@ func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x3
 
 // CHECK-LABEL: func.func @single_last_inner_dim_packing(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
 func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
   %empty = tensor.empty() : tensor<5x8x32xf32>
@@ -39,7 +39,7 @@ func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8
 
 // CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<64xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<64xf32> into tensor<2x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<2x32xf32>
 func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> {
   %empty = tensor.empty() :  tensor<2x32xf32>
@@ -51,7 +51,7 @@ func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf3
 
 // CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(
 // CHECK-SAME:    %[[ARG0:.+]]: tensor<5x256xf32>)
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
 // CHECK:         return %[[EXPANDED]] : tensor<5x8x32xf32>
 func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
   %empty = tensor.empty() : tensor<5x8x32xf32>
@@ -85,7 +85,7 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
 
 // CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
   %empty = tensor.empty() : tensor<1x32x1x1xf32>
@@ -98,7 +98,7 @@ func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf3
 
 // CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
   %empty = tensor.empty() : tensor<1x16x1x2xf32>
@@ -111,7 +111,7 @@ func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf3
 
 // CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]
-// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+// CHECK:         %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1]
 // CHECK:         return %[[EXPANDED]]
 func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
   %empty = tensor.empty() : tensor<1x16x2x1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 58538b66c5e0c7..2589044c38ad8c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3791,6 +3791,7 @@ cc_library(
     includes = ["include"],
     deps = [
         ":DialectUtilsIncGen",
+        ":ArithDialect",
         ":IR",
         ":Support",
         "//llvm:Support",



More information about the llvm-commits mailing list