[Mlir-commits] [mlir] [mlir][memref] Simplify expand_shape size/stride computation using output_shape (PR #187844)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 21 00:20:38 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR refactors `getExpandedSizes` and `getExpandedStrides` to compute their results directly from the `output_shape` of `memref.expand_shape`. Instead of reconstructing expanded sizes/strides through manual inference, we now rely on the operation’s explicit shape information.
The previous implementation imposed the restriction that there must be at most one dynamic size per reassociation group. This limitation is removed by the new approach: any number of dynamic dimensions within a group is now supported, as long as they are represented in the `output_shape`.
As a result, the code becomes both simpler and more expressive, while better matching the semantics of `memref.expand_shape`.

---

Patch is 35.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/187844.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+22-84) 
- (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+55-67) 
- (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+34-80) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index c9352e8f700d7..261bfca195130 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -255,10 +255,7 @@ struct ExtractStridedMetadataOpSubviewFolder
 /// \p origSizes hold the sizes of the source shape as values.
 /// This is used to compute the new sizes in cases of dynamic shapes.
 ///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+/// sizes#i = expandOutputShape#i
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -273,34 +270,10 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
   assert(!reassocGroup.empty() &&
          "Reassociation group should have at least one dimension");
 
-  unsigned groupSize = reassocGroup.size();
-  SmallVector<OpFoldResult> expandedSizes(groupSize);
-
-  uint64_t productOfAllStaticSizes = 1;
-  std::optional<unsigned> dynSizeIdx;
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  // Fill up all the statically known sizes.
-  for (unsigned i = 0; i < groupSize; ++i) {
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-    productOfAllStaticSizes *= dimSize;
-    expandedSizes[i] = builder.getIndexAttr(dimSize);
-  }
-
-  // Compute the dynamic size using the original size and all the other known
-  // static sizes:
-  // expandSize = origSize / productOfAllStaticSizes.
-  if (dynSizeIdx) {
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
-        origSizes[groupId]);
-  }
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
+  SmallVector<OpFoldResult> expandedSizes;
+  for (auto index : reassocGroup)
+    expandedSizes.push_back(outputShape[index]);
 
   return expandedSizes;
 }
@@ -313,16 +286,12 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 /// dynamic stride for this reassociation group.
 ///
 /// strides#i =
-///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///     origStrides#reassDim * product(expandOutputShape#j, for j in
 ///                                    reassIdx#i+1..reassIdx#i+group.size-1)
 ///
 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
+/// and expandOutputShape#j is taken directly from the mixed (static and dynamic)
+/// output shape
 ///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
@@ -340,24 +309,22 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
          "Reassociation group should have at least one dimension");
 
   unsigned groupSize = reassocGroup.size();
-  MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-
+  Location loc = expandShape.getLoc();
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(builder, loc, s0 * s1,
+                                                 {v1, v2});
+  };
+
+  SmallVector<OpFoldResult> outputShape = expandShape.getMixedOutputShape();
   // Fill up the expanded strides, with the information we can deduce from the
   // resulting shape.
-  uint64_t currentStride = 1;
+  OpFoldResult currentStride = builder.getIndexAttr(1);
   SmallVector<OpFoldResult> expandedStrides(groupSize);
   for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-
-    currentStride *= dimSize;
+    expandedStrides[i] = currentStride;
+    currentStride = mul(currentStride, outputShape[reassocGroup[i]]);
   }
 
   // Collect the statically known information about the original stride.
@@ -370,37 +337,8 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                 : builder.getIndexAttr(strides[groupId]);
 
   // Apply the original stride to all the strides.
-  int64_t doneStrideIdx = 0;
-  // If we saw a dynamic dimension, we need to fix-up all the strides up to
-  // that dimension with the dynamic size.
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "We shouldn't be able to change dynamicity");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      int64_t baseExpandedStride =
-          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-              .getInt();
-      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  // Now apply the origStride to the remaining dimensions.
-  AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    int64_t baseExpandedStride =
-        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-            .getInt();
-    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
-  }
+  for (int64_t i = 0; i < groupSize; ++i)
+    expandedStrides[i] = mul(origStride, expandedStrides[i]);
 
   return expandedStrides;
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index e0e4a61e821ce..c2c93525b6509 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -593,38 +593,31 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref<
 }
 
 // CHECK-LABEL:   func.func @expand_shape_dynamic(
-// CHECK-SAME:              %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// In this example stride1 and size2 are the same.
-// Hence with CSE, we get the same SSA value.
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_2:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_3:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_2]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_4:.*]] = llvm.insertvalue %[[MLIR_1]], %[[INSERTVALUE_3]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_3:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK:           %[[INSERTVALUE_5:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_6:.*]] = llvm.insertvalue %[[EXTRACTVALUE_2]], %[[INSERTVALUE_5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[MLIR_4:.*]] = llvm.mlir.constant(2 : index) : i64
+// CHECK:           %[[INSERTVALUE_7:.*]] = llvm.insertvalue %[[MLIR_4]], %[[INSERTVALUE_6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_8:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_9:.*]] = llvm.insertvalue %[[UNREALIZED_CONVERSION_CAST_0]], %[[INSERTVALUE_8]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[INSERTVALUE_10:.*]] = llvm.insertvalue %[[MLIR_3]], %[[INSERTVALUE_9]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[INSERTVALUE_10]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32>
+// CHECK:           return %[[UNREALIZED_CONVERSION_CAST_2]] : memref<1x2x?xf32>
 // CHECK:         }
 
 // -----
@@ -638,41 +631,36 @@ func.func @expand_shape_dynamic_with_non_identity_layout(
   return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
 }
 // CHECK-LABEL:   func.func @expand_shape_dynamic_with_non_identity_layout(
-// CHECK-SAME:        %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
-// CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
-// CHECK:           %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64
-// CHECK:           %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64
-// CHECK:           %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]]  : i64
-// CHECK:           %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64
-// CHECK:           %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]]  : i64
-// CHECK:           %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]]  : i64
-// CHECK:           %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64
-// CHECK:           %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index
-// CHECK:           %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64
-// CHECK:           %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]]
-// CHECK:           %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index
-// CHECK:           %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
-// CHECK:           %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
-// CHECK:           return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i64
+// CHECK:           %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_0:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_1:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[MLIR_0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_0:.*]] = llvm.insertvalue %[[EXTRACTVALUE_0]], %[[MLIR_0]][0] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[INSERTVALUE_1:.*]] = llvm.insertvalue %[[EXTRACTVALUE_1]], %[[INSERTVALUE_0]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[MLIR_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[EXTRACTVALUE_2:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_3:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[EXTRACTVALUE_4:.*]] = llvm.extractvalue %[[UNREALIZED_CONVERSION_CAST_1]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/187844


More information about the Mlir-commits mailing list