[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Sep 30 11:33:45 PDT 2025


================
@@ -250,6 +261,249 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
   }
 };
 
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+    if (!value)
+      return value;
+    if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+      return splatAttr.reshape(newType);
+    } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+      return denseAttr.reshape(newType);
+    } else if (auto denseResourceAttr =
+                   llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+      return DenseResourceElementsAttr::get(newType,
+                                            denseResourceAttr.getRawHandle());
+    }
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+                                PatternRewriter &rewriter) const override {
+    auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+    if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+      return failure();
+
+    auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+                                            oldType.getElementType());
+    auto memRefType =
+        MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+                        AffineMap(), oldType.getMemorySpace());
+    auto newInitialValue =
+        flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+    rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+        globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+        memRefType, newInitialValue, globalOp.getConstant(),
+        /*alignment=*/IntegerAttr());
+    return success();
+  }
+};
+
+struct FlattenCollapseShape final
+    : public OpRewritePattern<memref::CollapseShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> collapsedSizes;
+    SmallVector<OpFoldResult> collapsedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    collapsedSizes.reserve(numGroups);
+    collapsedStrides.reserve(numGroups);
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          memref::getCollapsedSize(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+      collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+      collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, collapsedSizes,
+        collapsedStrides);
+    return success();
----------------
krzysz00 wrote:

Why are you rewritting collapse_shape to a reinterpret_cast? Shouldn't it be a noop after flattening?

https://github.com/llvm/llvm-project/pull/159841


More information about the Mlir-commits mailing list