[Mlir-commits] [mlir] [MLIR][Memref] Improve `expand-strided-metadata` pass (PR #129642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 14 12:20:41 PDT 2025
================
@@ -334,74 +312,56 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId) {
- SmallVector<int64_t, 2> reassocGroup =
- expandShape.getReassociationIndices()[groupId];
+ auto reassocIndices = expandShape.getReassociationIndices();
+ unsigned currIdx = 0;
+ for (unsigned i = 0; i < groupId; i++)
+ currIdx += reassocIndices[i].size();
+ SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
assert(!reassocGroup.empty() &&
"Reassociation group should have at least one dimension");
unsigned groupSize = reassocGroup.size();
MemRefType expandShapeType = expandShape.getResultType();
-
- std::optional<int64_t> dynSizeIdx;
-
// Fill up the expanded strides, with the information we can deduce from the
// resulting shape.
- uint64_t currentStride = 1;
+ Location loc = expandShape.getLoc();
SmallVector<OpFoldResult> expandedStrides(groupSize);
- for (int i = groupSize - 1; i >= 0; --i) {
- expandedStrides[i] = builder.getIndexAttr(currentStride);
- uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
- if (ShapedType::isDynamic(dimSize)) {
- assert(!dynSizeIdx && "There must be at most one dynamic size per group");
- dynSizeIdx = i;
- continue;
- }
-
- currentStride *= dimSize;
+ DenseMap<int, Value> dynSizes;
+ unsigned dynCount = 0;
+ Operation::operand_range dynOutShapes = expandShape.getOutputShape();
+ for (unsigned i = 0, e = expandShapeType.getRank(); i < e; i++) {
+ if (expandShapeType.isDynamicDim(i))
+ dynSizes[i] = dynOutShapes[dynCount++];
}
-
- // Collect the statically known information about the original stride.
- Value source = expandShape.getSrc();
- auto sourceType = cast<MemRefType>(source.getType());
- auto [strides, offset] = sourceType.getStridesAndOffset();
-
- OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
- ? origStrides[groupId]
- : builder.getIndexAttr(strides[groupId]);
-
- // Apply the original stride to all the strides.
- int64_t doneStrideIdx = 0;
- // If we saw a dynamic dimension, we need to fix-up all the strides up to
- // that dimension with the dynamic size.
- if (dynSizeIdx) {
- int64_t productOfAllStaticSizes = currentStride;
- assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
- "We shouldn't be able to change dynamicity");
- OpFoldResult origSize = origSizes[groupId];
-
- AffineExpr s0 = builder.getAffineSymbolExpr(0);
- AffineExpr s1 = builder.getAffineSymbolExpr(1);
- for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
- int64_t baseExpandedStride =
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
- .getInt();
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
- builder, expandShape.getLoc(),
- (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
- {origSize, origStride});
- }
- }
-
- // Now apply the origStride to the remaining dimensions.
+ OpFoldResult origStride = origStrides[groupId];
AffineExpr s0 = builder.getAffineSymbolExpr(0);
- for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
- int64_t baseExpandedStride =
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
- .getInt();
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
- builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ int64_t resultOffset;
+ SmallVector<int64_t, 4> resultStrides;
+ (void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
+ expandedStrides[groupSize - 1] =
+ !ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
+ ? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
+ : origStride;
+ OpFoldResult currentStride = builder.getIndexAttr(1);
+ for (int i = groupSize - 2; i >= 0; i--) {
----------------
MaheshRavishankar wrote:
I think this also gets simplified if you avoid using `dynSizes` below and just use `expandedSizes` which is a `SmallVector<OpFoldResult>` as suggested above.
https://github.com/llvm/llvm-project/pull/129642
More information about the Mlir-commits
mailing list