[Mlir-commits] [mlir] [MLIR][Memref] Improve `expand-strided-metadata` pass (PR #129642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 14 12:20:41 PDT 2025


================
@@ -334,74 +312,56 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                              ArrayRef<OpFoldResult> origSizes,
                                              ArrayRef<OpFoldResult> origStrides,
                                              unsigned groupId) {
-  SmallVector<int64_t, 2> reassocGroup =
-      expandShape.getReassociationIndices()[groupId];
+  auto reassocIndices = expandShape.getReassociationIndices();
+  unsigned currIdx = 0;
+  for (unsigned i = 0; i < groupId; i++)
+    currIdx += reassocIndices[i].size();
+  SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
   assert(!reassocGroup.empty() &&
          "Reassociation group should have at least one dimension");
 
   unsigned groupSize = reassocGroup.size();
   MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-
   // Fill up the expanded strides, with the information we can deduce from the
   // resulting shape.
-  uint64_t currentStride = 1;
+  Location loc = expandShape.getLoc();
   SmallVector<OpFoldResult> expandedStrides(groupSize);
-  for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-
-    currentStride *= dimSize;
+  DenseMap<int, Value> dynSizes;
+  unsigned dynCount = 0;
+  Operation::operand_range dynOutShapes = expandShape.getOutputShape();
+  for (unsigned i = 0, e = expandShapeType.getRank(); i < e; i++) {
+    if (expandShapeType.isDynamicDim(i))
+      dynSizes[i] = dynOutShapes[dynCount++];
   }
-
-  // Collect the statically known information about the original stride.
-  Value source = expandShape.getSrc();
-  auto sourceType = cast<MemRefType>(source.getType());
-  auto [strides, offset] = sourceType.getStridesAndOffset();
-
-  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
-                                ? origStrides[groupId]
-                                : builder.getIndexAttr(strides[groupId]);
-
-  // Apply the original stride to all the strides.
-  int64_t doneStrideIdx = 0;
-  // If we saw a dynamic dimension, we need to fix-up all the strides up to
-  // that dimension with the dynamic size.
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "We shouldn't be able to change dynamicity");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      int64_t baseExpandedStride =
-          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-              .getInt();
-      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  // Now apply the origStride to the remaining dimensions.
+  OpFoldResult origStride = origStrides[groupId];
   AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    int64_t baseExpandedStride =
-        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-            .getInt();
-    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  int64_t resultOffset;
+  SmallVector<int64_t, 4> resultStrides;
+  (void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
+  expandedStrides[groupSize - 1] =
+      !ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
+          ? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
+          : origStride;
+  OpFoldResult currentStride = builder.getIndexAttr(1);
+  for (int i = groupSize - 2; i >= 0; i--) {
----------------
MaheshRavishankar wrote:

I think this also gets simplified if you avoid using `dynSizes` below and just use `expandedSizes` which is a `SmallVector<OpFoldResult>` as suggested above.

https://github.com/llvm/llvm-project/pull/129642


More information about the Mlir-commits mailing list