[Mlir-commits] [mlir] [MLIR] Add more ops support for flattening memref operands (PR #159841)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 19 13:07:06 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 562b8c112..d40cb5ee8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -18,8 +18,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallVector.h"
namespace mlir {
class Location;
@@ -236,9 +236,10 @@ memref::AllocaOp allocToAlloca(
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId);
/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
@@ -277,11 +278,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
///
/// \pre for all index in indices: index < values.size()
/// \pre for all index in indices: index < maybeConstants.size()
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
- ArrayRef<int64_t> maybeConstants,
- ArrayRef<OpFoldResult> values,
- llvm::function_ref<bool(int64_t)> isDynamic);
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+ Location loc, ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic);
/// Compute the collapsed size of the given \p collapseShape for the
/// \p groupId-th reassociation group.
@@ -291,9 +291,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId);
/// Compute the collapsed stride of the given \p collpaseShape for the
/// \p groupId-th reassociation group.
@@ -307,10 +308,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes,
- ArrayRef<OpFoldResult> origStrides, unsigned groupId);
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 6b69d0e36..96bceae88 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -21,10 +21,10 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
@@ -256,9 +256,10 @@ struct ExtractStridedMetadataOpSubviewFolder
namespace mlir {
namespace memref {
-SmallVector<OpFoldResult>
-getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getExpandedSizes(ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
@@ -372,11 +373,10 @@ SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
return expandedStrides;
}
-OpFoldResult
-getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
- ArrayRef<int64_t> maybeConstants,
- ArrayRef<OpFoldResult> values,
- llvm::function_ref<bool(int64_t)> isDynamic) {
+OpFoldResult getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder,
+ Location loc, ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
AffineExpr productOfValues = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> inputValues;
unsigned numberOfSymbols = 0;
@@ -410,9 +410,10 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
-SmallVector<OpFoldResult>
-getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedSize(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ unsigned groupId) {
SmallVector<OpFoldResult> collapsedSize;
MemRefType collapseShapeType = collapseShape.getResultType();
@@ -451,10 +452,11 @@ getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
-SmallVector<OpFoldResult>
-getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
- ArrayRef<OpFoldResult> origSizes,
- ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+SmallVector<OpFoldResult> getCollapsedStride(CollapseShapeOp collapseShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 43a67f1fa..ea401c212 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -21,9 +21,9 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -48,7 +48,6 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
return cast<Value>(in);
}
-
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
@@ -281,9 +280,8 @@ struct FlattenGlobal final : public OpRewritePattern<memref::GlobalOp> {
return {};
}
- LogicalResult
- matchAndRewrite(memref::GlobalOp globalOp,
- PatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewrite(memref::GlobalOp globalOp,
+ PatternRewriter &rewriter) const override {
auto oldType = llvm::dyn_cast<MemRefType>(globalOp.getType());
if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1)
return failure();
@@ -314,7 +312,8 @@ struct FlattenCollapseShape final
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
OpFoldResult offset = metadata.getConstifiedMixedOffset();
SmallVector<OpFoldResult> collapsedSizes;
@@ -338,7 +337,8 @@ struct FlattenCollapseShape final
}
};
-struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp> {
+struct FlattenExpandShape final
+ : public OpRewritePattern<memref::ExpandShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExpandShapeOp op,
@@ -348,7 +348,8 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc());
SmallVector<OpFoldResult> origSizes = metadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> origStrides = metadata.getConstifiedMixedStrides();
+ SmallVector<OpFoldResult> origStrides =
+ metadata.getConstifiedMixedStrides();
OpFoldResult offset = metadata.getConstifiedMixedOffset();
SmallVector<OpFoldResult> expandedSizes;
@@ -372,7 +373,6 @@ struct FlattenExpandShape final : public OpRewritePattern<memref::ExpandShapeOp>
}
};
-
// Flattens memref subview ops with more than 1 dimension into 1-D accesses.
struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
@@ -405,9 +405,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
for (OpFoldResult ofr : mixedOffsets)
offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr));
- auto [flatSource, linearOffset] =
- getFlattenMemrefAndOffset(rewriter, loc, op.getSource(),
- ValueRange(offsetValues));
+ auto [flatSource, linearOffset] = getFlattenMemrefAndOffset(
+ rewriter, loc, op.getSource(), ValueRange(offsetValues));
memref::ExtractStridedMetadataOp sourceMetadata =
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource());
@@ -449,7 +448,8 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
}
}
Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset);
- Value contribVal = getValueFromOpFoldResult(rewriter, loc, contribution);
+ Value contribVal =
+ getValueFromOpFoldResult(rewriter, loc, contribution);
return rewriter.create<arith::AddIOp>(loc, offsetVal, contribVal)
.getResult();
}();
@@ -478,12 +478,12 @@ struct FlattenSubView final : public OpRewritePattern<memref::SubViewOp> {
memref::LinearizedMemRefInfo linearizedInfo;
[[maybe_unused]] OpFoldResult linearizedIndex;
std::tie(linearizedInfo, linearizedIndex) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, elementBitWidth, elementBitWidth, resultOffset,
- resultSizes, resultStrides);
+ memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
+ elementBitWidth, resultOffset,
+ resultSizes, resultStrides);
- Value flattenedSize = getValueFromOpFoldResult(
- rewriter, loc, linearizedInfo.linearizedSize);
+ Value flattenedSize =
+ getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize);
Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value flattenedSubview = memref::SubViewOp::create(
@@ -524,10 +524,11 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::GetGlobalOp op,
- PatternRewriter &rewriter) const override {
+ PatternRewriter &rewriter) const override {
// Check if this get_global references a multi-dimensional global
auto module = op->template getParentOfType<ModuleOp>();
- auto globalOp = module.template lookupSymbol<memref::GlobalOp>(op.getName());
+ auto globalOp =
+ module.template lookupSymbol<memref::GlobalOp>(op.getName());
if (!globalOp) {
return failure();
}
@@ -537,12 +538,13 @@ struct FlattenGetGlobal : public OpRewritePattern<memref::GetGlobalOp> {
// Only apply if the global has been flattened but the get_global hasn't
if (globalType.getRank() == 1 && resultType.getRank() > 1) {
- auto newGetGlobal = memref::GetGlobalOp::create(
- rewriter, op.getLoc(), globalType, op.getName());
+ auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(),
+ globalType, op.getName());
// Cast the flattened result back to the original shape
memref::ExtractStridedMetadataOp stridedMetadata =
- memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), op.getResult());
+ memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(),
+ op.getResult());
auto castResult = memref::ReinterpretCastOp::create(
rewriter, op.getLoc(), resultType, newGetGlobal,
/*offset=*/rewriter.getIndexAttr(0),
@@ -572,13 +574,9 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<memref::StoreOp>,
MemRefRewritePattern<memref::AllocOp>,
MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<memref::DeallocOp>,
- FlattenExpandShape,
- FlattenCollapseShape,
- FlattenSubView,
- FlattenGetGlobal,
- FlattenGlobal>(
- patterns.getContext());
+ MemRefRewritePattern<memref::DeallocOp>, FlattenExpandShape,
+ FlattenCollapseShape, FlattenSubView, FlattenGetGlobal,
+ FlattenGlobal>(patterns.getContext());
}
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/159841
More information about the Mlir-commits
mailing list