[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Sep 30 11:33:45 PDT 2025
================
@@ -250,6 +261,249 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
}
};
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+ if (!value)
+ return value;
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+ return splatAttr.reshape(newType);
+ } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
+ }
+ return {};
+ }
+
+ LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
+ auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+ if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+ return failure();
+
+ auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+ oldType.getElementType());
+ auto memRefType =
+ MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+ AffineMap(), oldType.getMemorySpace());
+ auto newInitialValue =
+ flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+ rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+ memRefType, newInitialValue, globalOp.getConstant(),
+ /*alignment=*/IntegerAttr());
+ return success();
+ }
+};
+
+struct FlattenCollapseShape final
+ : public OpRewritePattern<memref::CollapseShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> collapsedSizes;
+ SmallVector<OpFoldResult> collapsedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ collapsedSizes.reserve(numGroups);
+ collapsedStrides.reserve(numGroups);
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getCollapsedSize(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+ collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, collapsedSizes,
+ collapsedStrides);
+ return success();
+ }
+};
+
+struct FlattenExpandShape final
+ : public OpRewritePattern<memref::ExpandShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ expandedSizes.reserve(op.getResultType().getRank());
+ expandedStrides.reserve(op.getResultType().getRank());
+
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getExpandedSizes(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+ expandedSizes.append(groupSizes.begin(), groupSizes.end());
+ expandedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+ return success();
+ }
+};
+
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+ if (!sourceType || sourceType.getRank() <= 1)
+ return failure();
+ if (!checkLayout(sourceType))
+ return failure();
+
+ MemRefType resultType = op.getType();
+ if (resultType.getRank() <= 1 || !checkLayout(resultType))
+ return failure();
+
+ unsigned elementBitWidth = sourceType.getElementTypeBitWidth();
+ if (!elementBitWidth)
+ return failure();
+
+ Location loc = op.getLoc();
+
+ // Materialize offsets as values so they can participate in linearization.
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
+
+ SmallVector<Value> offsetValues;
+ offsetValues.reserve(mixedOffsets.size());
+ for (OpFoldResult ofr : mixedOffsets)
+ offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
+
+ auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+ rewriter, loc, op.getSource(), ValueRange(offsetValues));
+
+ memref::ExtractStridedMetadataOp sourceMetadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
+
+ SmallVector<OpFoldResult> sourceStrides =
+ sourceMetadata.getConstifiedMixedStrides();
+ OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset();
+
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+ SmallVector<OpFoldResult> resultSizes;
+ SmallVector<OpFoldResult> resultStrides;
+ resultSizes.reserve(resultType.getRank());
+ resultStrides.reserve(resultType.getRank());
+
+ OpFoldResult resultOffset = sourceOffset;
+ for (auto zipped : llvm::enumerate(llvm::zip_equal(
+ mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
+ auto idx = zipped.index();
+ auto it = zipped.value();
+ auto offsetOfr = std::get<0>(it);
+ auto strideOfr = std::get<1>(it);
+ auto sizeOfr = std::get<2>(it);
+ auto relativeStrideOfr = std::get<3>(it);
+ OpFoldResult contribution = [&]() -> OpFoldResult {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+ return rewriter.getIndexAttr(offsetInt * strideInt);
+ }
+ }
+ Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr);
+ Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+ return rewriter.create<arith::MulIOp>(loc, offsetVal, strideVal)
+ .getResult();
+ }();
+ resultOffset = [&]() -> OpFoldResult {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
+ if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
+ auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+ auto contribInt = cast<IntegerAttr>(contribAttr).getInt();
+ return rewriter.getIndexAttr(offsetInt + contribInt);
+ }
+ }
+ Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
+ Value contribVal =
+ getValueFromOpFoldResult(rewriter, loc, contribution);
+ return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
+ .getResult();
+ }();
+
+ if (droppedDims.test(idx))
+ continue;
+
+ resultSizes.push_back(sizeOfr);
+ OpFoldResult combinedStride = [&]() -> OpFoldResult {
+ if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
+ if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+ auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt();
+ auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+ return rewriter.getIndexAttr(relStrideInt * strideInt);
+ }
+ }
+ Value relStrideVal =
+ getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr);
+ Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+ return rewriter.create<arith::MulIOp>(loc, relStrideVal, strideVal)
+ .getResult();
+ }();
+ resultStrides.push_back(combinedStride);
+ }
+
+ memref::LinearizedMemRefInfo linearizedInfo;
+ [[maybe_unused]] OpFoldResult linearizedIndex;
+ std::tie(linearizedInfo, linearizedIndex) =
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, resultOffset,
+ resultSizes, resultStrides);
+
+ Value flattenedSize =
+ getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
+ Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+
+ Value flattenedSubview = memref::SubViewOp::create(
+ rewriter, loc, flatSource, ValueRange{linearOffset},
+ ValueRange{flattenedSize}, ValueRange{strideOne});
+
+ Value replacement = memref::ReinterpretCastOp::create(
+ rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
+ resultStrides);
+
+ rewriter.replaceOp(op, replacement);
+ return success();
----------------
krzysz00 wrote:
... Something here feels extremely dubious. Maybe it'll make more sense when I look at the tests
I'd say that this all should be run after `fold-memref-alias-ops` and so expand_shape, collapse_shape, and subview shouldn't *exist* at the time this code is run.
https://github.com/llvm/llvm-project/pull/159841
More information about the Mlir-commits
mailing list