[Mlir-commits] [mlir] [milr][memref]: Add control options to FoldMemrefAliasOps (PR #178405)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 28 03:53:13 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (fabrizio-indirli)

<details>
<summary>Changes</summary>

Add two optional arguments "excludedPatterns" & "controlFn" to the constructor of the `FoldMemRefAliasOps` pass and to `memref::populateFoldMemRefAliasOpPatterns()`, to allow:
- disabling specific patterns by name,
- passing a control function to decide if pattern should ignore the current operation

---

Patch is 20.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178405.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h (+7) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+8-2) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+92-21) 
- (added) mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir (+34) 
- (modified) mlir/test/lib/Dialect/MemRef/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp (+98) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 748248d45df26..daa790a4d78b3 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -14,6 +14,9 @@
 #define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H
 
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 
@@ -45,6 +48,10 @@ namespace memref {
 #define GEN_PASS_DECL
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 
+std::unique_ptr<Pass> createFoldMemRefAliasOpsPass(
+    ArrayRef<StringRef> excludedPatterns,
+    function_ref<bool(Operation *)> controlFn = nullptr);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 8b76930aed35a..91d84d1c1d9ff 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -21,6 +21,7 @@ namespace mlir {
 class OpBuilder;
 class RewritePatternSet;
 class RewriterBase;
+class Operation;
 class Value;
 class ValueRange;
 class ReifyRankedShapedTypeOpInterface;
@@ -43,8 +44,13 @@ class DeallocOp;
 void populateExpandOpsPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns for folding memref aliasing ops into consumer load/store
-/// ops into `patterns`.
-void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+/// ops into `patterns`. If `controlFn` is provided, each pattern invokes it and
+/// bails out when it returns false.
+/// `controlFn(Operation* userOp)` will be passed the user operation of the
+/// aliasing op (e.g., a load/store that uses the result of a memref.subview).
+void populateFoldMemRefAliasOpPatterns(
+    RewritePatternSet &patterns,
+    function_ref<bool(Operation *)> controlFn = nullptr);
 
 /// Appends patterns that resolve `memref.dim` operations with values that are
 /// defined by operations that implement the
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3cacb7e29263b..50cfebbdc66ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -21,11 +21,14 @@
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include <functional>
+#include <string>
 
 #define DEBUG_TYPE "fold-memref-alias-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -82,11 +85,29 @@ static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
 //===----------------------------------------------------------------------===//
 
 namespace {
+using ControlFunction = std::function<bool(Operation *)>;
+
+template <typename OpTy>
+class FoldMemRefAliasPattern : public OpRewritePattern<OpTy> {
+public:
+  FoldMemRefAliasPattern(MLIRContext *context,
+                         ControlFunction controlFn = ControlFunction())
+      : OpRewritePattern<OpTy>(context), controlFn(std::move(controlFn)) {}
+
+protected:
+  bool shouldRewrite(Operation *op) const {
+    return !controlFn || controlFn(op);
+  }
+
+private:
+  ControlFunction controlFn;
+};
+
 /// Merges subview operation with load/transferRead operation.
 template <typename OpTy>
-class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfSubViewOpFolder final : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy loadOp,
                                 PatternRewriter &rewriter) const override;
@@ -94,9 +115,9 @@ class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
 
 /// Merges expand_shape operation with load/transferRead operation.
 template <typename OpTy>
-class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy loadOp,
                                 PatternRewriter &rewriter) const override;
@@ -104,9 +125,10 @@ class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
 
 /// Merges collapse_shape operation with load/transferRead operation.
 template <typename OpTy>
-class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+class LoadOpOfCollapseShapeOpFolder final
+    : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy loadOp,
                                 PatternRewriter &rewriter) const override;
@@ -114,9 +136,9 @@ class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
 
 /// Merges subview operation with store/transferWriteOp operation.
 template <typename OpTy>
-class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfSubViewOpFolder final : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy storeOp,
                                 PatternRewriter &rewriter) const override;
@@ -124,9 +146,9 @@ class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
 
 /// Merges expand_shape operation with store/transferWriteOp operation.
 template <typename OpTy>
-class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy storeOp,
                                 PatternRewriter &rewriter) const override;
@@ -134,21 +156,26 @@ class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
 
 /// Merges collapse_shape operation with store/transferWriteOp operation.
 template <typename OpTy>
-class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+class StoreOpOfCollapseShapeOpFolder final
+    : public FoldMemRefAliasPattern<OpTy> {
 public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
+  using FoldMemRefAliasPattern<OpTy>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(OpTy storeOp,
                                 PatternRewriter &rewriter) const override;
 };
 
 /// Folds subview(subview(x)) to a single subview(x).
-class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
+class SubViewOfSubViewFolder
+    : public FoldMemRefAliasPattern<memref::SubViewOp> {
 public:
-  using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
+  using FoldMemRefAliasPattern<memref::SubViewOp>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(memref::SubViewOp subView,
                                 PatternRewriter &rewriter) const override {
+    if (!this->shouldRewrite(subView))
+      return failure();
+
     auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
     if (!srcSubView)
       return failure();
@@ -188,9 +215,10 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
 /// is folds subview on src and dst memref of the copy.
 class NVGPUAsyncCopyOpSubViewOpFolder final
-    : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
+    : public FoldMemRefAliasPattern<nvgpu::DeviceAsyncCopyOp> {
 public:
-  using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
+  using FoldMemRefAliasPattern<
+      nvgpu::DeviceAsyncCopyOp>::FoldMemRefAliasPattern;
 
   LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
                                 PatternRewriter &rewriter) const override;
@@ -234,6 +262,8 @@ static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
 template <typename OpTy>
 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     OpTy loadOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(loadOp))
+    return failure();
   auto subViewOp =
       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
 
@@ -290,6 +320,8 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
 template <typename OpTy>
 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     OpTy loadOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(loadOp))
+    return failure();
   auto expandShapeOp =
       getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
 
@@ -351,6 +383,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
 template <typename OpTy>
 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
     OpTy loadOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(loadOp))
+    return failure();
   auto collapseShapeOp = getMemRefOperand(loadOp)
                              .template getDefiningOp<memref::CollapseShapeOp>();
 
@@ -383,6 +417,8 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
 template <typename OpTy>
 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     OpTy storeOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(storeOp))
+    return failure();
   auto subViewOp =
       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
 
@@ -435,6 +471,8 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
 template <typename OpTy>
 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     OpTy storeOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(storeOp))
+    return failure();
   auto expandShapeOp =
       getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
 
@@ -470,6 +508,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
 template <typename OpTy>
 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
     OpTy storeOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(storeOp))
+    return failure();
   auto collapseShapeOp = getMemRefOperand(storeOp)
                              .template getDefiningOp<memref::CollapseShapeOp>();
 
@@ -501,6 +541,8 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
 
 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
     nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
+  if (!this->shouldRewrite(copyOp))
+    return failure();
 
   LLVM_DEBUG(DBGS() << "copyOp       : " << copyOp << "\n");
 
@@ -550,7 +592,12 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
   return success();
 }
 
-void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
+void memref::populateFoldMemRefAliasOpPatterns(
+    RewritePatternSet &patterns, function_ref<bool(Operation *)> controlFn) {
+  ControlFunction controlFnStorage;
+  if (controlFn)
+    controlFnStorage = controlFn;
+
   patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
                LoadOpOfSubViewOpFolder<vector::LoadOp>,
@@ -576,7 +623,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
                StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
                StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
                SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
-      patterns.getContext());
+      patterns.getContext(), controlFnStorage);
 }
 
 //===----------------------------------------------------------------------===//
@@ -587,13 +634,37 @@ namespace {
 
 struct FoldMemRefAliasOpsPass final
     : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
+  FoldMemRefAliasOpsPass() = default;
+  FoldMemRefAliasOpsPass(ArrayRef<StringRef> disabledPatterns,
+                         function_ref<bool(Operation *)> controlFn = nullptr)
+      : disabledPatternNames(disabledPatterns.begin(), disabledPatterns.end()) {
+    if (controlFn)
+      controlFunction = controlFn;
+  }
+
   void runOnOperation() override;
+
+private:
+  SmallVector<std::string> disabledPatternNames;
+  ControlFunction controlFunction;
 };
 
 } // namespace
 
 void FoldMemRefAliasOpsPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  memref::populateFoldMemRefAliasOpPatterns(patterns);
-  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  RewritePatternSet owningPatterns(&getContext());
+  function_ref<bool(Operation *)> controlFnRef;
+  if (controlFunction)
+    controlFnRef = controlFunction;
+  memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef);
+
+  FrozenRewritePatternSet patterns(std::move(owningPatterns),
+                                   disabledPatternNames);
+  (void)applyPatternsGreedily(getOperation(), patterns);
+}
+
+std::unique_ptr<Pass> mlir::memref::createFoldMemRefAliasOpsPass(
+    ArrayRef<StringRef> excludedPatterns,
+    function_ref<bool(Operation *)> controlFn) {
+  return std::make_unique<FoldMemRefAliasOpsPass>(excludedPatterns, controlFn);
 }
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir
new file mode 100644
index 0000000000000..bd4edd9ed65eb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt --test-fold-memref-alias-options="exclude-pattern=load-subview" -split-input-file %s | FileCheck %s --check-prefix=EXCLUDE
+// RUN: mlir-opt --test-fold-memref-alias-options="control-attr=no_fold" -split-input-file %s | FileCheck %s --check-prefix=CONTROL
+
+// -----
+
+// Excluding the load-subview pattern keeps the subview + load untouched.
+func.func @exclude_load_subview(%arg0: memref<4xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>>
+  %v = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>>
+  return %v : f32
+}
+
+// EXCLUDE-LABEL: func.func @exclude_load_subview
+// EXCLUDE: %[[SV:.*]] = memref.subview
+// EXCLUDE: memref.load %[[SV]]
+// EXCLUDE-NOT: memref.load %arg0
+
+// -----
+
+// Control callback rejects ops carrying the attribute; the plain load is still
+// folded through the subview.
+func.func @control_attr(%arg0: memref<4xf32>) -> (f32, f32) {
+  %c0 = arith.constant 0 : index
+  %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>>
+  %blocked = memref.load %sv[%c0] {no_fold} : memref<4xf32, strided<[1], offset: 0>>
+  %folded = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>>
+  return %blocked, %folded : f32, f32
+}
+
+// CONTROL-LABEL: func.func @control_attr
+// CONTROL: %[[SV:.*]] = memref.subview
+// CONTROL: %[[A:.*]] = memref.load %[[SV]][%c0] {no_fold}
+// CONTROL: %[[B:.*]] = memref.load %arg0[%c0]
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
index 39457ab2d0bf7..4a707f719a317 100644
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRMemRefTestPasses
   TestComposeSubView.cpp
   TestEmulateNarrowType.cpp
+  TestFoldMemRefAliasOptions.cpp
   TestMultiBuffer.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp
new file mode 100644
index 0000000000000..84c4c8161644b
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp
@@ -0,0 +1,98 @@
+//===- TestFoldMemRefAliasOptions.cpp - Test FoldMemRefAlias options ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a test pass to exercise the optional arguments of
+// FoldMemRefAliasOps (excluded patterns and control callback).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/TypeName.h"
+
+using namespace mlir;
+
+namespace {
+struct TestFoldMemRefAliasOptionsPass
+    : public PassWrapper<TestFoldMemRefAliasOptionsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldMemRefAliasOptionsPass)
+
+  TestFoldMemRefAliasOptionsPass() = default;
+  TestFoldMemRefAliasOptionsPass(const TestFoldMemRefAliasOptionsPass &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final {
+    return "test-fold-memref-alias-options";
+  }
+  StringRef getDescription() const final {
+    return "Test FoldMemRefAliasOps optional arguments";
+  }
+
+  ListOption<std::string> excludedPatternTokens{
+      *this, "exclude-pattern",
+      llvm::cl::desc("Comma-separated tokens to exclude certain patterns "
+                     "(e.g., load-subview)")};
+  Option<std::string> controlAttr{
+      *this, "control-attr",
+      llvm::cl::desc(
+          "Attribute name that disables rewrites when present on the "
+          "matched operation"),
+      llvm::cl::init("")};
+
+  void runOnOperation() override;
+};
+
+void TestFoldMemRefAliasOptionsPass::runOnOperation() {
+  // Map friendly tokens to concrete pattern names expected by the exclusion
+  // mechanism.
+  SmallVector<std::string> disabledPatternNames;
+  if (llvm::is_contained(excludedPatternTokens, "load-subview")) {
+    // Resolve pattern debug names from a populated set so we don't rely on type
+    // names leaking from another translation unit.
+    RewritePatternSet probe(&getContext());
+    memref::populateFoldMemRefAliasOpPatterns(probe);
+    for (auto &pattern : probe.getNativePatterns()) {
+      std::optional<OperationName> rootKind = pattern->getRootKind();
+      if (rootKind &&
+          rootKind->getStringRef() == memref::LoadOp::getOperationName()) {
+        disabledPatternNames.push_back(pattern->getDebugName().str());
+        break;
+      }
+    }
+  }
+
+  std::function<bool(Operation *)> controlFnStorage;
+  function_ref<bool(Operation *)> controlFnRef;
+  if (!controlAttr.empty()) {
+    StringAttr attrName = StringAttr::get(&getContext(), controlAttr);
+    controlFnStorage = [attrName](Operation *op) {
+      return !op->hasAttr(attrName);
+    };
+    controlFnRef = controlFnStorage;
+  }
+
+  RewritePatternSet owningPatterns(&getContext());
+  memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef);
+  FrozenRewritePatternSet patterns(std::move(owningPatterns),
+                                   disabledPatternNames);
+  (void)applyPatternsGreedily(getOperation(), patterns);
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestFoldMemRefAliasOptionsPass() {
+  PassRegistration<TestFoldMemRefAliasOptionsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a427132247e...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/178405


More information about the Mlir-commits mailing list