[Mlir-commits] [mlir] [mlir][memref] Simplify expand_shape size/stride computation using output_shape (PR #187844)

Longsheng Mou llvmlistbot at llvm.org
Sat Mar 21 03:36:28 PDT 2026


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/187844

>From d58fc8bcf150db3a7f2da6c1ecf2bcc6fa0b9bc2 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 15:08:46 +0800
Subject: [PATCH 1/6] [mlir][memref] Simplify expand_shape size/stride
 computation using output_shape
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR refactors `getExpandedSizes` and `getExpandedStrides` to compute their results directly from the `output_shape` of `memref.expand_shape`. Instead of reconstructing expanded sizes/strides through manual inference, we now rely on the operation’s explicit shape information.
The previous implementation imposed the restriction that there must be at most one dynamic size per reassociation group. This limitation is removed by the new approach: any number of dynamic dimensions within a group is now supported, as long as they are represented in the `output_shape`.
As a result, the code becomes both simpler and more expressive, while better matching the semantics of `memref.expand_shape`.
---
 .../Transforms/ExpandStridedMetadata.cpp      | 106 ++++--------------
 1 file changed, 22 insertions(+), 84 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index c9352e8f700d7..261bfca195130 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -255,10 +255,7 @@ struct ExtractStridedMetadataOpSubviewFolder
 /// \p origSizes hold the sizes of the source shape as values.
 /// This is used to compute the new sizes in cases of dynamic shapes.
 ///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+/// sizes#i = expandOutputShape#i
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -273,34 +270,10 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
   assert(!reassocGroup.empty() &&
          "Reassociation group should have at least one dimension");
 
-  unsigned groupSize = reassocGroup.size();
-  SmallVector<OpFoldResult> expandedSizes(groupSize);
-
-  uint64_t productOfAllStaticSizes = 1;
-  std::optional<unsigned> dynSizeIdx;
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  // Fill up all the statically known sizes.
-  for (unsigned i = 0; i < groupSize; ++i) {
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-    productOfAllStaticSizes *= dimSize;
-    expandedSizes[i] = builder.getIndexAttr(dimSize);
-  }
-
-  // Compute the dynamic size using the original size and all the other known
-  // static sizes:
-  // expandSize = origSize / productOfAllStaticSizes.
-  if (dynSizeIdx) {
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
-        origSizes[groupId]);
-  }
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
+  SmallVector<OpFoldResult> expandedSizes;
+  for (auto index : reassocGroup)
+    expandedSizes.push_back(outputShape[index]);
 
   return expandedSizes;
 }
@@ -313,16 +286,12 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 /// dynamic stride for this reassociation group.
 ///
 /// strides#i =
-///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///     origStrides#reassDim * product(expandOutputShape#j, for j in
 ///                                    reassIdx#i+1..reassIdx#i+group.size-1)
 ///
 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
+/// and expandOutputShape#j is taken directly from the mixed (static and dynamic)
+/// output shape
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -340,24 +309,22 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
          "Reassociation group should have at least one dimension");
 
   unsigned groupSize = reassocGroup.size();
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-
+  Location loc = expandShape.getLoc();
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(builder, loc, s0 * s1,
+                                                 {v1, v2});
+  };
+
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
   // Fill up the expanded strides, with the information we can deduce from the
   // resulting shape.
-  uint64_t currentStride = 1;
+  OpFoldResult currentStride = builder.getIndexAttr(1);
   SmallVector<OpFoldResult> expandedStrides(groupSize);
   for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-
-    currentStride *= dimSize;
+    expandedStrides[i] = currentStride;
+    currentStride = mul(currentStride, outputShape[reassocGroup[i]]);
   }
 
   // Collect the statically known information about the original stride.
@@ -370,37 +337,8 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                 : builder.getIndexAttr(strides[groupId]);
 
   // Apply the original stride to all the strides.
-  int64_t doneStrideIdx = 0;
-  // If we saw a dynamic dimension, we need to fix-up all the strides up to
-  // that dimension with the dynamic size.
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "We shouldn't be able to change dynamicity");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      int64_t baseExpandedStride =
-          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-              .getInt();
-      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  // Now apply the origStride to the remaining dimensions.
-  AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    int64_t baseExpandedStride =
-        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-            .getInt();
-    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
-  }
+  for (int64_t i = 0; i < groupSize; ++i)
+    expandedStrides[i] = mul(origStride, expandedStrides[i]);
 
   return expandedStrides;
 }

>From 0abba85bb102a47c5d3750ff6e7ca2c7c8942c9d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 15:10:59 +0800
Subject: [PATCH 2/6] update test

---
 .../expand-then-convert-to-llvm.mlir          | 122 ++++++++----------
 1 file changed, 55 insertions(+), 67 deletions(-)

diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index e0e4a61e821ce..c2c93525b6509 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -593,38 +593,31 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<
 }
 
 // CHECK-LABEL:   func.func @expand_shape_dynamic(
-// CHECK-SAME:              %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// In this example stride1 and size2 are the same.
-// Hence with CSE, we get the same SSA value.
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_3]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[EXTRACTVALUE_2]], %[[INSERTVALUE_5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_8]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_9]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[INSERTVALUE_10]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
+// CHECK:           return %[[UNREALIZED_CONVERSION_CAST_2]] : memref<1x2x?xf32>
 // CHECK:         }
 
 // -----
@@ -638,41 +631,36 @@ func.func @expand_shape_dynamic_with_non_identity_layout(
   return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func.func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK-SAME:        %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]]
-// CHECK:           %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index
-// CHECK:           %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_3:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_4:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MUL_0:.*]] = llvm.mul %[[EXTRACTVALUE_4]], %[[UNREALIZED_CONVERSION_CAST_0]] overflow<nsw> : i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[MUL_0]] : i64 to index
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_3:.*]] = builtin.unrealized_conversion_cast %[[UNREALIZED_CONVERSION_CAST_2]] : index to i64
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[EXTRACTVALUE_2]], %[[INSERTVALUE_3]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[EXTRACTVALUE_3]], %[[INSERTVALUE_5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_3]], %[[INSERTVALUE_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_8]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[EXTRACTVALUE_4]], %[[INSERTVALUE_9]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_4:.*]] = builtin.unrealized_conversion_cast %[[INSERTVALUE_10]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK:           return %[[UNREALIZED_CONVERSION_CAST_4]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 // CHECK:         }
 
 // -----

>From b45055faf8b78dc1afdd79b93c263a204c0ce7c0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 15:15:36 +0800
Subject: [PATCH 3/6] update test

---
 .../MemRef/expand-strided-metadata.mlir       | 114 ++++++------------
 1 file changed, 34 insertions(+), 80 deletions(-)

diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 4ed8d4b20229f..502a2f5d009a3 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -357,9 +357,7 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 //
 // Here we have:
 // For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
+// size 0 = %sz0
 // size 1 = 7
 // size 2 = 8
 // size 3 = 9
@@ -373,63 +371,51 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 // For the group applying to dim1:
 // size 4 = 10
 // size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
+// size 6 = %sz1
 // size 7 = 3
 // stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
+//          = baseStrides#1 * 2 * %sz1 * 3
+//          = baseStrides#1 * %sz1 * 6
 // stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
+//          = baseStrides#1 * %sz1 * 3
 // stride 6 = baseStrides#1 * size 7
 //          = baseStrides#1 * 3
 // stride 7 = baseStrides#1
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
 //   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 6))>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 3))>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @simplify_expand_shape
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: %[[SIZE0:.*]]: index,  %[[SIZE1:.*]]: index)
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE1]]]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE1]]]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 //
-//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE0]], 7, 8, 9, 10, 2, %[[SIZE1]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
 //
 //   CHECK: return %[[REINTERPRET_CAST]]
 func.func @simplify_expand_shape(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
-    %offset0: index, %offset1: index, %offset2: index,
-    %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index,
     %sz0: index, %sz1: index)
     -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
 
-  %subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
+  %expand_shape = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
       memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
-  return %subview :
+  return %expand_shape :
     memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 }
 
@@ -439,12 +425,6 @@ func.func @simplify_expand_shape(
 // into:
 // baseBuffer, baseOffset, baseSizes, baseStrides =
 //     extract_strided_metadata(memref)
-// sizes#reassIdx =
-//     baseSizes#reassDim / product(expandShapeSizes#j,
-//                                  for j in group excluding reassIdx)
-// strides#reassIdx =
-//     baseStrides#reassDim * product(expandShapeSizes#j, for j in
-//                                    reassIdx+1..reassIdx+group.size)
 //
 // Here we have:
 // For the group applying to dim0:
@@ -516,25 +496,15 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 // See extract_strided_metadata_of_expand_shape_all_static for an explanation
 // of the expansion.
 //
-// One of the important characteristic of this test is that the dynamic
-// dimensions produced by the expand_shape appear both in the first dimension
-// (for group 1) and the non-first dimension (second dimension for group 2.)
-// The idea is to make sure that:
-// 1. We properly account for dynamic shapes even when the strides are not
-//    affected by them. (When the dynamic dimension is the first one.)
-// 2. We properly compute the strides affected by dynamic shapes. (When the
-//    dynamic dimension is not the first one.)
 //
 // Here we have:
 // For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
-// size 1 = 7
+// size 0 = %sz0
+// size 1 = %sz1
 // size 2 = 8
 // size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-//          = baseStrides#0 * 504
+// stride 0 = baseStrides#0 * %sz1 * 8 * 9
+//          = baseStrides#0 * %sz1 *72
 // stride 1 = baseStrides#0 * 8 * 9
 //          = baseStrides#0 * 72
 // stride 2 = baseStrides#0 * 9
@@ -542,72 +512,57 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 //
 // For the group applying to dim1:
 // size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
+// size 5 = %sz2
+// size 6 = %sz3
 // size 7 = 3
 // stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
+//          = baseStrides#1 * %sz2 * %sz3 *3
 // stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
+//          = baseStrides#1 * %sz3 *3
 // stride 6 = baseStrides#1 * size 7
 //          = baseStrides#1 * 3
 // stride 7 = baseStrides#1
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 72))>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * ((s1 * s2) * 3))>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 3))>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: %[[SIZE0:.*]]: index,  %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index,  %[[SIZE3:.*]]: index)
 //
 //   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
 //   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
 //   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0, %[[SIZE1]]]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE2]], %[[SIZE3]]]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE3]]]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 
-//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE0]], %[[SIZE1]], %[[C8]], %[[C9]], %[[C10]], %[[SIZE2]], %[[SIZE3]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
 func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
-    %offset0: index, %offset1: index, %offset2: index,
-    %size0: index, %size1: index, %size2: index,
-    %stride0: index, %stride1: index, %stride2: index,
-    %sz0: index, %sz1: index)
+    %sz0: index, %sz1: index, %sz2: index, %sz3: index)
     -> (memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index) {
 
-  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
+  %expand_shape = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, %sz1, 8, 9, 10, %sz2, %sz3, 3] :
     memref<?x?xf32, strided<[?,?], offset: ?>> into
-      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+      memref<?x?x8x9x10x?x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
 
-  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
-    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %expand_shape :
+    memref<?x?x8x9x10x?x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
     -> memref<f32>, index,
        index, index, index, index, index, index, index, index,
        index, index, index, index, index, index, index, index
@@ -620,7 +575,6 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
       index, index, index, index, index, index, index, index
 }
 
-
 // -----
 
 // Check that we properly handle extract_strided_metadata of expand_shape for

>From fbb9dcc3126cee85a929f8c087f5ef55b815383d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 15:30:46 +0800
Subject: [PATCH 4/6] code format

---
 mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 261bfca195130..dd95194d4d69e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -290,8 +290,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 ///                                    reassIdx#i+1..reassIdx#i+group.size-1)
 ///
 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandOutputShape#j is taken directly from the mixed (static and dynamic)
-/// output shape
+/// and expandOutputShape#j is taken directly from the mixed (static and
+/// dynamic) output shape
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///

>From d1b1aaf57db16b88906f9d28a4ba9fe162af0abc Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 18:34:08 +0800
Subject: [PATCH 5/6] optimize code

---
 .../Transforms/ExpandStridedMetadata.cpp      | 21 +++++++------------
 1 file changed, 8 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index dd95194d4d69e..d6d5e90275015 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -317,16 +317,6 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                                  {v1, v2});
   };
 
-  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
-  // Fill up the expanded strides, with the information we can deduce from the
-  // resulting shape.
-  OpFoldResult currentStride = builder.getIndexAttr(1);
-  SmallVector<OpFoldResult> expandedStrides(groupSize);
-  for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = currentStride;
-    currentStride = mul(currentStride, outputShape[reassocGroup[i]]);
-  }
-
   // Collect the statically known information about the original stride.
   Value source = expandShape.getSrc();
   auto sourceType = cast<MemRefType>(source.getType());
@@ -336,9 +326,14 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                 ? origStrides[groupId]
                                 : builder.getIndexAttr(strides[groupId]);
 
-  // Apply the original stride to all the strides.
-  for (int64_t i = 0; i < groupSize; ++i)
-    expandedStrides[i] = mul(origStride, expandedStrides[i]);
+  // Fill up the expanded strides.
+  OpFoldResult currentStride = origStride;
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
+  SmallVector<OpFoldResult> expandedStrides(groupSize);
+  for (int i = groupSize - 1; i >= 0; --i) {
+    expandedStrides[i] = currentStride;
+    currentStride = mul(currentStride, outputShape[reassocGroup[i]]);
+  }
 
   return expandedStrides;
 }

>From ae2c81a57393565b58c914225441512dcad07d0d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 21 Mar 2026 18:36:16 +0800
Subject: [PATCH 6/6] update test

---
 .../MemRef/expand-strided-metadata.mlir       | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 502a2f5d009a3..70c5e1aee85dc 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -387,8 +387,8 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 //   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 6))>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 3))>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 6)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @simplify_expand_shape
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
@@ -399,8 +399,8 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 //   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE1]]]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE1]]]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 //
 //   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE0]], 7, 8, 9, 10, 2, %[[SIZE1]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
@@ -525,11 +525,11 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 //
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 72))>
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 72)>
 //   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
 //   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * ((s1 * s2) * 3))>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * (s1 * 3))>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1, s2] -> (((s1 * s2) * s0) * 3)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 3)>
 //   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
 // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
 //  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
@@ -542,11 +542,11 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0, %[[SIZE1]]]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[SIZE1]], %[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE2]], %[[SIZE3]]]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[STRIDES]]#1, %[[SIZE3]]]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZE2]], %[[SIZE3]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZE3]], %[[STRIDES]]#1]
 //   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
 
 //   CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE0]], %[[SIZE1]], %[[C8]], %[[C9]], %[[C10]], %[[SIZE2]], %[[SIZE3]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index



More information about the Mlir-commits mailing list