[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 19 13:03:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Alan Li (lialan)

<details>
<summary>Changes</summary>

This patch is to make the `flatten-memref` pass to be more complete by adding supports of memref operands of more ops.

Some patterns are from: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp

---

Patch is 34.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159841.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+99) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+18-54) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp (+296-5) 
- (modified) mlir/test/Dialect/MemRef/flatten_memref.mlir (+128) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..562b8c11225e8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -14,10 +14,15 @@
 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace mlir {
+class Location;
 class OpBuilder;
 class RewritePatternSet;
 class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
 namespace memref {
 class AllocOp;
 class AllocaOp;
+class CollapseShapeOp;
 class DeallocOp;
+class ExpandShapeOp;
 
 //===----------------------------------------------------------------------===//
 // Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
 memref::AllocaOp allocToAlloca(
     RewriterBase &rewriter, memref::AllocOp alloc,
     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// sizes#i =
+///     baseSizes#groupId / product(expandShapeSizes#j,
+///                                  for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// strides#i =
+///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///                                    reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+///   the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+///   element.)
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId);
+
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+                   ArrayRef<int64_t> maybeConstants,
+                   ArrayRef<OpFoldResult> values,
+                   llvm::function_ref<bool(int64_t)> isDynamic);
+
+/// Compute the collapsed size of the given \p collapseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
+                   ArrayRef<OpFoldResult> origSizes,
+                   ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..6b69d0e366903 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
 
 using namespace mlir;
 using namespace mlir::affine;
+using namespace mlir::memref;
 
 namespace {
 
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
   }
 };
 
-/// Compute the expanded sizes of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+} // namespace
+
+namespace mlir {
+namespace memref {
+SmallVector<OpFoldResult>
+getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
   return expandedSizes;
 }
 
-/// Compute the expanded strides of the given \p expandShape for the
-/// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
-///
-/// strides#i =
-///     origStrides#reassDim * product(expandShapeSizes#j, for j in
-///                                    reassIdx#i+1..reassIdx#i+group.size-1)
-///
-/// Where reassIdx#i is the reassociation index for at index i in \p groupId
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
-///
-/// TODO: Move this utility function directly within ExpandShapeOp. For now,
-/// this is not possible because this function uses the Affine dialect and the
-/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
                                              OpBuilder &builder,
                                              ArrayRef<OpFoldResult> origSizes,
                                              ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   return expandedStrides;
 }
 
-/// Produce an OpFoldResult object with \p builder at \p loc representing
-/// `prod(valueOrConstant#i, for i in {indices})`,
-/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
-/// values[i] otherwise.
-///
-/// \pre for all index in indices: index < values.size()
-/// \pre for all index in indices: index < maybeConstants.size()
-static OpFoldResult
+OpFoldResult
 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
                    ArrayRef<int64_t> maybeConstants,
                    ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-static SmallVector<OpFoldResult>
-getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
   SmallVector<OpFoldResult> collapsedSize;
 
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-static SmallVector<OpFoldResult>
-getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+SmallVector<OpFoldResult>
+getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
                    ArrayRef<OpFoldResult> origSizes,
                    ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 
   return {lastValidStride};
 }
+} // namespace memref
+} // namespace mlir
+
+namespace {
 
 /// From `reshape_like(memref, subSizes, subStrides))` compute
 ///
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..43a67f1fab2be 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,11 +21,13 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace mlir {
@@ -46,6 +48,7 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
   return cast<Value>(in);
 }
 
+
 /// Returns a collapsed memref and the linearized index to access the element
 /// at the specified indices.
 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -89,17 +92,21 @@ static bool needFlattening(Value val) {
   return type.getRank() > 1;
 }
 
-static bool checkLayout(Value val) {
-  auto type = cast<MemRefType>(val.getType());
+static bool checkLayout(MemRefType type) {
   return type.getLayout().isIdentity() ||
          isa<StridedLayoutAttr>(type.getLayout());
 }
 
+static bool checkLayout(Value val) {
+  return checkLayout(cast<MemRefType>(val.getType()));
+}
+
 namespace {
 static Value getTargetMemref(Operation *op) {
   return llvm::TypeSwitch<Operation *, Value>(op)
       .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
-                     memref::AllocOp>([](auto op) { return op.getMemref(); })
+                     memref::AllocOp, memref::DeallocOp>(
+          [](auto op) { return op.getMemref(); })
       .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
                      vector::MaskedStoreOp, vector::TransferReadOp,
                      vector::TransferWriteOp>(
@@ -189,6 +196,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
             rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
         rewriter.replaceOp(op, newTransferWrite);
       })
+      .template Case<memref::DeallocOp>([&](auto op) {
+        auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref);
+        rewriter.replaceOp(op, newDealloc);
+      })
       .Default([&](auto op) {
         op->emitOpError("unimplemented: do not know how to replace op.");
       });
@@ -197,7 +208,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
 template <typename T>
 static ValueRange getIndices(T op) {
   if constexpr (std::is_same_v<T, memref::AllocaOp> ||
-                std::is_same_v<T, memref::AllocOp>) {
+                std::is_same_v<T, memref::AllocOp> ||
+                std::is_same_v<T, memref::DeallocOp>) {
     return ValueRange{};
   } else {
     return op.getIndices();
@@ -250,6 +262,243 @@ struct MemRefRewritePattern : public OpRewritePattern<T> {
   }
 };
 
+/// Flattens memref global ops with more than 1 dimensions to 1 dimension.
+struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  static Attribute flattenAttribute(Attribute value, ShapedType newType) {
+    if (!value)
+      return value;
+    if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(value)) {
+      return splatAttr.reshape(newType);
+    } else if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(value)) {
+      return denseAttr.reshape(newType);
+    } else if (auto denseResourceAttr =
+                   llvm::dyn_cast<DenseResourceElementsAttr>(value)) {
+      return DenseResourceElementsAttr::get(newType,
+                                            denseResourceAttr.getRawHandle());
+    }
+    return {};
+  }
+
+  LogicalResult
+  matchAndRewrite(memref::GlobalOp globalOp,
+                  PatternRewriter &rewriter) const override {
+    auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
+    if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
+      return failure();
+
+    auto tensorType = RankedTensorType::get({oldType.getNumElements()},
+                                            oldType.getElementType());
+    auto memRefType =
+        MemRefType::get({oldType.getNumElements()}, oldType.getElementType(),
+                        AffineMap(), oldType.getMemorySpace());
+    auto newInitialValue =
+        flattenAttribute(globalOp.getInitialValueAttr(), tensorType);
+    rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+        globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(),
+        memRefType, newInitialValue, globalOp.getConstant(),
+        /*alignment=*/IntegerAttr());
+    return success();
+  }
+};
+
+struct FlattenCollapseShape final
+    : public OpRewritePattern<memref::CollapseShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> collapsedSizes;
+    SmallVector<OpFoldResult> collapsedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    collapsedSizes.reserve(numGroups);
+    collapsedStrides.reserve(numGroups);
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          memref::getCollapsedSize(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i);
+      collapsedSizes.append(groupSizes.begin(), groupSizes.end());
+      collapsedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, collapsedSizes,
+        collapsedStrides);
+    return success();
+  }
+};
+
+struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    memref::ExtractStridedMetadataOp metadata =
+        memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
+
+    SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
+    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    OpFoldResult offset = metadata.getConstifiedMixedOffset();
+
+    SmallVector<OpFoldResult> expandedSizes;
+    SmallVector<OpFoldResult> expandedStrides;
+    unsigned numGroups = op.getReassociationIndices().size();
+    expandedSizes.reserve(op.getResultType().getRank());
+    expandedStrides.reserve(op.getResultType().getRank());
+
+    for (unsigned i = 0; i < numGroups; ++i) {
+      SmallVector<OpFoldResult> groupSizes =
+          memref::getExpandedSizes(op, rewriter, origSizes, i);
+      SmallVector<OpFoldResult> groupStrides =
+          memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i);
+      expandedSizes.append(groupSizes.begin(), groupSizes.end());
+      expandedStrides.append(groupStrides.begin(), groupStrides.end());
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides);
+    return success();
+  }
+};
+
+
+// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
+struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
+    if (!sourceType || sourceType.getRank() <= 1)
+      return failure();
+    if (!checkLayout(sourceType))
+      return failure();
+
+    MemRefType resultType = op.getType();
+    if (resultType.getRank() <= 1 || !checkLayout(res...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/159841


More information about the Mlir-commits mailing list