[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Sep 30 11:33:45 PDT 2025


================
@@ -250,6 +261,249 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
   }
 };
 
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+    if (!value)
+      return value;
+    if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+      return splatAttr.reshape(newType);
+    } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+      return denseAttr.reshape(newType);
+    } else if (auto denseResourceAttr =
+                   llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+      return DenseResourceElementsAttr::get(newType,
+                                            denseResourceAttr.getRawHandle());
+    }
+    return {};
+  }
+
+  LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+                                PatternRewriter &rewriter) const override {
+    auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+    if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+      return failure();
+
+    auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+                                            oldType.getElementType());
+    auto memRefType =
+        MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+                        AffineMap(), oldType.getMemorySpace());
+    auto newInitialValue =
+        flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+    rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+        globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+        memRefType, newInitialValue, globalOp.getConstant(),
+        /*alignment=*/IntegerAttr());
+    return success();
+  }
+};
+
+struct FlattenCollapseShape final
+    : public OpRewritePattern<memref::CollapseShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> collapsedSizes;
+    SmallVector<OpFoldResult> collapsedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    collapsedSizes.reserve(numGroups);
+    collapsedStrides.reserve(numGroups);
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          memref::getCollapsedSize(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+      collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+      collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, collapsedSizes,
+        collapsedStrides);
+    return success();
+  }
+};
+
+struct FlattenExpandShape final
+    : public OpRewritePattern<memref::ExpandShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> expandedSizes;
+    SmallVector<OpFoldResult> expandedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    expandedSizes.reserve(op.getResultType().getRank());
+    expandedStrides.reserve(op.getResultType().getRank());
+
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          memref::getExpandedSizes(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+      expandedSizes.append(groupSizes.begin(), groupSizes.end());
+      expandedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+    return success();
+  }
+};
+
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+    if (!sourceType || sourceType.getRank() <= 1)
+      return failure();
+    if (!checkLayout(sourceType))
+      return failure();
+
+    MemRefType resultType = op.getType();
+    if (resultType.getRank() <= 1 || !checkLayout(resultType))
+      return failure();
+
+    unsigned elementBitWidth = sourceType.getElementTypeBitWidth();
+    if (!elementBitWidth)
+      return failure();
+
+    Location loc = op.getLoc();
+
+    // Materialize offsets as values so they can participate in linearization.
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
+
+    SmallVector<Value> offsetValues;
+    offsetValues.reserve(mixedOffsets.size());
+    for (OpFoldResult ofr : mixedOffsets)
+      offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
+
+    auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+        rewriter, loc, op.getSource(), ValueRange(offsetValues));
+
+    memref::ExtractStridedMetadataOp sourceMetadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
+
+    SmallVector<OpFoldResult> sourceStrides =
+        sourceMetadata.getConstifiedMixedStrides();
+    OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset();
+
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+    SmallVector<OpFoldResult> resultSizes;
+    SmallVector<OpFoldResult> resultStrides;
+    resultSizes.reserve(resultType.getRank());
+    resultStrides.reserve(resultType.getRank());
+
+    OpFoldResult resultOffset = sourceOffset;
+    for (auto zipped : llvm::enumerate(llvm::zip_equal(
+             mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) {
+      auto idx = zipped.index();
+      auto it = zipped.value();
+      auto offsetOfr = std::get<0>(it);
+      auto strideOfr = std::get<1>(it);
+      auto sizeOfr = std::get<2>(it);
+      auto relativeStrideOfr = std::get<3>(it);
+      OpFoldResult contribution = [&]() -> OpFoldResult {
+        if (Attribute offsetAttr = dyn_cast<Attribute>(offsetOfr)) {
+          if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+            auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+            auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+            return rewriter.getIndexAttr(offsetInt * strideInt);
+          }
+        }
+        Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr);
+        Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+        return rewriter.create<arith::MulIOp>(loc, offsetVal, strideVal)
+            .getResult();
+      }();
+      resultOffset = [&]() -> OpFoldResult {
+        if (Attribute offsetAttr = dyn_cast<Attribute>(resultOffset)) {
+          if (Attribute contribAttr = dyn_cast<Attribute>(contribution)) {
+            auto offsetInt = cast<IntegerAttr>(offsetAttr).getInt();
+            auto contribInt = cast<IntegerAttr>(contribAttr).getInt();
+            return rewriter.getIndexAttr(offsetInt + contribInt);
+          }
+        }
+        Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
+        Value contribVal =
+            getValueFromOpFoldResult(rewriter, loc, contribution);
+        return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
+            .getResult();
+      }();
+
+      if (droppedDims.test(idx))
+        continue;
+
+      resultSizes.push_back(sizeOfr);
+      OpFoldResult combinedStride = [&]() -> OpFoldResult {
+        if (Attribute relStrideAttr = dyn_cast<Attribute>(relativeStrideOfr)) {
+          if (Attribute strideAttr = dyn_cast<Attribute>(strideOfr)) {
+            auto relStrideInt = cast<IntegerAttr>(relStrideAttr).getInt();
+            auto strideInt = cast<IntegerAttr>(strideAttr).getInt();
+            return rewriter.getIndexAttr(relStrideInt * strideInt);
+          }
+        }
+        Value relStrideVal =
+            getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr);
+        Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr);
+        return rewriter.create<arith::MulIOp>(loc, relStrideVal, strideVal)
+            .getResult();
+      }();
+      resultStrides.push_back(combinedStride);
+    }
+
+    memref::LinearizedMemRefInfo linearizedInfo;
+    [[maybe_unused]] OpFoldResult linearizedIndex;
+    std::tie(linearizedInfo, linearizedIndex) =
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, resultOffset,
+                                                 resultSizes, resultStrides);
+
+    Value flattenedSize =
+        getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
+    Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
+
+    Value flattenedSubview = memref::SubViewOp::create(
+        rewriter, loc, flatSource, ValueRange{linearOffset},
+        ValueRange{flattenedSize}, ValueRange{strideOne});
+
+    Value replacement = memref::ReinterpretCastOp::create(
+        rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes,
+        resultStrides);
+
+    rewriter.replaceOp(op, replacement);
+    return success();
----------------
krzysz00 wrote:

... Something here feels extremely dubious. Maybe it'll make more sense when I look at the tests

I'd  say that this all should be run after `fold-memref-alias-ops` and so expand_shape, collapse_shape, and subview shouldn't *exist* at the time this code is run.

https://github.com/llvm/llvm-project/pull/159841


More information about the Mlir-commits mailing list