[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 19 13:07:06 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 562b8c112..d40cb5ee8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,8 +18,8 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 class Location;
@@ -236,9 +236,10 @@ memref::AllocaOp allocToAlloca(
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId);
 
 /// Compute the expanded strides of the given \p expandShape for the
 /// \p groupId-th reassociation group.
@@ -277,11 +278,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
 ///
 /// \pre for all index in indices: index < values.size()
 /// \pre for all index in indices: index < maybeConstants.size()
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
-                   ArrayRef<int64_t> maybeConstants,
-                   ArrayRef<OpFoldResult> values,
-                   llvm::function_ref<bool(int64_t)> isDynamic);
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+                                Location loc, ArrayRef<int64_t> maybeConstants,
+                                ArrayRef<OpFoldResult> values,
+                                llvm::function_ref<bool(int64_t)> isDynamic);
 
 /// Compute the collapsed size of the given \p collapseShape for the
 /// \p groupId-th reassociation group.
@@ -291,9 +291,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId);
 
 /// Compute the collapsed stride of the given \p collpaseShape for the
 /// \p groupId-th reassociation group.
@@ -307,10 +308,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
-                   ArrayRef<OpFoldResult> origSizes,
-                   ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId);
 
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 6b69d0e36..96bceae88 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -21,10 +21,10 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "mlir/IR/OpDefinition.h"
 #include <optional>
 
 namespace mlir {
@@ -256,9 +256,10 @@ struct ExtractStridedMetadataOpSubviewFolder
 
 namespace mlir {
 namespace memref {
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       expandShape.getReassociationIndices()[groupId];
   assert(!reassocGroup.empty() &&
@@ -372,11 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
   return expandedStrides;
 }
 
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
-                   ArrayRef<int64_t> maybeConstants,
-                   ArrayRef<OpFoldResult> values,
-                   llvm::function_ref<bool(int64_t)> isDynamic) {
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+                                Location loc, ArrayRef<int64_t> maybeConstants,
+                                ArrayRef<OpFoldResult> values,
+                                llvm::function_ref<bool(int64_t)> isDynamic) {
   AffineExpr productOfValues = builder.getAffineConstantExpr(1);
   SmallVector<OpFoldResult> inputValues;
   unsigned numberOfSymbols = 0;
@@ -410,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
 /// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
-                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+                                           OpBuilder &builder,
+                                           ArrayRef<OpFoldResult> origSizes,
+                                           unsigned groupId) {
   SmallVector<OpFoldResult> collapsedSize;
 
   MemRefType collapseShapeType = collapseShape.getResultType();
@@ -451,10 +452,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
 ///
 /// \post result.size() == 1, in other words, each group collapse to one
 /// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
-                   ArrayRef<OpFoldResult> origSizes,
-                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId) {
   SmallVector<int64_t, 2> reassocGroup =
       collapseShape.getReassociationIndices()[groupId];
   assert(!reassocGroup.empty() &&
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 43a67f1fa..ea401c212 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,9 +21,9 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
   return cast<Value>(in);
 }
 
-
 /// Returns a collapsed memref and the linearized index to access the element
 /// at the specified indices.
 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
     return {};
   }
 
-  LogicalResult
-  matchAndRewrite(memref::GlobalOp globalOp,
-                  PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+                                PatternRewriter &rewriter) const override {
     auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
     if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
       return failure();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
 
     SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
     OpFoldResult offset = metadata.getConstifiedMixedOffset();
 
     SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
   }
 };
 
-struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+struct FlattenExpandShape final
+    : public OpRewritePattern<memref::ExpandShapeOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
 
     SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+    SmallVector<OpFoldResult> origStrides =
+        metadata.getConstifiedMixedStrides();
     OpFoldResult offset = metadata.getConstifiedMixedOffset();
 
     SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
   }
 };
 
-
 // Flattens memref subview ops with more than 1 dimension into 1-D accesses.
 struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
     for (OpFoldResult ofr : mixedOffsets)
       offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
 
-    auto [flatSource, linearOffset] =
-        getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
-                                  ValueRange(offsetValues));
+    auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+        rewriter, loc, op.getSource(), ValueRange(offsetValues));
 
     memref::ExtractStridedMetadataOp sourceMetadata =
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
@@ -449,7 +448,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
           }
         }
         Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
-        Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+        Value contribVal =
+            getValueFromOpFoldResult(rewriter, loc, contribution);
         return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
             .getResult();
       }();
@@ -478,12 +478,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     [[maybe_unused]] OpFoldResult linearizedIndex;
     std::tie(linearizedInfo, linearizedIndex) =
-        memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
-            resultSizes, resultStrides);
+        memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+                                                 elementBitWidth, resultOffset,
+                                                 resultSizes, resultStrides);
 
-    Value flattenedSize = getValueFromOpFoldResult(
-        rewriter, loc, linearizedInfo.linearizedSize);
+    Value flattenedSize =
+        getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
     Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
 
     Value flattenedSubview = memref::SubViewOp::create(
@@ -524,10 +524,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::GetGlobalOp op,
-                               PatternRewriter &rewriter) const override {
+                                PatternRewriter &rewriter) const override {
     // Check if this get_global references a multi-dimensional global
     auto module = op->template getParentOfType<ModuleOp>();
-    auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+    auto globalOp =
+        module.template lookupSymbol<memref::GlobalOp>(op.getName());
     if (!globalOp) {
       return failure();
     }
@@ -537,12 +538,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
 
     // Only apply if the global has been flattened but the get_global hasn't
     if (globalType.getRank() == 1 && resultType.getRank() > 1) {
-      auto newGetGlobal = memref::GetGlobalOp::create(
-          rewriter, op.getLoc(), globalType, op.getName());
+      auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(),
+                                                      globalType, op.getName());
 
       // Cast the flattened result back to the original shape
       memref::ExtractStridedMetadataOp stridedMetadata =
-          memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+          memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(),
+                                                   op.getResult());
       auto castResult = memref::ReinterpretCastOp::create(
           rewriter, op.getLoc(), resultType, newGetGlobal,
           /*offset=*/rewriter.getIndexAttr(0),
@@ -572,13 +574,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<memref::StoreOp>,
                   MemRefRewritePattern<memref::AllocOp>,
                   MemRefRewritePattern<memref::AllocaOp>,
-                  MemRefRewritePattern<memref::DeallocOp>,
-                  FlattenExpandShape,
-                  FlattenCollapseShape,
-                  FlattenSubView,
-                  FlattenGetGlobal,
-                  FlattenGlobal>(
-      patterns.getContext());
+                  MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
+                  FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
+                  FlattenGlobal>(patterns.getContext());
 }
 
 void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/159841


More information about the Mlir-commits mailing list