[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Sep 30 11:33:45 PDT 2025
================
@@ -250,6 +261,249 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
}
};
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+ if (!value)
+ return value;
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+ return splatAttr.reshape(newType);
+ } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+ return denseAttr.reshape(newType);
+ } else if (auto denseResourceAttr =
+ llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+ return DenseResourceElementsAttr::get(newType,
+ denseResourceAttr.getRawHandle());
+ }
+ return {};
+ }
+
+ LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
+ auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+ if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+ return failure();
+
+ auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+ oldType.getElementType());
+ auto memRefType =
+ MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+ AffineMap(), oldType.getMemorySpace());
+ auto newInitialValue =
+ flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+ rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+ memRefType, newInitialValue, globalOp.getConstant(),
+ /*alignment=*/IntegerAttr());
+ return success();
+ }
+};
+
+struct FlattenCollapseShape final
+ : public OpRewritePattern<memref::CollapseShapeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ memref::ExtractStridedMetadataOp metadata =
+ memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+ SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
+ OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+ SmallVector<OpFoldResult> collapsedSizes;
+ SmallVector<OpFoldResult> collapsedStrides;
+ unsigned numGroups = op.getReassociationIndices().size();
+ collapsedSizes.reserve(numGroups);
+ collapsedStrides.reserve(numGroups);
+ for (unsigned i = 0; i < numGroups; ++i) {
+ SmallVector<OpFoldResult> groupSizes =
+ memref::getCollapsedSize(op, rewriter, origSizes, i);
+ SmallVector<OpFoldResult> groupStrides =
+ memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+ collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+ collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, op.getType(), op.getSrc(), offset, collapsedSizes,
+ collapsedStrides);
+ return success();
----------------
krzysz00 wrote:
Why are you rewritting collapse_shape to a reinterpret_cast? Shouldn't it be a noop after flattening?
https://github.com/llvm/llvm-project/pull/159841
More information about the Mlir-commits
mailing list