[Mlir-commits] [mlir] [milr][memref]: Add control options to FoldMemrefAliasOps (PR #178405)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 28 04:03:58 PST 2026
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/178405
>From 008d01cff9b7cec821c625862227c269bf8570c9 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <fabrizio.indirli at arm.com>
Date: Wed, 28 Jan 2026 10:29:45 +0000
Subject: [PATCH] [milr][memref]: Add control options to FoldMemrefAliasOps
Add two optional arguments "excludedPatterns" & "controlFn" to
the constructor of the `FoldMemRefAliasOps` pass and to
`memref::populateFoldMemRefAliasOpPatterns()`, to allow:
- disabling specific patterns by name,
- passing a control function to decide if pattern should
ignore the current operation
Signed-off-by: Fabrizio Indirli <fabrizio.indirli at arm.com>
---
.../mlir/Dialect/MemRef/Transforms/Passes.h | 11 ++
.../Dialect/MemRef/Transforms/Transforms.h | 10 +-
.../MemRef/Transforms/FoldMemRefAliasOps.cpp | 113 ++++++++++++++----
.../MemRef/fold-memref-alias-ops-options.mlir | 34 ++++++
mlir/test/lib/Dialect/MemRef/CMakeLists.txt | 1 +
.../MemRef/TestFoldMemRefAliasOptions.cpp | 101 ++++++++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 249 insertions(+), 23 deletions(-)
create mode 100644 mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir
create mode 100644 mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 748248d45df26..633950e8d54dc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -14,6 +14,9 @@
#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/StringRef.h"
namespace mlir {
@@ -45,6 +48,14 @@ namespace memref {
#define GEN_PASS_DECL
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+/// Additional construction for FoldMemrefAliasOps to allow disabling
+/// patterns by name, and controlling folding via a callback function.
+/// `controlFn(Operation* userOp)` will be passed the user operation of the
+/// aliasing op (e.g., a load/store that uses the result of a memref.subview).
+std::unique_ptr<Pass> createFoldMemRefAliasOpsPass(
+ ArrayRef<StringRef> excludedPatterns,
+ function_ref<bool(Operation *)> controlFn = nullptr);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..91d84d1c1d9ff 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -21,6 +21,7 @@ namespace mlir {
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
+class Operation;
class Value;
class ValueRange;
class ReifyRankedShapedTypeOpInterface;
@@ -43,8 +44,13 @@ class DeallocOp;
void populateExpandOpsPatterns(RewritePatternSet &patterns);
/// Appends patterns for folding memref aliasing ops into consumer load/store
-/// ops into `patterns`.
-void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+/// ops into `patterns`. If `controlFn` is provided, each pattern invokes it and
+/// bails out when it returns false.
+/// `controlFn(Operation* userOp)` will be passed the user operation of the
+/// aliasing op (e.g., a load/store that uses the result of a memref.subview).
+void populateFoldMemRefAliasOpPatterns(
+ RewritePatternSet &patterns,
+ function_ref<bool(Operation *)> controlFn = nullptr);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3cacb7e29263b..50cfebbdc66ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -21,11 +21,14 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include <functional>
+#include <string>
#define DEBUG_TYPE "fold-memref-alias-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -82,11 +85,29 @@ static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
//===----------------------------------------------------------------------===//
namespace {
+using ControlFunction = std::function<bool(Operation *)>;
+
+template <typename OpTy>
+class FoldMemRefAliasPattern : public OpRewritePattern<OpTy> {
+public:
+ FoldMemRefAliasPattern(MLIRContext *context,
+ ControlFunction controlFn = ControlFunction())
+ : OpRewritePattern<OpTy>(context), controlFn(std::move(controlFn)) {}
+
+protected:
+ bool shouldRewrite(Operation *op) const {
+ return !controlFn || controlFn(op);
+ }
+
+private:
+ ControlFunction controlFn;
+};
+
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
-class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfSubViewOpFolder final : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
@@ -94,9 +115,9 @@ class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
/// Merges expand_shape operation with load/transferRead operation.
template <typename OpTy>
-class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
@@ -104,9 +125,10 @@ class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
/// Merges collapse_shape operation with load/transferRead operation.
template <typename OpTy>
-class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfCollapseShapeOpFolder final
+ : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
@@ -114,9 +136,9 @@ class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
/// Merges subview operation with store/transferWriteOp operation.
template <typename OpTy>
-class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfSubViewOpFolder final : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
@@ -124,9 +146,9 @@ class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
/// Merges expand_shape operation with store/transferWriteOp operation.
template <typename OpTy>
-class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
@@ -134,21 +156,26 @@ class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
/// Merges collapse_shape operation with store/transferWriteOp operation.
template <typename OpTy>
-class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfCollapseShapeOpFolder final
+ : public FoldMemRefAliasPattern<OpTy> {
public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
+ using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
};
/// Folds subview(subview(x)) to a single subview(x).
-class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
+class SubViewOfSubViewFolder
+ : public FoldMemRefAliasPattern<memref::SubViewOp> {
public:
- using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
+ using FoldMemRefAliasPattern<memref::SubViewOp>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(memref::SubViewOp subView,
PatternRewriter &rewriter) const override {
+ if (!this->shouldRewrite(subView))
+ return failure();
+
auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
if (!srcSubView)
return failure();
@@ -188,9 +215,10 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
/// is folds subview on src and dst memref of the copy.
class NVGPUAsyncCopyOpSubViewOpFolder final
- : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
+ : public FoldMemRefAliasPattern<nvgpu::DeviceAsyncCopyOp> {
public:
- using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
+ using FoldMemRefAliasPattern<
+ nvgpu::DeviceAsyncCopyOp>::FoldMemRefAliasPattern;
LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
PatternRewriter &rewriter) const override;
@@ -234,6 +262,8 @@ static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
template <typename OpTy>
LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(loadOp))
+ return failure();
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
@@ -290,6 +320,8 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
template <typename OpTy>
LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(loadOp))
+ return failure();
auto expandShapeOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
@@ -351,6 +383,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
template <typename OpTy>
LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(loadOp))
+ return failure();
auto collapseShapeOp = getMemRefOperand(loadOp)
.template getDefiningOp<memref::CollapseShapeOp>();
@@ -383,6 +417,8 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
template <typename OpTy>
LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(storeOp))
+ return failure();
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
@@ -435,6 +471,8 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
template <typename OpTy>
LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(storeOp))
+ return failure();
auto expandShapeOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
@@ -470,6 +508,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
template <typename OpTy>
LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(storeOp))
+ return failure();
auto collapseShapeOp = getMemRefOperand(storeOp)
.template getDefiningOp<memref::CollapseShapeOp>();
@@ -501,6 +541,8 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
+ if (!this->shouldRewrite(copyOp))
+ return failure();
LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
@@ -550,7 +592,12 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
return success();
}
-void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
+void memref::populateFoldMemRefAliasOpPatterns(
+ RewritePatternSet &patterns, function_ref<bool(Operation *)> controlFn) {
+ ControlFunction controlFnStorage;
+ if (controlFn)
+ controlFnStorage = controlFn;
+
patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
@@ -576,7 +623,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
- patterns.getContext());
+ patterns.getContext(), controlFnStorage);
}
//===----------------------------------------------------------------------===//
@@ -587,13 +634,37 @@ namespace {
struct FoldMemRefAliasOpsPass final
: public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
+ FoldMemRefAliasOpsPass() = default;
+ FoldMemRefAliasOpsPass(ArrayRef<StringRef> disabledPatterns,
+ function_ref<bool(Operation *)> controlFn = nullptr)
+ : disabledPatternNames(disabledPatterns.begin(), disabledPatterns.end()) {
+ if (controlFn)
+ controlFunction = controlFn;
+ }
+
void runOnOperation() override;
+
+private:
+ SmallVector<std::string> disabledPatternNames;
+ ControlFunction controlFunction;
};
} // namespace
void FoldMemRefAliasOpsPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- memref::populateFoldMemRefAliasOpPatterns(patterns);
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ RewritePatternSet owningPatterns(&getContext());
+ function_ref<bool(Operation *)> controlFnRef;
+ if (controlFunction)
+ controlFnRef = controlFunction;
+ memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef);
+
+ FrozenRewritePatternSet patterns(std::move(owningPatterns),
+ disabledPatternNames);
+ (void)applyPatternsGreedily(getOperation(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::memref::createFoldMemRefAliasOpsPass(
+ ArrayRef<StringRef> excludedPatterns,
+ function_ref<bool(Operation *)> controlFn) {
+ return std::make_unique<FoldMemRefAliasOpsPass>(excludedPatterns, controlFn);
}
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir
new file mode 100644
index 0000000000000..bd4edd9ed65eb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt --test-fold-memref-alias-options="exclude-pattern=load-subview" -split-input-file %s | FileCheck %s --check-prefix=EXCLUDE
+// RUN: mlir-opt --test-fold-memref-alias-options="control-attr=no_fold" -split-input-file %s | FileCheck %s --check-prefix=CONTROL
+
+// -----
+
+// Excluding the load-subview pattern keeps the subview + load untouched.
+func.func @exclude_load_subview(%arg0: memref<4xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>>
+ %v = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>>
+ return %v : f32
+}
+
+// EXCLUDE-LABEL: func.func @exclude_load_subview
+// EXCLUDE: %[[SV:.*]] = memref.subview
+// EXCLUDE: memref.load %[[SV]]
+// EXCLUDE-NOT: memref.load %arg0
+
+// -----
+
+// Control callback rejects ops carrying the attribute; the plain load is still
+// folded through the subview.
+func.func @control_attr(%arg0: memref<4xf32>) -> (f32, f32) {
+ %c0 = arith.constant 0 : index
+ %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>>
+ %blocked = memref.load %sv[%c0] {no_fold} : memref<4xf32, strided<[1], offset: 0>>
+ %folded = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>>
+ return %blocked, %folded : f32, f32
+}
+
+// CONTROL-LABEL: func.func @control_attr
+// CONTROL: %[[SV:.*]] = memref.subview
+// CONTROL: %[[A:.*]] = memref.load %[[SV]][%c0] {no_fold}
+// CONTROL: %[[B:.*]] = memref.load %arg0[%c0]
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index 39457ab2d0bf7..4a707f719a317 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRMemRefTestPasses
TestComposeSubView.cpp
TestEmulateNarrowType.cpp
+ TestFoldMemRefAliasOptions.cpp
TestMultiBuffer.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp
new file mode 100644
index 0000000000000..aa4d366262ffa
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp
@@ -0,0 +1,101 @@
+//===- TestFoldMemRefAliasOptions.cpp - Test FoldMemRefAlias options ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a test pass to exercise the optional arguments of
+// FoldMemRefAliasOps (excluded patterns and control callback).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/TypeName.h"
+
+using namespace mlir;
+
+namespace {
+struct TestFoldMemRefAliasOptionsPass
+ : public PassWrapper<TestFoldMemRefAliasOptionsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldMemRefAliasOptionsPass)
+
+ TestFoldMemRefAliasOptionsPass() = default;
+ TestFoldMemRefAliasOptionsPass(const TestFoldMemRefAliasOptionsPass &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const final {
+ return "test-fold-memref-alias-options";
+ }
+ StringRef getDescription() const final {
+ return "Test FoldMemRefAliasOps optional arguments";
+ }
+
+ ListOption<std::string> excludedPatternTokens{
+ *this, "exclude-pattern",
+ llvm::cl::desc("Comma-separated tokens to exclude certain patterns "
+ "(e.g., load-subview)")};
+ Option<std::string> controlAttr{
+ *this, "control-attr",
+ llvm::cl::desc(
+ "Attribute name that disables rewrites when present on the "
+ "matched operation"),
+ llvm::cl::init("")};
+
+ void runOnOperation() override;
+};
+
+void TestFoldMemRefAliasOptionsPass::runOnOperation() {
+ // Custom version of "FoldMemRefAliasOps" to test its options, by:
+ // 1) Excluding patterns that fold memref.subview into load ops
+ // 2) Ignoring user ops that have a specific attribute.
+
+ // Map friendly tokens to concrete pattern names expected by the exclusion
+ // mechanism.
+ SmallVector<std::string> disabledPatternNames;
+ if (llvm::is_contained(excludedPatternTokens, "load-subview")) {
+ // Resolve pattern debug names from a populated set.
+ RewritePatternSet patternsSet(&getContext());
+ memref::populateFoldMemRefAliasOpPatterns(patternsSet);
+ for (auto &pattern : patternsSet.getNativePatterns()) {
+ std::optional<OperationName> rootKind = pattern->getRootKind();
+ if (rootKind &&
+ rootKind->getStringRef() == memref::LoadOp::getOperationName()) {
+ disabledPatternNames.push_back(pattern->getDebugName().str());
+ break;
+ }
+ }
+ }
+
+ std::function<bool(Operation *)> controlFnStorage;
+ function_ref<bool(Operation *)> controlFnRef;
+ if (!controlAttr.empty()) {
+ StringAttr attrName = StringAttr::get(&getContext(), controlAttr);
+ controlFnStorage = [attrName](Operation *op) {
+ return !op->hasAttr(attrName);
+ };
+ controlFnRef = controlFnStorage;
+ }
+
+ RewritePatternSet owningPatterns(&getContext());
+ memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef);
+ FrozenRewritePatternSet patterns(std::move(owningPatterns),
+ disabledPatternNames);
+ (void)applyPatternsGreedily(getOperation(), patterns);
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestFoldMemRefAliasOptionsPass() {
+ PassRegistration<TestFoldMemRefAliasOptionsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a427132247e6d..d674114d4a18b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -133,6 +133,7 @@ void registerTestMemRefToLLVMWithTransforms();
void registerTestReshardingPartitionPass();
void registerTestShardSimplificationsPass();
void registerTestMultiBuffering();
+void registerTestFoldMemRefAliasOptionsPass();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
void registerTestOpenACC();
@@ -248,6 +249,7 @@ static void registerTestPasses() {
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestRemarkPass();
+ mlir::test::registerTestFoldMemRefAliasOptionsPass();
mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();
More information about the Mlir-commits
mailing list