[Mlir-commits] [mlir] eca7d83 - [mlir][memref] Simplify expand_shape size/stride computation using output_shape (#187844)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 24 20:00:01 PDT 2026


Author: Longsheng Mou
Date: 2026-03-25T10:59:56+08:00
New Revision: eca7d833a6158194078635c1054ada81324e9b97

URL: https://github.com/llvm/llvm-project/commit/eca7d833a6158194078635c1054ada81324e9b97
DIFF: https://github.com/llvm/llvm-project/commit/eca7d833a6158194078635c1054ada81324e9b97.diff

LOG: [mlir][memref] Simplify expand_shape size/stride computation using output_shape (#187844)

This PR refactors `getExpandedSizes` and `getExpandedStrides` to compute
their results directly from the `output_shape` of `memref.expand_shape`.
Instead of reconstructing expanded sizes/strides through manual
inference, we now rely on the operation’s explicit shape information.
The previous implementation imposed the restriction that there must be
at most one dynamic size per reassociation group. This limitation is
removed by the new approach: any number of dynamic dimensions within a
group is now supported, as long as they are represented in the
`output_shape`.
As a result, the code becomes both simpler and more expressive, while
better matching the semantics of `memref.expand_shape`.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
    mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
    mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index c9352e8f700d7..d6d5e90275015 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -255,10 +255,7 @@ struct ExtractStridedMetadataOpSubviewFolder
 /// \p origSizes hold the sizes of the source shape as values.
 /// This is used to compute the new sizes in cases of dynamic shapes.
 ///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+/// sizes#i = expandOutputShape#i
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -273,34 +270,10 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
   assert(!reassocGroup.empty() &&
          "Reassociation group should have at least one dimension");
 
-  unsigned groupSize = reassocGroup.size();
-  SmallVector<OpFoldResult> expandedSizes(groupSize);
-
-  uint64_t productOfAllStaticSizes = 1;
-  std::optional<unsigned> dynSizeIdx;
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  // Fill up all the statically known sizes.
-  for (unsigned i = 0; i < groupSize; ++i) {
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-    productOfAllStaticSizes *= dimSize;
-    expandedSizes[i] = builder.getIndexAttr(dimSize);
-  }
-
-  // Compute the dynamic size using the original size and all the other known
-  // static sizes:
-  // expandSize = origSize / productOfAllStaticSizes.
-  if (dynSizeIdx) {
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
-        origSizes[groupId]);
-  }
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
+  SmallVector<OpFoldResult> expandedSizes;
+  for (auto index : reassocGroup)
+    expandedSizes.push_back(outputShape[index]);
 
   return expandedSizes;
 }
@@ -313,16 +286,12 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 /// dynamic stride for this reassociation group.
 ///
 /// strides#i =
-///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///     origStrides#reassDim * product(expandOutputShape#j, for j in
 ///                                    reassIdx#i+1..reassIdx#i+group.size-1)
 ///
 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
+/// and expandOutputShape#j is taken directly from the mixed (static and
+/// dynamic) output shape
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -340,25 +309,13 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
          "Reassociation group should have at least one dimension");
 
   unsigned groupSize = reassocGroup.size();
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-
-  // Fill up the expanded strides, with the information we can deduce from the
-  // resulting shape.
-  uint64_t currentStride = 1;
-  SmallVector<OpFoldResult> expandedStrides(groupSize);
-  for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-
-    currentStride *= dimSize;
-  }
+  Location loc = expandShape.getLoc();
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(builder, loc, s0 * s1,
+                                                 {v1, v2});
+  };
 
   // Collect the statically known information about the original stride.
   Value source = expandShape.getSrc();
@@ -369,37 +326,13 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                 ? origStrides[groupId]
                                 : builder.getIndexAttr(strides[groupId]);
 
-  // Apply the original stride to all the strides.
-  int64_t doneStrideIdx = 0;
-  // If we saw a dynamic dimension, we need to fix-up all the strides up to
-  // that dimension with the dynamic size.
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "We shouldn't be able to change dynamicity");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      int64_t baseExpandedStride =
-          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-              .getInt();
-      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  // Now apply the origStride to the remaining dimensions.
-  AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    int64_t baseExpandedStride =
-        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-            .getInt();
-    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+  // Fill up the expanded strides.
+  OpFoldResult currentStride = origStride;
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
+  SmallVector<OpFoldResult> expandedStrides(groupSize);
+  for (int i = groupSize - 1; i >= 0; --i) {
+    expandedStrides[i] = currentStride;
+    currentStride = mul(currentStride, outputShape[reassocGroup[i]]);
   }
 
   return expandedStrides;

diff  --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index e0e4a61e821ce..c2c93525b6509 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -593,38 +593,31 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<
 }
 
 // CHECK-LABEL:   func.func @expand_shape_dynamic(
-// CHECK-SAME:              %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// In this example stride1 and size2 are the same.
-// Hence with CSE, we get the same SSA value.
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_3]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[EXTRACTVALUE_2]], %[[INSERTVALUE_5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_8]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_9]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[INSERTVALUE_10]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
+// CHECK:           return %[[UNREALIZED_CONVERSION_CAST_2]] : memref<1x2x?xf32>
 // CHECK:         }
 
 // -----
@@ -638,41 +631,36 @@ func.func @expand_shape_dynamic_with_non_identity_layout(
   return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func.func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK-SAME:        %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]]
-// CHECK:           %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index
-// CHECK:           %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_3:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_4:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MUL_0:.*]] = llvm.mul %[[EXTRACTVALUE_4]], %[[UNREALIZED_CONVERSION_CAST_0]] overflow<nsw> : i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[MUL_0]] : i64 to index
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_3:.*]] = builtin.unrealized_conversion_cast %[[UNREALIZED_CONVERSION_CAST_2]] : index to i64
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[EXTRACTVALUE_2]], %[[INSERTVALUE_3]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[EXTRACTVALUE_3]], %[[INSERTVALUE_5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_3]], %[[INSERTVALUE_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_8]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[EXTRACTVALUE_4]], %[[INSERTVALUE_9]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_4:.*]] = builtin.unrealized_conversion_cast %[[INSERTVALUE_10]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK:           return %[[UNREALIZED_CONVERSION_CAST_4]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 // CHECK:         }
 
 // -----

diff  --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 4ed8d4b20229f..70c5e1aee85dc 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -357,9 +357,7 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 //
 // Here we have:
 // For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
+// size 0 = %sz0
 // size 1 = 7
 // size 2 = 8
 // size 3 = 9
@@ -373,63 +371,51 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 // For the group applying to dim1:
 // size 4 = 10
 // size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
+// size 6 = %sz1
 // size 7 = 3
 // stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
+//          = baseStrides#1 * 2 * %sz1 * 3
+//          = baseStrides#1 * %sz1 * 6
 // stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
+//          = baseStrides#1 * %sz1 * 3
 // stride 6 = baseStrides#1 * size 7
 //          = baseStrides#1 * 3
 // stride 7 = baseStrides#1
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
 //   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 6)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @simplify_expand_shape
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: %[[SIZE0:.*]]: index,  %[[SIZE1:.*]]: index)
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 //
-//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE0]], 7, 8, 9, 10, 2, %[[SIZE1]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
 //
 //   CHECK: return %[[REINTERPRET_CAST]]
 func.func @simplify_expand_shape(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
-    %offset0: index, %offset1: index, %offset2: index,
-    %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index,
     %sz0: index, %sz1: index)
     -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
 
-  %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
+  %expand_shape = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
       memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
-  return %subview :
+  return %expand_shape :
     memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 }
 
@@ -439,12 +425,6 @@ func.func @simplify_expand_shape(
 // into:
 // baseBuffer, baseOffset, baseSizes, baseStrides =
 //     extract_strided_metadata(memref)
-// sizes#reassIdx =
-//     baseSizes#reassDim / product(expandShapeSizes#j,
-//                                  for j in group excluding reassIdx)
-// strides#reassIdx =
-//     baseStrides#reassDim * product(expandShapeSizes#j, for j in
-//                                    reassIdx+1..reassIdx+group.size)
 //
 // Here we have:
 // For the group applying to dim0:
@@ -516,25 +496,15 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 // See extract_strided_metadata_of_expand_shape_all_static for an explanation
 // of the expansion.
 //
-// One of the important characteristic of this test is that the dynamic
-// dimensions produced by the expand_shape appear both in the first dimension
-// (for group 1) and the non-first dimension (second dimension for group 2.)
-// The idea is to make sure that:
-// 1. We properly account for dynamic shapes even when the strides are not
-//    affected by them. (When the dynamic dimension is the first one.)
-// 2. We properly compute the strides affected by dynamic shapes. (When the
-//    dynamic dimension is not the first one.)
 //
 // Here we have:
 // For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
-// size 1 = 7
+// size 0 = %sz0
+// size 1 = %sz1
 // size 2 = 8
 // size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-//          = baseStrides#0 * 504
+// stride 0 = baseStrides#0 * %sz1 * 8 * 9
+//          = baseStrides#0 * %sz1 *72
 // stride 1 = baseStrides#0 * 8 * 9
 //          = baseStrides#0 * 72
 // stride 2 = baseStrides#0 * 9
@@ -542,72 +512,57 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 //
 // For the group applying to dim1:
 // size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
+// size 5 = %sz2
+// size 6 = %sz3
 // size 7 = 3
 // stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
+//          = baseStrides#1 * %sz2 * %sz3 *3
 // stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
+//          = baseStrides#1 * %sz3 *3
 // stride 6 = baseStrides#1 * size 7
 //          = baseStrides#1 * 3
 // stride 7 = baseStrides#1
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 72)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (((s1 * s2) * s0) * 3)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: %[[SIZE0:.*]]: index,  %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index,  %[[SIZE3:.*]]: index)
 //
 //   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
 //   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
 //   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE2]], %[[SIZE3]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE3]], %[[STRIDES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 
-//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE0]], %[[SIZE1]], %[[C8]], %[[C9]], %[[C10]], %[[SIZE2]], %[[SIZE3]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
 func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
-    %offset0: index, %offset1: index, %offset2: index,
-    %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index,
-    %sz0: index, %sz1: index)
+    %sz0: index, %sz1: index, %sz2: index, %sz3: index)
     -> (memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index) {
 
-  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
+  %expand_shape = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, %sz1, 8, 9, 10, %sz2, %sz3, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
-      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+      memref<?x?x8x9x10x?x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
-  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
-    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %expand_shape :
+    memref<?x?x8x9x10x?x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
     -> memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index
@@ -620,7 +575,6 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
       index, index, index, index, index, index, index, index
 }
 
-
 // -----
 
 // Check that we properly handle extract_strided_metadata of expand_shape for


        


More information about the Mlir-commits mailing list