[Mlir-commits] [mlir] [mlir][affine] Improve --affine-scalrep to identify reduction variables (PR #138005)

Clément Fournier llvmlistbot at llvm.org
Wed Apr 30 11:04:52 PDT 2025


https://github.com/oowekyala created https://github.com/llvm/llvm-project/pull/138005

Note: this is a reopening of #118987 which I inadvertently closed.

---

Improve the affine scalar replacement pass to identify memref accesses that are used as a reduction variable, and turn them into `iter_args` variables. For instance in:
```mlir
%x = memref.alloc(): memref<10x10xf32>
%min = memref.alloc(): memref<10xf32>
// initialize %min
affine.for %i = 0 to 10 {
   affine.for %j = 0 to 10 {
      %0 = memref.load %min[%i]: memref<10xf32>
      %1 = memref.load %x[%i, %j]: memref<10x10xf32>
      %2 = arith.minimumf %0, %1: f32
      memref.store %2, %min[%i] : memref<10xf32>
   }
}
```
the load/store pattern on `%min` in the inner loop is characteristic of a reduction. The memory location `%min[%i]` is invariant on the inner loop induction var, so it is basically used as a scalar. We can rewrite this loop to the following:
```mlir
%x = memref.alloc(): memref<10x10xf32>
%min = memref.alloc(): memref<10xf32>
// initialize %min
affine.for %i = 0 to 10 {
  %0 = memref.load %min[%i]: memref<10xf32>
  %1 = affine.for %j = 0 to 10 iter_args(%acc = %0) -> f32 {
    %2 = memref.load %x[%i, %j]: memref<10x10xf32>
    %3 = arith.minimumf %acc, %2: f32
    affine.yield %3 : f32
  }
  memref.store %1, %min[%i] : memref<10xf32>
}
```
where this memory location is "scalarized" as an `iter_args` variable. This allows existing affine passes to apply more optimizations on the reduction loop, eg, it can be vectorized, or it can be turned into an `affine.parallel` loop with a combiner for the reduction.

This kind of code pattern is often found in the affine loops generated from linalg code, so I think it's very useful to include this.

I expect maybe some backlash over why I put this into the scalar replacement pass instead of a new pass. I think this is justfied because
1. This transformation moves some loads and stores out of the loop, and these may be forwardable by the existing scalar replacement transformations. Conversely maybe forwarding some loads and stores frees up some dependencies that make this new loop rewriting pattern applicable. So to me those transformation are tightly related, and maybe they should even be put into a fixed-point loop within the scalrep pass.
2. This transformation effectively replaces buffer accesses by a scalar `iter_args` variable. So even if it seems unrelated to the load-store forwardings that the pass is currently doing, I think it still fits within the scope of `--affine-scalrep`.

Thanks for reading!


>From 249e3135602cc076bf77f7a68bdc1fb2d4ee9d33 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 28 Oct 2024 16:51:26 +0100
Subject: [PATCH 01/23] Add --affine-raise-from-memref

Restrict isValidDim to induction vars, and not iter_args
---
 mlir/include/mlir/Dialect/Affine/Passes.h     |  10 ++
 mlir/include/mlir/Dialect/Affine/Passes.td    |  12 ++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  13 +-
 .../Dialect/Affine/Transforms/CMakeLists.txt  |   1 +
 .../Affine/Transforms/DecomposeAffineOps.cpp  |  11 ++
 .../Affine/Transforms/RaiseMemrefDialect.cpp  | 168 ++++++++++++++++++
 mlir/test/Dialect/Affine/raise-memref.mlir    | 130 ++++++++++++++
 7 files changed, 339 insertions(+), 6 deletions(-)
 create mode 100644 mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
 create mode 100644 mlir/test/Dialect/Affine/raise-memref.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 96bd3c6a9a7bc..2f70f24dd3ef2 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -23,6 +23,9 @@ namespace mlir {
 namespace func {
 class FuncOp;
 } // namespace func
+namespace memref {
+class MemRefDialect;
+} // namespace memref
 
 namespace affine {
 class AffineForOp;
@@ -45,6 +48,13 @@ createSimplifyAffineStructuresPass();
 std::unique_ptr<OperationPass<func::FuncOp>>
 createAffineLoopInvariantCodeMotionPass();
 
+/// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel
+/// ops.
+std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+
+/// Creates a pass that converts some memref operators to affine operators.
+std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+
 /// Apply normalization transformations to affine loop-like ops. If
 /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
 /// loop is replaced by its loop body).
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 0b8d5b7d94861..d47569968d901 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -396,6 +396,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
   let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
 }
 
+def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+  let summary = "Turn some memref operators to affine operators where supported";
+  let description = [{
+    Raise memref.load and memref.store to affine.store and affine.load, inferring
+    the affine map of those operators if needed. This allows passes like --affine-scalrep
+    to optimize those loads and stores (forwarding them or eliminating them).
+    They can be turned back to memref dialect ops with --lower-affine.
+  }];
+  let constructor = "mlir::affine::createRaiseMemrefToAffine()";
+  let dependentDialects = ["memref::MemRefDialect"];
+}
+
 def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
   let summary = "Simplify affine expressions in maps/sets and normalize "
                 "memrefs";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8acb21d5074b4..11a087f59b072 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -294,10 +294,12 @@ bool mlir::affine::isValidDim(Value value) {
     return isValidDim(value, getAffineScope(defOp));
 
   // This value has to be a block argument for an op that has the
-  // `AffineScope` trait or for an affine.for or affine.parallel.
+  // `AffineScope` trait or an induction var of an affine.for or
+  // affine.parallel.
+  if (isAffineInductionVar(value))
+    return true;
   auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
-  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
-                      isa<AffineForOp, AffineParallelOp>(parentOp));
+  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
 }
 
 // Value can be used as a dimension id iff it meets one of the following
@@ -316,10 +318,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
 
   auto *op = value.getDefiningOp();
   if (!op) {
-    // This value has to be a block argument for an affine.for or an
+    // This value has to be an induction var for an affine.for or an
     // affine.parallel.
-    auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
-    return isa<AffineForOp, AffineParallelOp>(parentOp);
+    return isAffineInductionVar(value);
   }
 
   // Affine apply operation is ok if all of its operands are ok.
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c42789b01bc9f..1c82822b2bd7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   LoopUnroll.cpp
   LoopUnrollAndJam.cpp
   PipelineDataTransfer.cpp
+  RaiseMemrefDialect.cpp
   ReifyValueBounds.cpp
   SuperVectorize.cpp
   SimplifyAffineStructures.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index f28fb3acb7db7..4d5ff5765ccc9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -13,9 +13,20 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <iterator>
 
 using namespace mlir;
 using namespace mlir::affine;
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
new file mode 100644
index 0000000000000..2fd4754900000
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -0,0 +1,168 @@
+
+
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <optional>
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_RAISEMEMREFDIALECT
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "raise-memref-to-affine"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+                const std::function<bool(Value)> &isValidElement) {
+
+  Value *loopIV = std::find(dims.begin(), dims.end(), value);
+  if (loopIV != dims.end()) {
+    // found an IV that already has an index
+    return {std::distance(dims.begin(), loopIV)};
+  }
+  if (isValidElement(value)) {
+    // push this IV in the parameters
+    size_t idx = dims.size();
+    dims.push_back(value);
+    return idx;
+  }
+  return std::nullopt;
+}
+
+static LogicalResult toAffineExpr(Value value, AffineExpr &result,
+                                  llvm::SmallVectorImpl<Value> &affineDims,
+                                  llvm::SmallVectorImpl<Value> &affineSymbols) {
+  using namespace matchers;
+  IntegerAttr::ValueType cst;
+  if (matchPattern(value, m_ConstantInt(&cst))) {
+    result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+    return success();
+  }
+  Value lhs;
+  Value rhs;
+  if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
+      matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
+    AffineExpr lhsE;
+    AffineExpr rhsE;
+    if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
+        succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
+      AffineExprKind kind;
+      if (isa<arith::AddIOp>(value.getDefiningOp())) {
+        kind = mlir::AffineExprKind::Add;
+      } else {
+        kind = mlir::AffineExprKind::Mul;
+      }
+      result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
+      return success();
+    }
+  }
+
+  if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+        return affine::isValidSymbol(v);
+      })) {
+    result = getAffineSymbolExpr(*dimIx, value.getContext());
+    return success();
+  }
+
+  if (auto dimIx = findInListOrAdd(
+          value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+    result = getAffineDimExpr(*dimIx, value.getContext());
+    return success();
+  }
+
+  return failure();
+}
+
+static LogicalResult
+computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
+                        llvm::SmallVectorImpl<Value> &mapArgs) {
+  llvm::SmallVector<AffineExpr> results;
+  llvm::SmallVector<Value, 2> symbols;
+  llvm::SmallVector<Value, 8> dims;
+
+  for (auto indexExpr : indices) {
+    if (failed(
+            toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
+      return failure();
+    }
+  }
+
+  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+  dims.append(symbols);
+  mapArgs.swap(dims);
+  return success();
+}
+
+struct RaiseMemrefDialect
+    : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    Operation *op = getOperation();
+    IRRewriter rewriter(ctx);
+    AffineMap map;
+    SmallVector<Value> mapArgs;
+    op->walk([&](Operation *op) {
+      rewriter.setInsertionPoint(op);
+      if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
+
+        if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
+                                              mapArgs))) {
+          rewriter.replaceOpWithNewOp<AffineStoreOp>(
+              op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
+        } else {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "[affine] Cannot raise memref op: " << op << "\n");
+        }
+
+      } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+
+        if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
+                                              mapArgs))) {
+          rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
+                                                    mapArgs);
+        } else {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "[affine] Cannot raise memref op: " << op << "\n");
+        }
+      }
+    });
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createRaiseMemrefToAffine() {
+  return std::make_unique<RaiseMemrefDialect>();
+}
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
new file mode 100644
index 0000000000000..d529e2c0c907a
--- /dev/null
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
+
+// CHECK-LABEL:    func @reduce_window_max() {
+func.func @reduce_window_max() {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = memref.alloc() : memref<1x8x8x64xf32>
+  %1 = memref.alloc() : memref<1x18x18x64xf32>
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 8 {
+      affine.for %arg2 = 0 to 8 {
+        affine.for %arg3 = 0 to 64 {
+          memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+        }
+      }
+    }
+  }
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 8 {
+      affine.for %arg2 = 0 to 8 {
+        affine.for %arg3 = 0 to 64 {
+          affine.for %arg4 = 0 to 1 {
+            affine.for %arg5 = 0 to 3 {
+              affine.for %arg6 = 0 to 3 {
+                affine.for %arg7 = 0 to 1 {
+                  %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+                  %21 = arith.addi %arg0, %arg4 : index
+                  %22 = arith.constant 2 : index
+                  %23 = arith.muli %arg1, %22 : index
+                  %24 = arith.addi %23, %arg5 : index
+                  %25 = arith.muli %arg2, %22 : index
+                  %26 = arith.addi %25, %arg6 : index
+                  %27 = arith.addi %arg3, %arg7 : index
+                  %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
+                  %4 = arith.cmpf ogt, %2, %3 : f32
+                  %5 = arith.select %4, %2, %3 : f32
+                  memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return
+}
+
+// CHECK:        %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
+// CHECK:        %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
+// CHECK:        affine.for %[[arg0:.*]] = 0 to 1 {
+// CHECK:          affine.for %[[arg1:.*]] = 0 to 8 {
+// CHECK:            affine.for %[[arg2:.*]] = 0 to 8 {
+// CHECK:              affine.for %[[arg3:.*]] = 0 to 64 {
+// CHECK:                affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
+// CHECK:              }
+// CHECK:            }
+// CHECK:          }
+// CHECK:        }
+// CHECK:        affine.for %[[a0:.*]] = 0 to 1 {
+// CHECK:          affine.for %[[a1:.*]] = 0 to 8 {
+// CHECK:            affine.for %[[a2:.*]] = 0 to 8 {
+// CHECK:              affine.for %[[a3:.*]] = 0 to 64 {
+// CHECK:                affine.for %[[a4:.*]] = 0 to 1 {
+// CHECK:                  affine.for %[[a5:.*]] = 0 to 3 {
+// CHECK:                    affine.for %[[a6:.*]] = 0 to 3 {
+// CHECK:                      affine.for %[[a7:.*]] = 0 to 1 {
+// CHECK:                        %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
+// CHECK:                        %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
+// CHECK:                        %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
+// CHECK:                        %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
+// CHECK:                        affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
+// CHECK:                      }
+// CHECK:                    }
+// CHECK:                  }
+// CHECK:                }
+// CHECK:              }
+// CHECK:            }
+// CHECK:          }
+// CHECK:        }
+// CHECK:      }
+
+func.func @symbols(%N : index) {
+  %0 = memref.alloc() : memref<1024x1024xf32>
+  %1 = memref.alloc() : memref<1024x1024xf32>
+  %2 = memref.alloc() : memref<1024x1024xf32>
+  %cst1 = arith.constant 1 : index
+  %cst2 = arith.constant 2 : index
+  affine.for %i = 0 to %N {
+    affine.for %j = 0 to %N {
+      %7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
+      %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
+        %12 = arith.muli %N, %cst2 : index
+        %13 = arith.addi %12, %cst1 : index
+        %14 = arith.addi %13, %j : index
+        %5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
+        %6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
+        %8 = arith.mulf %5, %6 : f32
+        %9 = arith.addf %7, %8 : f32
+        %4 = arith.addi %N, %cst1 : index
+        %11 = arith.addi %ax, %cst1 : index
+        memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
+        memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
+        %something = "ab.v"() : () -> index
+        memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
+        affine.yield %11 : index
+      }
+    }
+  }
+  return
+}
+
+// CHECK:          %[[cst1:.*]] = arith.constant 1 : index
+// CHECK:          %[[v0:.*]] = memref.alloc() : memref<
+// CHECK:          %[[v1:.*]] = memref.alloc() : memref<
+// CHECK:          %[[v2:.*]] = memref.alloc() : memref<
+// CHECK:          affine.for %[[a1:.*]] = 0 to %arg0 {
+// CHECK-NEXT:        affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK-NEXT:           %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK-NEXT:           affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK-NEXT:             %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
+// CHECK-NEXT:             %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
+// CHECK-NEXT:             %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK-NEXT:             %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK-NEXT:             %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK-NEXT:             affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
+// CHECK-NEXT:             memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
+// CHECK-NEXT:             %[[lhs7:.*]] = "ab.v"
+// CHECK-NEXT:             memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
+// CHECK-NEXT:             affine.yield %[[lhs6]]

>From e61f0baf02f08f2a6d2af9dd792f4d44116953a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 29 Nov 2024 13:50:04 +0100
Subject: [PATCH 02/23] Address review comments

---
 mlir/include/mlir/Dialect/Affine/Passes.td    |  2 +-
 .../Affine/Transforms/DecomposeAffineOps.cpp  | 11 ---
 .../Affine/Transforms/RaiseMemrefDialect.cpp  | 92 ++++++++++---------
 mlir/test/Dialect/Affine/raise-memref.mlir    | 78 +++++++---------
 4 files changed, 82 insertions(+), 101 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index d47569968d901..67f9138589c47 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -405,7 +405,7 @@ def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
     They can be turned back to memref dialect ops with --lower-affine.
   }];
   let constructor = "mlir::affine::createRaiseMemrefToAffine()";
-  let dependentDialects = ["memref::MemRefDialect"];
+  let dependentDialects = ["affine::AffineDialect"];
 }
 
 def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 4d5ff5765ccc9..f28fb3acb7db7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -13,20 +13,9 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
-#include "llvm/Support/LogicalResult.h"
-#include <algorithm>
-#include <cstddef>
-#include <functional>
-#include <iterator>
 
 using namespace mlir;
 using namespace mlir::affine;
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
index 2fd4754900000..a6e961a6d6439 100644
--- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -1,29 +1,27 @@
-
+//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functionality to convert memref load and store ops to
+// the corresponding affine ops, inferring the affine map as needed.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
-#include "llvm/Support/LogicalResult.h"
-#include <algorithm>
-#include <cstddef>
-#include <functional>
-#include <iterator>
-#include <memory>
-#include <optional>
 
 namespace mlir {
 namespace affine {
@@ -39,17 +37,24 @@ using namespace mlir::affine;
 
 namespace {
 
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
 static std::optional<size_t>
 findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
-                const std::function<bool(Value)> &isValidElement) {
+                function_ref<bool(Value)> isValidElement) {
 
   Value *loopIV = std::find(dims.begin(), dims.end(), value);
   if (loopIV != dims.end()) {
-    // found an IV that already has an index
+    // We found an IV that already has an index, return that index.
     return {std::distance(dims.begin(), loopIV)};
   }
   if (isValidElement(value)) {
-    // push this IV in the parameters
+    // This is a valid element for the dim/symbol list, push this as a
+    // parameter.
     size_t idx = dims.size();
     dims.push_back(value);
     return idx;
@@ -57,14 +62,15 @@ findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
   return std::nullopt;
 }
 
-static LogicalResult toAffineExpr(Value value, AffineExpr &result,
-                                  llvm::SmallVectorImpl<Value> &affineDims,
-                                  llvm::SmallVectorImpl<Value> &affineSymbols) {
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+                               llvm::SmallVectorImpl<Value> &affineDims,
+                               llvm::SmallVectorImpl<Value> &affineSymbols) {
   using namespace matchers;
   IntegerAttr::ValueType cst;
   if (matchPattern(value, m_ConstantInt(&cst))) {
-    result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
-    return success();
+    return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
   }
   Value lhs;
   Value rhs;
@@ -72,48 +78,46 @@ static LogicalResult toAffineExpr(Value value, AffineExpr &result,
       matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
     AffineExpr lhsE;
     AffineExpr rhsE;
-    if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
-        succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
+    if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
+        (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
       AffineExprKind kind;
       if (isa<arith::AddIOp>(value.getDefiningOp())) {
         kind = mlir::AffineExprKind::Add;
       } else {
         kind = mlir::AffineExprKind::Mul;
       }
-      result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
-      return success();
+      return getAffineBinaryOpExpr(kind, lhsE, rhsE);
     }
   }
 
   if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
         return affine::isValidSymbol(v);
       })) {
-    result = getAffineSymbolExpr(*dimIx, value.getContext());
-    return success();
+    return getAffineSymbolExpr(*dimIx, value.getContext());
   }
 
   if (auto dimIx = findInListOrAdd(
           value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
 
-    result = getAffineDimExpr(*dimIx, value.getContext());
-    return success();
+    return getAffineDimExpr(*dimIx, value.getContext());
   }
 
-  return failure();
+  return {};
 }
 
 static LogicalResult
 computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
                         llvm::SmallVectorImpl<Value> &mapArgs) {
-  llvm::SmallVector<AffineExpr> results;
-  llvm::SmallVector<Value, 2> symbols;
-  llvm::SmallVector<Value, 8> dims;
+  SmallVector<AffineExpr> results;
+  SmallVector<Value> symbols;
+  SmallVector<Value> dims;
 
-  for (auto indexExpr : indices) {
-    if (failed(
-            toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
+  for (Value indexExpr : indices) {
+    AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+    if (!res) {
       return failure();
     }
+    results.push_back(res);
   }
 
   map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
@@ -140,21 +144,21 @@ struct RaiseMemrefDialect
                                               mapArgs))) {
           rewriter.replaceOpWithNewOp<AffineStoreOp>(
               op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
-        } else {
-          LLVM_DEBUG(llvm::dbgs()
-                     << "[affine] Cannot raise memref op: " << op << "\n");
+          return;
         }
 
-      } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+        LLVM_DEBUG(llvm::dbgs()
+                   << "[affine] Cannot raise memref op: " << op << "\n");
 
+      } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
         if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
                                               mapArgs))) {
           rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
                                                     mapArgs);
-        } else {
-          LLVM_DEBUG(llvm::dbgs()
-                     << "[affine] Cannot raise memref op: " << op << "\n");
+          return;
         }
+        LLVM_DEBUG(llvm::dbgs()
+                   << "[affine] Cannot raise memref op: " << op << "\n");
       }
     });
   }
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
index d529e2c0c907a..d8f2aaab4839e 100644
--- a/mlir/test/Dialect/Affine/raise-memref.mlir
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
 
-// CHECK-LABEL:    func @reduce_window_max() {
+// CHECK-LABEL:    func @reduce_window_max(
 func.func @reduce_window_max() {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = memref.alloc() : memref<1x8x8x64xf32>
@@ -45,41 +45,29 @@ func.func @reduce_window_max() {
   return
 }
 
-// CHECK:        %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:        %[[cst:.*]] = arith.constant 0
 // CHECK:        %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
 // CHECK:        %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
-// CHECK:        affine.for %[[arg0:.*]] = 0 to 1 {
-// CHECK:          affine.for %[[arg1:.*]] = 0 to 8 {
-// CHECK:            affine.for %[[arg2:.*]] = 0 to 8 {
-// CHECK:              affine.for %[[arg3:.*]] = 0 to 64 {
-// CHECK:                affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
-// CHECK:              }
-// CHECK:            }
-// CHECK:          }
-// CHECK:        }
-// CHECK:        affine.for %[[a0:.*]] = 0 to 1 {
-// CHECK:          affine.for %[[a1:.*]] = 0 to 8 {
-// CHECK:            affine.for %[[a2:.*]] = 0 to 8 {
-// CHECK:              affine.for %[[a3:.*]] = 0 to 64 {
-// CHECK:                affine.for %[[a4:.*]] = 0 to 1 {
-// CHECK:                  affine.for %[[a5:.*]] = 0 to 3 {
-// CHECK:                    affine.for %[[a6:.*]] = 0 to 3 {
-// CHECK:                      affine.for %[[a7:.*]] = 0 to 1 {
-// CHECK:                        %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
-// CHECK:                        %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
+// CHECK:        affine.for %[[arg0:.*]] =
+// CHECK:          affine.for %[[arg1:.*]] =
+// CHECK:            affine.for %[[arg2:.*]] =
+// CHECK:              affine.for %[[arg3:.*]] =
+// CHECK:                affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
+// CHECK:        affine.for %[[a0:.*]] =
+// CHECK:          affine.for %[[a1:.*]] =
+// CHECK:            affine.for %[[a2:.*]] =
+// CHECK:              affine.for %[[a3:.*]] =
+// CHECK:                affine.for %[[a4:.*]] =
+// CHECK:                  affine.for %[[a5:.*]] =
+// CHECK:                    affine.for %[[a6:.*]] =
+// CHECK:                      affine.for %[[a7:.*]] =
+// CHECK:                        %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK:                        %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
 // CHECK:                        %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
 // CHECK:                        %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
-// CHECK:                        affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
-// CHECK:                      }
-// CHECK:                    }
-// CHECK:                  }
-// CHECK:                }
-// CHECK:              }
-// CHECK:            }
-// CHECK:          }
-// CHECK:        }
-// CHECK:      }
+// CHECK:                        affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
 
+// CHECK-LABEL:    func @symbols(
 func.func @symbols(%N : index) {
   %0 = memref.alloc() : memref<1024x1024xf32>
   %1 = memref.alloc() : memref<1024x1024xf32>
@@ -100,7 +88,7 @@ func.func @symbols(%N : index) {
         %4 = arith.addi %N, %cst1 : index
         %11 = arith.addi %ax, %cst1 : index
         memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
-        memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
+        memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
         %something = "ab.v"() : () -> index
         memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
         affine.yield %11 : index
@@ -115,16 +103,16 @@ func.func @symbols(%N : index) {
 // CHECK:          %[[v1:.*]] = memref.alloc() : memref<
 // CHECK:          %[[v2:.*]] = memref.alloc() : memref<
 // CHECK:          affine.for %[[a1:.*]] = 0 to %arg0 {
-// CHECK-NEXT:        affine.for %[[a2:.*]] = 0 to %arg0 {
-// CHECK-NEXT:           %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
-// CHECK-NEXT:           affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
-// CHECK-NEXT:             %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
-// CHECK-NEXT:             %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
-// CHECK-NEXT:             %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
-// CHECK-NEXT:             %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
-// CHECK-NEXT:             %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
-// CHECK-NEXT:             affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
-// CHECK-NEXT:             memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
-// CHECK-NEXT:             %[[lhs7:.*]] = "ab.v"
-// CHECK-NEXT:             memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
-// CHECK-NEXT:             affine.yield %[[lhs6]]
+// CHECK:             affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK:                %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK:                affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK:                  %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
+// CHECK:                  %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
+// CHECK:                  %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK:                  %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK:                  %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK:                  affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
+// CHECK:                  memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
+// CHECK:                  %[[lhs7:.*]] = "ab.v"
+// CHECK:                  memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
+// CHECK:                  affine.yield %[[lhs6]]

>From f181b3b777dfb3247c2ceeb5a9d39e5ee5562f7b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sat, 1 Jun 2024 19:23:45 +0200
Subject: [PATCH 03/23] [mlir] Fix #93973 - linalg::ReduceOp verifier crash

---
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  2 +-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 16 ++++---
 mlir/test/Dialect/Linalg/roundtrip.mlir       | 42 +++++++++++++++++++
 3 files changed, 54 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index e4dd458eaff84..f51cf6b97bb83 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -318,7 +318,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
-    SameVariadicOperandSize,
+    AttrSizedOperandSegments,
     SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Reduce operator";
   let description = [{
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..cd4f9d20730c6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1354,11 +1354,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 static ParseResult parseDstStyleOp(
     OpAsmParser &parser, OperationState &result,
     function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
-        nullptr) {
+        nullptr,
+    bool addOperandSegmentSizes = false) {
   // Parse `ins` and `outs`.
   SmallVector<Type, 4> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
-                                   /*addOperandSegmentSizes=*/false))
+                                   addOperandSegmentSizes))
     return failure();
 
   // Add result types.
@@ -1707,9 +1708,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
   }
 
   if (parseDstStyleOp(
-          parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+          parser, result,
+          [&](OpAsmParser &parser, NamedAttrList &attributes) {
             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
-          }))
+          },
+          /*addOperandSegmentSizes=*/true))
+
     return failure();
 
   if (payloadOpName.has_value()) {
@@ -1744,7 +1748,9 @@ void ReduceOp::print(OpAsmPrinter &p) {
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
-  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      {getDimensionsAttrName(), getOperandSegmentSizesAttrName()});
   if (!payloadOp) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index dc556761b09e5..9459309eb4c0d 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -497,6 +497,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
 
 // -----
 
+func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>,
+                  %init: tensor<16x64xi32>) -> tensor<16x64xi32> {
+  %reduce = linalg.reduce
+      ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+      outs(%init:tensor<16x64xi32>)
+      dimensions = [1]
+      (%in: i32, %in2: i32, %out: i32) {
+        %0 = arith.muli %in, %in2: i32
+        %1 = arith.addi %out, %0: i32
+        linalg.yield %1: i32
+      }
+  func.return %reduce : tensor<16x64xi32>
+}
+// CHECK-LABEL: func @reduce_asymmetric
+//       CHECK:   linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>)
+//  CHECK-NOT:    operandSegmentSize
+//  CHECK-SAME:   outs(%{{.*}}: tensor<16x64xi32>)
+//  CHECK-SAME:   dimensions = [1]
+
+// -----
+
+func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>,
+                  %init: memref<16x64xi32>) {
+  linalg.reduce
+      ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>)
+      outs(%init:memref<16x64xi32>)
+      dimensions = [1]
+      (%in: i32, %in2: i32, %out: i32) {
+        %0 = arith.muli %in, %in2: i32
+        %1 = arith.addi %out, %0: i32
+        linalg.yield %1: i32
+      }
+  func.return
+}
+// CHECK-LABEL: func @reduce_asymmetric_memref
+//       CHECK:   linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>)
+//  CHECK-NOT:    operandSegmentSize
+//  CHECK-SAME:   outs(%{{.*}}: memref<16x64xi32>)
+//  CHECK-SAME:   dimensions = [1]
+
+// -----
+
 func.func @transpose(%input: tensor<16x32x64xf32>,
                      %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
   %transpose = linalg.transpose

>From 49dcae22912557f8071124d6ea00728a9560b922 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 25 Oct 2024 14:38:50 +0200
Subject: [PATCH 04/23] Make affine and bufferization pass applicable to any
 AffineScopeOp/AutomaticAllocationScope

This doesnt affect their behavior when they are just
called from the command line, since these switches
will target the ops nested within the module (which
before where func.func, now can be others).

Fix affine tests

Fix anchor for some passes

TODO better fix would be to improve the PassManager
to support passes that can run on an arbitrary nesting
level. Then these passes could continue to target the
root op properly. The fix with Nesting::ImplicitAny is
really bad because it changes the anchor of OperationPass<>
passes (any anchor) to target children of the root instead
of the root.

Fix pass manager nesting again
---
 mlir/include/mlir/Dialect/Affine/LoopUtils.h  |   5 +-
 mlir/include/mlir/Dialect/Affine/Passes.h     |  40 +-
 mlir/include/mlir/Dialect/Affine/Passes.td    |  31 +-
 mlir/include/mlir/Dialect/Affine/Utils.h      |   2 +-
 .../Dialect/Bufferization/Transforms/Passes.h |  12 +
 .../Bufferization/Transforms/Passes.td        | 100 ++-
 .../Dialect/Transform/Transforms/Passes.td    |   8 +-
 mlir/include/mlir/Pass/PassManager.h          |   5 +-
 mlir/include/mlir/Transforms/Passes.td        |   2 +-
 .../Transforms/AffineDataCopyGeneration.cpp   |  16 +-
 .../AffineLoopInvariantCodeMotion.cpp         |   4 +-
 .../Affine/Transforms/AffineLoopNormalize.cpp |   4 +-
 .../Affine/Transforms/AffineParallelize.cpp   |   9 +-
 .../Transforms/AffineScalarReplacement.cpp    |  10 +-
 .../Affine/Transforms/LoopCoalescing.cpp      |   5 +-
 .../Dialect/Affine/Transforms/LoopTiling.cpp  |   4 +-
 .../Dialect/Affine/Transforms/LoopUnroll.cpp  |  18 +-
 .../Affine/Transforms/LoopUnrollAndJam.cpp    |   6 +-
 .../Transforms/PipelineDataTransfer.cpp       |   4 +-
 .../Transforms/SimplifyAffineStructures.cpp   |  15 +-
 .../Affine/Transforms/SuperVectorize.cpp      |  13 +-
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   |   8 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |   9 +-
 .../Transforms/BufferDeallocation.cpp         | 693 ++++++++++++++++++
 .../Bufferization/Transforms/CMakeLists.txt   |   1 +
 mlir/lib/Pass/Pass.cpp                        |  19 +-
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |   2 +-
 .../lib/Dialect/Affine/TestAccessAnalysis.cpp |   5 +-
 .../lib/Dialect/Affine/TestAffineDataCopy.cpp |   3 +-
 .../Affine/TestAffineLoopParametricTiling.cpp |   3 +-
 .../Dialect/Affine/TestDecomposeAffineOps.cpp |   4 +-
 .../lib/Dialect/Affine/TestLoopFusion.cpp     |   3 +-
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  16 +-
 .../Dialect/Affine/TestVectorizationUtils.cpp |  20 +-
 34 files changed, 958 insertions(+), 141 deletions(-)
 create mode 100644 mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp

diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 7fe1f6d48ceeb..1822d535dfe25 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -16,6 +16,7 @@
 #define MLIR_DIALECT_AFFINE_LOOPUTILS_H
 
 #include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include <optional>
@@ -101,7 +102,7 @@ LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
 /// Identify valid and profitable bands of loops to tile. This is currently just
 /// a temporary placeholder to test the mechanics of tiled code generation.
 /// Returns all maximal outermost perfect loop nests to tile.
-void getTileableBands(func::FuncOp f,
+void getTileableBands(Operation *f,
                       std::vector<SmallVector<AffineForOp, 6>> *bands);
 
 /// Tiles the specified band of perfectly nested loops creating tile-space loops
@@ -272,7 +273,7 @@ void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
                            ArrayRef<Value> numProcessors);
 
 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
-void gatherLoops(func::FuncOp func,
+void gatherLoops(Operation* func,
                  std::vector<SmallVector<AffineForOp, 2>> &depthToLoops);
 
 /// Creates an AffineForOp while ensuring that the lower and upper bounds are
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 2f70f24dd3ef2..e580d73d83a8a 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AFFINE_PASSES_H
 #define MLIR_DIALECT_AFFINE_PASSES_H
 
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include <limits>
@@ -30,6 +31,16 @@ class MemRefDialect;
 namespace affine {
 class AffineForOp;
 
+
+class AffineScopePassBase : public OperationPass<> {
+  using OperationPass<>::OperationPass;
+
+  bool canScheduleOn(RegisteredOperationName opInfo) const final {
+    return opInfo.hasTrait<OpTrait::AffineScope>() &&
+           opInfo.getStringRef() != ModuleOp::getOperationName();
+  }
+};
+
 /// Fusion mode to attempt. The default mode `Greedy` does both
 /// producer-consumer and sibling fusion.
 enum FusionMode { Greedy, ProducerConsumer, Sibling };
@@ -40,47 +51,46 @@ enum FusionMode { Greedy, ProducerConsumer, Sibling };
 /// Creates a simplification pass for affine structures (maps and sets). In
 /// addition, this pass also normalizes memrefs to have the trivial (identity)
 /// layout map.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 createSimplifyAffineStructuresPass();
 
 /// Creates a loop invariant code motion pass that hoists loop invariant
 /// operations out of affine loops.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 createAffineLoopInvariantCodeMotionPass();
 
 /// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel
 /// ops.
-std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+std::unique_ptr<AffineScopePassBase> createAffineParallelizePass();
 
 /// Creates a pass that converts some memref operators to affine operators.
-std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+std::unique_ptr<AffineScopePassBase> createRaiseMemrefToAffine();
 
 /// Apply normalization transformations to affine loop-like ops. If
 /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
 /// loop is replaced by its loop body).
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 createAffineLoopNormalizePass(bool promoteSingleIter = false);
 
 /// Performs packing (or explicit copying) of accessed memref regions into
 /// buffers in the specified faster memory space through either pointwise copies
 /// or DMA operations.
-std::unique_ptr<OperationPass<func::FuncOp>> createAffineDataCopyGenerationPass(
+std::unique_ptr<AffineScopePassBase> createAffineDataCopyGenerationPass(
     unsigned slowMemorySpace, unsigned fastMemorySpace,
     unsigned tagMemorySpace = 0, int minDmaTransferSize = 1024,
     uint64_t fastMemCapacityBytes = std::numeric_limits<uint64_t>::max());
 /// Overload relying on pass options for initialization.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 createAffineDataCopyGenerationPass();
 
 /// Creates a pass to replace affine memref accesses by scalars using store to
 /// load forwarding and redundant load elimination; consequently also eliminate
 /// dead allocs.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createAffineScalarReplacementPass();
+std::unique_ptr<AffineScopePassBase> createAffineScalarReplacementPass();
 
 /// Creates a pass that transforms perfectly nested loops with independent
 /// bounds into a single loop.
-std::unique_ptr<OperationPass<func::FuncOp>> createLoopCoalescingPass();
+std::unique_ptr<AffineScopePassBase> createLoopCoalescingPass();
 
 /// Creates a loop fusion pass which fuses affine loop nests at the top-level of
 /// the operation the pass is created on according to the type of fusion
@@ -93,10 +103,10 @@ createLoopFusionPass(unsigned fastMemorySpace = 0,
                      enum FusionMode fusionMode = FusionMode::Greedy);
 
 /// Creates a pass to perform tiling on loop nests.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 createLoopTilingPass(uint64_t cacheSizeBytes);
 /// Overload relying on pass options for initialization.
-std::unique_ptr<OperationPass<func::FuncOp>> createLoopTilingPass();
+std::unique_ptr<AffineScopePassBase> createLoopTilingPass();
 
 /// Creates a loop unrolling pass with the provided parameters.
 /// 'getUnrollFactor' is a function callback for clients to supply a function
@@ -104,7 +114,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLoopTilingPass();
 /// factors supplied through other means. If -1 is passed as the unrollFactor
 /// and no callback is provided, anything passed from the command-line (if at
 /// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
-std::unique_ptr<InterfacePass<FunctionOpInterface>> createLoopUnrollPass(
+std::unique_ptr<AffineScopePassBase> createLoopUnrollPass(
     int unrollFactor = -1, bool unrollUpToFactor = false,
     bool unrollFull = false,
     const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr);
@@ -112,12 +122,12 @@ std::unique_ptr<InterfacePass<FunctionOpInterface>> createLoopUnrollPass(
 /// Creates a loop unroll jam pass to unroll jam by the specified factor. A
 /// factor of -1 lets the pass use the default factor or the one on the command
 /// line if provided.
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
+std::unique_ptr<AffineScopePassBase>
 createLoopUnrollAndJamPass(int unrollJamFactor = -1);
 
 /// Creates a pass to pipeline explicit movement of data across levels of the
 /// memory hierarchy.
-std::unique_ptr<OperationPass<func::FuncOp>> createPipelineDataTransferPass();
+std::unique_ptr<AffineScopePassBase> createPipelineDataTransferPass();
 
 /// Creates a pass to expand affine index operations into more fundamental
 /// operations (not necessarily restricted to Affine dialect).
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 67f9138589c47..f54c8efc43a70 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -15,7 +15,10 @@
 
 include "mlir/Pass/PassBase.td"
 
-def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp"> {
+class AffineScopePass<string name> 
+  : PassBase<name, "::mlir::affine::AffineScopePassBase">;
+
+def AffineDataCopyGeneration : AffineScopePass<"affine-data-copy-generate"> {
   let summary = "Generate explicit copying for affine memory operations";
   let constructor = "mlir::affine::createAffineDataCopyGenerationPass()";
   let dependentDialects = ["memref::MemRefDialect"];
@@ -43,7 +46,7 @@ def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp">
   ];
 }
 
-def AffineLoopFusion : Pass<"affine-loop-fusion"> {
+def AffineLoopFusion : AffineScopePass<"affine-loop-fusion"> {
   let summary = "Fuse affine loop nests";
   let description = [{
     This pass performs fusion of loop nests using a slicing-based approach. The
@@ -178,12 +181,12 @@ def AffineLoopFusion : Pass<"affine-loop-fusion"> {
 }
 
 def AffineLoopInvariantCodeMotion
-    : Pass<"affine-loop-invariant-code-motion", "func::FuncOp"> {
+    : AffineScopePass<"affine-loop-invariant-code-motion"> {
   let summary = "Hoist loop invariant instructions outside of affine loops";
   let constructor = "mlir::affine::createAffineLoopInvariantCodeMotionPass()";
 }
 
-def AffineLoopTiling : Pass<"affine-loop-tile", "func::FuncOp"> {
+def AffineLoopTiling : AffineScopePass<"affine-loop-tile"> {
   let summary = "Tile affine loop nests";
   let constructor = "mlir::affine::createLoopTilingPass()";
   let options = [
@@ -199,7 +202,7 @@ def AffineLoopTiling : Pass<"affine-loop-tile", "func::FuncOp"> {
   ];
 }
 
-def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface"> {
+def AffineLoopUnroll : AffineScopePass<"affine-loop-unroll"> {
   let summary = "Unroll affine loops";
   let constructor = "mlir::affine::createLoopUnrollPass()";
   let options = [
@@ -219,7 +222,7 @@ def AffineLoopUnroll : InterfacePass<"affine-loop-unroll", "FunctionOpInterface"
   ];
 }
 
-def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOpInterface"> {
+def AffineLoopUnrollAndJam : AffineScopePass<"affine-loop-unroll-jam"> {
   let summary = "Unroll and jam affine loops";
   let constructor = "mlir::affine::createLoopUnrollAndJamPass()";
   let options = [
@@ -230,7 +233,7 @@ def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOp
 }
 
 def AffinePipelineDataTransfer
-    : Pass<"affine-pipeline-data-transfer", "func::FuncOp"> {
+    : AffineScopePass<"affine-pipeline-data-transfer"> {
   let summary = "Pipeline non-blocking data transfers between explicitly "
                 "managed levels of the memory hierarchy";
   let description = [{
@@ -298,7 +301,7 @@ def AffinePipelineDataTransfer
   let constructor = "mlir::affine::createPipelineDataTransferPass()";
 }
 
-def AffineScalarReplacement : Pass<"affine-scalrep", "func::FuncOp"> {
+def AffineScalarReplacement : AffineScopePass<"affine-scalrep"> {
   let summary = "Replace affine memref accesses by scalars by forwarding stores "
                 "to loads and eliminating redundant loads";
   let description = [{
@@ -344,7 +347,7 @@ def AffineScalarReplacement : Pass<"affine-scalrep", "func::FuncOp"> {
   let constructor = "mlir::affine::createAffineScalarReplacementPass()";
 }
 
-def AffineVectorize : Pass<"affine-super-vectorize", "func::FuncOp"> {
+def AffineVectorize : AffineScopePass<"affine-super-vectorize"> {
   let summary = "Vectorize to a target independent n-D vector abstraction";
   let dependentDialects = ["vector::VectorDialect"];
   let options = [
@@ -368,7 +371,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "func::FuncOp"> {
   ];
 }
 
-def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> {
+def AffineParallelize : AffineScopePass<"affine-parallelize"> {
   let summary = "Convert affine.for ops into 1-D affine.parallel";
   let options = [
     Option<"maxNested", "max-nested", "unsigned", /*default=*/"-1u",
@@ -380,7 +383,7 @@ def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> {
   ];
 }
 
-def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> {
+def AffineLoopNormalize : AffineScopePass<"affine-loop-normalize"> {
   let summary = "Apply normalization transformations to affine loop-like ops";
   let constructor = "mlir::affine::createAffineLoopNormalizePass()";
   let options = [
@@ -389,14 +392,14 @@ def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> {
   ];
 }
 
-def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
+def LoopCoalescing : AffineScopePass<"affine-loop-coalescing"> {
   let summary = "Coalesce nested loops with independent bounds into a single "
                 "loop";
   let constructor = "mlir::affine::createLoopCoalescingPass()";
   let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
 }
 
-def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+def RaiseMemrefDialect : AffineScopePass<"affine-raise-from-memref"> {
   let summary = "Turn some memref operators to affine operators where supported";
   let description = [{
     Raise memref.load and memref.store to affine.store and affine.load, inferring
@@ -408,7 +411,7 @@ def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
   let dependentDialects = ["affine::AffineDialect"];
 }
 
-def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
+def SimplifyAffineStructures : AffineScopePass<"affine-simplify-structures"> {
   let summary = "Simplify affine expressions in maps/sets and normalize "
                 "memrefs";
   let constructor = "mlir::affine::createSimplifyAffineStructuresPass()";
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index ff1900bc8f2eb..250c28d0c9d41 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -105,7 +105,7 @@ struct VectorizationStrategy {
 /// Replace affine store and load accesses by scalars by forwarding stores to
 /// loads and eliminate invariant affine loads; consequently, eliminate dead
 /// allocs.
-void affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
+void affineScalarReplace(Operation* parentOp, DominanceInfo &domInfo,
                          PostDominanceInfo &postDomInfo,
                          AliasAnalysis &analysis);
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 493180cd54e5b..50b2fac4ba994 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -4,6 +4,8 @@
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/OpDefinition.h>
 
 namespace mlir {
 class FunctionOpInterface;
@@ -23,6 +25,16 @@ struct OneShotBufferizationOptions;
 /// Maps from symbol table to its corresponding dealloc helper function.
 using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;
 
+
+class BufferScopePassBase : public OperationPass<> {
+  using OperationPass<>::OperationPass;
+
+  bool canScheduleOn(RegisteredOperationName opInfo) const final {
+    return opInfo.hasTrait<OpTrait::AutomaticAllocationScope>() &&
+           opInfo.getStringRef() != ModuleOp::getOperationName();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 3bbb8b02c644e..a4863a1b5c6d7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -11,8 +11,85 @@
 
 include "mlir/Pass/PassBase.td"
 
-def OwnershipBasedBufferDeallocationPass
-    : Pass<"ownership-based-buffer-deallocation"> {
+class BufferScopePass<string name> 
+  : PassBase<name, "::mlir::bufferization::BufferScopePassBase">;
+
+
+def BufferDeallocation : BufferScopePass<"buffer-deallocation"> {
+  let summary = "Adds all required dealloc operations for all allocations in "
+                "the input program";
+  let description = [{
+    This pass implements an algorithm to automatically introduce all required
+    deallocation operations for all buffers in the input program. This ensures
+    that the resulting program does not have any memory leaks.
+
+
+    Input
+
+    ```mlir
+    #map0 = affine_map<(d0) -> (d0)>
+    module {
+      func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+        cf.cond_br %arg0, ^bb1, ^bb2
+      ^bb1:
+        cf.br ^bb3(%arg1 : memref<2xf32>)
+      ^bb2:
+        %0 = memref.alloc() : memref<2xf32>
+        linalg.generic {
+          indexing_maps = [#map0, #map0],
+          iterator_types = ["parallel"]} %arg1, %0 {
+        ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
+          %tmp1 = exp %gen1_arg0 : f32
+          linalg.yield %tmp1 : f32
+        }: memref<2xf32>, memref<2xf32>
+        cf.br ^bb3(%0 : memref<2xf32>)
+      ^bb3(%1: memref<2xf32>):
+        "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
+        return
+      }
+    }
+
+    ```
+
+    Output
+
+    ```mlir
+    #map0 = affine_map<(d0) -> (d0)>
+    module {
+      func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
+        cf.cond_br %arg0, ^bb1, ^bb2
+      ^bb1:  // pred: ^bb0
+        %0 = memref.alloc() : memref<2xf32>
+        memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
+        cf.br ^bb3(%0 : memref<2xf32>)
+      ^bb2:  // pred: ^bb0
+        %1 = memref.alloc() : memref<2xf32>
+        linalg.generic {
+          indexing_maps = [#map0, #map0],
+          iterator_types = ["parallel"]} %arg1, %1 {
+        ^bb0(%arg3: f32, %arg4: f32):
+          %4 = exp %arg3 : f32
+          linalg.yield %4 : f32
+        }: memref<2xf32>, memref<2xf32>
+        %2 = memref.alloc() : memref<2xf32>
+        memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
+        dealloc %1 : memref<2xf32>
+        cf.br ^bb3(%2 : memref<2xf32>)
+      ^bb3(%3: memref<2xf32>):  // 2 preds: ^bb1, ^bb2
+        memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
+        dealloc %3 : memref<2xf32>
+        return
+      }
+
+    }
+    ```
+
+  }];
+  let constructor = "mlir::bufferization::createBufferDeallocationPass()";
+}
+
+def OwnershipBasedBufferDeallocation : BufferScopePass<
+    "ownership-based-buffer-deallocation"> {
   let summary = "Adds all required dealloc operations for all allocations in "
                 "the input program";
   let description = [{
@@ -152,8 +229,8 @@ def OwnershipBasedBufferDeallocationPass
   ];
 }
 
-def BufferDeallocationSimplificationPass
-    : Pass<"buffer-deallocation-simplification"> {
+def BufferDeallocationSimplification :
+    BufferScopePass<"buffer-deallocation-simplification"> {
   let summary = "Optimizes `bufferization.dealloc` operation for more "
                 "efficient codegen";
   let description = [{
@@ -169,8 +246,8 @@ def BufferDeallocationSimplificationPass
   ];
 }
 
-def OptimizeAllocationLivenessPass
-    : Pass<"optimize-allocation-liveness", "func::FuncOp"> {
+def OptimizeAllocationLiveness
+    : BufferScopePass<"optimize-allocation-liveness"> {
   let summary = "This pass optimizes the liveness of temp allocations in the "
                 "input function";
   let description = [{
@@ -184,7 +261,7 @@ def OptimizeAllocationLivenessPass
   let dependentDialects = ["mlir::memref::MemRefDialect"];
 }
 
-def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> {
+def LowerDeallocations : BufferScopePass<"bufferization-lower-deallocations"> {
   let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
                 "operations";
   let description = [{
@@ -204,7 +281,7 @@ def LowerDeallocationsPass : Pass<"bufferization-lower-deallocations"> {
   ];
 }
 
-def BufferHoistingPass : Pass<"buffer-hoisting", "func::FuncOp"> {
+def BufferHoisting : BufferScopePass<"buffer-hoisting"> {
   let summary = "Optimizes placement of allocation operations by moving them "
                 "into common dominators and out of nested regions";
   let description = [{
@@ -213,7 +290,7 @@ def BufferHoistingPass : Pass<"buffer-hoisting", "func::FuncOp"> {
   }];
 }
 
-def BufferLoopHoistingPass : Pass<"buffer-loop-hoisting", "func::FuncOp"> {
+def BufferLoopHoisting : BufferScopePass<"buffer-loop-hoisting"> {
   let summary = "Optimizes placement of allocation operations by moving them "
                 "out of loop nests";
   let description = [{
@@ -462,8 +539,7 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
   ];
 }
 
-def PromoteBuffersToStackPass
-    : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
+def PromoteBuffersToStack : BufferScopePass<"promote-buffers-to-stack"> {
   let summary = "Promotes heap-based allocations to automatically managed "
                 "stack-based allocations";
   let description = [{
@@ -483,7 +559,7 @@ def PromoteBuffersToStackPass
   ];
 }
 
-def EmptyTensorEliminationPass : Pass<"eliminate-empty-tensors"> {
+def EmptyTensorElimination : BufferScopePass<"eliminate-empty-tensors"> {
   let summary = "Try to eliminate all tensor.empty ops.";
   let description = [{
     Try to eliminate "tensor.empty" ops inside `op`. This transformation looks
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
index 86a2b3c21faf0..d134d1d8acff0 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
@@ -11,7 +11,7 @@
 
 include "mlir/Pass/PassBase.td"
 
-def CheckUsesPass : Pass<"transform-dialect-check-uses"> {
+def CheckUsesPass : Pass<"transform-dialect-check-uses", "mlir::ModuleOp"> {
   let summary = "warn about potential use-after-free in the transform dialect";
   let description = [{
     This pass analyzes operations from the transform dialect and its extensions
@@ -32,7 +32,7 @@ def CheckUsesPass : Pass<"transform-dialect-check-uses"> {
   }];
 }
 
-def InferEffectsPass : Pass<"transform-infer-effects"> {
+def InferEffectsPass : Pass<"transform-infer-effects", "mlir::ModuleOp"> {
   let summary = "infer transform side effects for symbols";
   let description = [{
     This pass analyzes the definitions of transform dialect callable symbol
@@ -42,7 +42,7 @@ def InferEffectsPass : Pass<"transform-infer-effects"> {
   }];
 }
 
-def PreloadLibraryPass : Pass<"transform-preload-library"> {
+def PreloadLibraryPass : Pass<"transform-preload-library", "mlir::ModuleOp"> {
   let summary = "preload transform dialect library";
   let description = [{
     This pass preloads a transform library and makes it available to subsequent
@@ -61,7 +61,7 @@ def PreloadLibraryPass : Pass<"transform-preload-library"> {
   ];
 }
 
-def InterpreterPass : Pass<"transform-interpreter"> {
+def InterpreterPass : Pass<"transform-interpreter", "mlir::ModuleOp"> {
   let summary = "transform dialect interpreter";
   let description = [{
     This pass runs the transform dialect interpreter and applies the named
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index d9bab431e2e0c..950f3e9c547eb 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -54,7 +54,10 @@ class OpPassManager {
     Implicit,
     /// Explicit nesting behavior. This requires that any passes added to this
     /// pass manager support its operation type.
-    Explicit
+    Explicit,
+    /// Implicitly add an "any" nesting level when scheduling a pass that handles 
+    /// "any" type.
+    ImplicitAny,
   };
 
   /// Construct a new op-agnostic ("any") pass manager with the given operation
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..b7cf19bfb790a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -273,7 +273,7 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
   let constructor = "mlir::createGenerateRuntimeVerificationPass()";
 }
 
-def Inliner : Pass<"inline"> {
+def Inliner : Pass<"inline", "mlir::ModuleOp"> {
   let summary = "Inline function calls";
   let constructor = "mlir::createInlinerPass()";
   let options = [
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 4d30213cc6ec2..6ea9e1d02bcb6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/Support/CommandLine.h"
@@ -86,8 +87,13 @@ struct AffineDataCopyGeneration
 
 /// Generates copies for memref's living in 'slowMemorySpace' into newly created
 /// buffers in 'fastMemorySpace', and replaces memory operations to the former
+<<<<<<< HEAD
 /// by the latter.
 std::unique_ptr<OperationPass<func::FuncOp>>
+=======
+/// by the latter. Only load op's handled for now.
+std::unique_ptr<AffineScopePassBase>
+>>>>>>> 9bd74f961815 (Make affine and bufferization pass applicable to any AffineScopeOp/AutomaticAllocationScope)
 mlir::affine::createAffineDataCopyGenerationPass(
     unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace,
     int minDmaTransferSize, uint64_t fastMemCapacityBytes) {
@@ -95,7 +101,7 @@ mlir::affine::createAffineDataCopyGenerationPass(
       slowMemorySpace, fastMemorySpace, tagMemorySpace, minDmaTransferSize,
       fastMemCapacityBytes);
 }
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createAffineDataCopyGenerationPass() {
   return std::make_unique<AffineDataCopyGeneration>();
 }
@@ -203,9 +209,9 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
 }
 
 void AffineDataCopyGeneration::runOnOperation() {
-  func::FuncOp f = getOperation();
-  OpBuilder topBuilder(f.getBody());
-  zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+  Operation* f = getOperation();
+  OpBuilder topBuilder(f->getRegion(0));
+  zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f->getLoc(), 0);
 
   // Nests that are copy-in's or copy-out's; the root AffineForOps of those
   // nests are stored herein.
@@ -214,7 +220,7 @@ void AffineDataCopyGeneration::runOnOperation() {
   // Clear recorded copy nests.
   copyNests.clear();
 
-  for (auto &block : f)
+  for (auto &block : f->getRegion(0))
     runOnBlock(&block, copyNests);
 
   // Promote any single iteration loops in the copy nests and collect
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index e3f316443161f..8f59a2fe54a76 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -200,10 +200,10 @@ void LoopInvariantCodeMotion::runOnOperation() {
   // Walk through all loops in a function in innermost-loop-first order.  This
   // way, we first LICM from the inner loop, and place the ops in
   // the outer loop, which in turn can be further LICM'ed.
-  getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); });
+  getOperation()->walk([&](AffineForOp op) { runOnAffineForOp(op); });
 }
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createAffineLoopInvariantCodeMotionPass() {
   return std::make_unique<LoopInvariantCodeMotion>();
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
index 5cc38f7051726..be295d9f25b22 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
@@ -38,7 +38,7 @@ struct AffineLoopNormalizePass
   }
 
   void runOnOperation() override {
-    getOperation().walk([&](Operation *op) {
+    getOperation()->walk([&](Operation *op) {
       if (auto affineParallel = dyn_cast<AffineParallelOp>(op))
         normalizeAffineParallel(affineParallel);
       else if (auto affineFor = dyn_cast<AffineForOp>(op))
@@ -49,7 +49,7 @@ struct AffineLoopNormalizePass
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createAffineLoopNormalizePass(bool promoteSingleIter) {
   return std::make_unique<AffineLoopNormalizePass>(promoteSingleIter);
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
index fa0676b206826..de5b55db3fe36 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineParallelize.cpp
@@ -60,12 +60,12 @@ struct ParallelizationCandidate {
 } // namespace
 
 void AffineParallelize::runOnOperation() {
-  func::FuncOp f = getOperation();
+  Operation* f = getOperation();
 
   // The walker proceeds in pre-order to process the outer loops first
   // and control the number of outer parallel loops.
   std::vector<ParallelizationCandidate> parallelizableLoops;
-  f.walk<WalkOrder::PreOrder>([&](AffineForOp loop) {
+  f->walk<WalkOrder::PreOrder>([&](AffineForOp loop) {
     SmallVector<LoopReduction> reductions;
     if (isLoopParallel(loop, parallelReductions ? &reductions : nullptr))
       parallelizableLoops.emplace_back(loop, std::move(reductions));
@@ -92,3 +92,8 @@ void AffineParallelize::runOnOperation() {
     }
   }
 }
+
+std::unique_ptr<AffineScopePassBase>
+mlir::affine::createAffineParallelizePass() {
+  return std::make_unique<AffineParallelize>();
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
index 16c0d3d8be3dc..98338dc473f1b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpDefinition.h"
 #include <algorithm>
 
 namespace mlir {
@@ -40,13 +42,13 @@ struct AffineScalarReplacement
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createAffineScalarReplacementPass() {
   return std::make_unique<AffineScalarReplacement>();
 }
 
 void AffineScalarReplacement::runOnOperation() {
-  affineScalarReplace(getOperation(), getAnalysis<DominanceInfo>(),
-                      getAnalysis<PostDominanceInfo>(),
-                      getAnalysis<AliasAnalysis>());
+  affineScalarReplace(
+      getOperation(), getAnalysis<DominanceInfo>(),
+      getAnalysis<PostDominanceInfo>(), getAnalysis<AliasAnalysis>());
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 05c77070a70c1..4006577cdf05e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -36,8 +36,7 @@ struct LoopCoalescingPass
     : public affine::impl::LoopCoalescingBase<LoopCoalescingPass> {
 
   void runOnOperation() override {
-    func::FuncOp func = getOperation();
-    func.walk<WalkOrder::PreOrder>([](Operation *op) {
+    getOperation()->walk<WalkOrder::PreOrder>([](Operation *op) {
       if (auto scfForOp = dyn_cast<scf::ForOp>(op))
         (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
       else if (auto affineForOp = dyn_cast<AffineForOp>(op))
@@ -48,7 +47,7 @@ struct LoopCoalescingPass
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createLoopCoalescingPass() {
   return std::make_unique<LoopCoalescingPass>();
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index c8400dfe8cd5c..628816d1d1eeb 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -64,11 +64,11 @@ struct LoopTiling : public affine::impl::AffineLoopTilingBase<LoopTiling> {
 
 /// Creates a pass to perform loop tiling on all suitable loop nests of a
 /// Function.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createLoopTilingPass(uint64_t cacheSizeBytes) {
   return std::make_unique<LoopTiling>(cacheSizeBytes);
 }
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createLoopTilingPass() {
   return std::make_unique<LoopTiling>();
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
index 7ff77968c61ad..1c95331a27841 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -82,27 +83,25 @@ static bool isInnermostAffineForOp(AffineForOp op) {
 }
 
 /// Gathers loops that have no affine.for's nested within.
-static void gatherInnermostLoops(FunctionOpInterface f,
+static void gatherInnermostLoops(Operation* f,
                                  SmallVectorImpl<AffineForOp> &loops) {
-  f.walk([&](AffineForOp forOp) {
+  f->walk([&](AffineForOp forOp) {
     if (isInnermostAffineForOp(forOp))
       loops.push_back(forOp);
   });
 }
 
 void LoopUnroll::runOnOperation() {
-  FunctionOpInterface func = getOperation();
-  if (func.isExternal())
-    return;
+  Operation* func = getOperation();
 
   if (unrollFull && unrollFullThreshold.hasValue()) {
     // Store short loops as we walk.
     SmallVector<AffineForOp, 4> loops;
 
     // Gathers all loops with trip count <= minTripCount. Do a post order walk
-    // so that loops are gathered from innermost to outermost (or else
-    // unrolling an outer one may delete gathered inner ones).
-    getOperation().walk([&](AffineForOp forOp) {
+    // so that loops are gathered from innermost to outermost (or else unrolling
+    // an outer one may delete gathered inner ones).
+    getOperation()->walk([&](AffineForOp forOp) {
       std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
       if (tripCount && *tripCount <= unrollFullThreshold)
         loops.push_back(forOp);
@@ -145,8 +144,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
                             cleanUpUnroll);
 }
 
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
-mlir::affine::createLoopUnrollPass(
+std::unique_ptr<AffineScopePassBase> mlir::affine::createLoopUnrollPass(
     int unrollFactor, bool unrollUpToFactor, bool unrollFull,
     const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
   return std::make_unique<LoopUnroll>(
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
index 13640f085951e..442fbf66fefef 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
@@ -75,7 +75,7 @@ struct LoopUnrollAndJam
 };
 } // namespace
 
-std::unique_ptr<InterfacePass<FunctionOpInterface>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createLoopUnrollAndJamPass(int unrollJamFactor) {
   return std::make_unique<LoopUnrollAndJam>(
       unrollJamFactor == -1 ? std::nullopt
@@ -83,13 +83,11 @@ mlir::affine::createLoopUnrollAndJamPass(int unrollJamFactor) {
 }
 
 void LoopUnrollAndJam::runOnOperation() {
-  if (getOperation().isExternal())
-    return;
 
   // Currently, just the outermost loop from the first loop nest is
   // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
   // any for operation.
-  auto &entryBlock = getOperation().front();
+  auto &entryBlock = getOperation()->getRegion(0).front();
   if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front()))
     (void)loopUnrollJamByFactor(forOp, unrollJamFactor);
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 4be99aa197380..4199c4d10aaff 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -52,7 +52,7 @@ struct PipelineDataTransfer
 
 /// Creates a pass to pipeline explicit movement of data across levels of the
 /// memory hierarchy.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createPipelineDataTransferPass() {
   return std::make_unique<PipelineDataTransfer>();
 }
@@ -142,7 +142,7 @@ void PipelineDataTransfer::runOnOperation() {
   // gets deleted and replaced by a prologue, a new steady-state loop and an
   // epilogue).
   forOps.clear();
-  getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
+  getOperation()->walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
   for (auto forOp : forOps)
     runOnAffineForOp(forOp);
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 31711ade3153b..337cfc8308439 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -81,24 +82,24 @@ struct SimplifyAffineStructures
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<AffineScopePassBase>
 mlir::affine::createSimplifyAffineStructuresPass() {
   return std::make_unique<SimplifyAffineStructures>();
 }
 
 void SimplifyAffineStructures::runOnOperation() {
-  auto func = getOperation();
+  Operation *func = getOperation();
   simplifiedAttributes.clear();
-  RewritePatternSet patterns(func.getContext());
-  AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
-  AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
-  AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
+  RewritePatternSet patterns(func->getContext());
+  AffineApplyOp::getCanonicalizationPatterns(patterns, func->getContext());
+  AffineForOp::getCanonicalizationPatterns(patterns, func->getContext());
+  AffineIfOp::getCanonicalizationPatterns(patterns, func->getContext());
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
   // The simplification of affine attributes will likely simplify the op. Try to
   // fold/apply canonicalization patterns when we have affine dialect ops.
   SmallVector<Operation *> opsToSimplify;
-  func.walk([&](Operation *op) {
+  func->walk([&](Operation *op) {
     for (auto attr : op->getAttrs()) {
       if (auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue()))
         simplifyAndUpdateAttribute(op, attr.getName(), mapAttr);
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index eaaafaf68767e..f78651d27f735 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
@@ -1748,21 +1749,21 @@ static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops,
 /// Applies vectorization to the current function by searching over a bunch of
 /// predetermined patterns.
 void Vectorize::runOnOperation() {
-  func::FuncOp f = getOperation();
+  Operation* f = getOperation();
   if (!fastestVaryingPattern.empty() &&
       fastestVaryingPattern.size() != vectorSizes.size()) {
-    f.emitRemark("Fastest varying pattern specified with different size than "
+    f->emitRemark("Fastest varying pattern specified with different size than "
                  "the vector size.");
     return signalPassFailure();
   }
 
   if (vectorizeReductions && vectorSizes.size() != 1) {
-    f.emitError("Vectorizing reductions is supported only for 1-D vectors.");
+    f->emitError("Vectorizing reductions is supported only for 1-D vectors.");
     return signalPassFailure();
   }
 
   if (llvm::any_of(vectorSizes, [](int64_t size) { return size <= 0; })) {
-    f.emitError("Vectorization factor must be greater than zero.");
+    f->emitError("Vectorization factor must be greater than zero.");
     return signalPassFailure();
   }
 
@@ -1772,7 +1773,7 @@ void Vectorize::runOnOperation() {
   // If 'vectorize-reduction=true' is provided, we also populate the
   // `reductionLoops` map.
   if (vectorizeReductions) {
-    f.walk([&parallelLoops, &reductionLoops](AffineForOp loop) {
+    f->walk([&parallelLoops, &reductionLoops](AffineForOp loop) {
       SmallVector<LoopReduction, 2> reductions;
       if (isLoopParallel(loop, &reductions)) {
         parallelLoops.insert(loop);
@@ -1782,7 +1783,7 @@ void Vectorize::runOnOperation() {
       }
     });
   } else {
-    f.walk([&parallelLoops](AffineForOp loop) {
+    f->walk([&parallelLoops](AffineForOp loop) {
       if (isLoopParallel(loop))
         parallelLoops.insert(loop);
     });
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 5c94ec2985c3d..bcd7db3d21052 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -871,10 +871,10 @@ void mlir::affine::getPerfectlyNestedLoops(
 /// a temporary placeholder to test the mechanics of tiled code generation.
 /// Returns all maximal outermost perfect loop nests to tile.
 void mlir::affine::getTileableBands(
-    func::FuncOp f, std::vector<SmallVector<AffineForOp, 6>> *bands) {
+    Operation* f, std::vector<SmallVector<AffineForOp, 6>> *bands) {
   // Get maximal perfect nest of 'affine.for' insts starting from root
   // (inclusive).
-  for (AffineForOp forOp : f.getOps<AffineForOp>()) {
+  for (AffineForOp forOp : f->getRegion(0).getOps<AffineForOp>()) {
     SmallVector<AffineForOp, 6> band;
     getPerfectlyNestedLoops(band, forOp);
     bands->push_back(band);
@@ -2543,8 +2543,8 @@ gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
 
 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
 void mlir::affine::gatherLoops(
-    func::FuncOp func, std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
-  for (auto &block : func)
+    Operation* func, std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
+  for (auto &block : func->getRegion(0))
     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
 
   // Remove last loop level from output since it's empty.
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 2723cff6900d0..543bff20a4199 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/LogicalResult.h"
 #include <optional>
@@ -1036,7 +1037,7 @@ static void loadCSE(AffineReadOpInterface loadA,
 // currently only eliminates the stores only if no other loads/uses (other
 // than dealloc) remain.
 //
-void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
+void mlir::affine::affineScalarReplace(Operation* parentOp, DominanceInfo &domInfo,
                                        PostDominanceInfo &postDomInfo,
                                        AliasAnalysis &aliasAnalysis) {
   // Load op's whose results were replaced by those forwarded from stores.
@@ -1050,7 +1051,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
   };
 
   // Walk all load's and perform store to load forwarding.
-  f.walk([&](AffineReadOpInterface loadOp) {
+  parentOp->walk([&](AffineReadOpInterface loadOp) {
     forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
   });
   for (auto *op : opsToErase)
@@ -1058,7 +1059,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
   opsToErase.clear();
 
   // Walk all store's and perform unused store elimination
-  f.walk([&](AffineWriteOpInterface storeOp) {
+  parentOp->walk([&](AffineWriteOpInterface storeOp) {
     findUnusedStore(storeOp, opsToErase, postDomInfo, mayAlias);
   });
   for (auto *op : opsToErase)
@@ -1091,7 +1092,7 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
   // To eliminate as many loads as possible, run load CSE after eliminating
   // stores. Otherwise, some stores are wrongly seen as having an intervening
   // effect.
-  f.walk([&](AffineReadOpInterface loadOp) {
+  parentOp->walk([&](AffineReadOpInterface loadOp) {
     loadCSE(loadOp, opsToErase, domInfo, mayAlias);
   });
   for (auto *op : opsToErase)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
new file mode 100644
index 0000000000000..73d2d4e4ca427
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -0,0 +1,693 @@
+//===- BufferDeallocation.cpp - the impl for buffer deallocation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for computing correct alloc and dealloc positions.
+// Furthermore, buffer deallocation also adds required new clone operations to
+// ensure that all buffers are deallocated. The main class is the
+// BufferDeallocationPass class that implements the underlying algorithm. In
+// order to put allocations and deallocations at safe positions, it is
+// significantly important to put them into the correct blocks. However, the
+// liveness analysis does not pay attention to aliases, which can occur due to
+// branches (and their associated block arguments) in general. For this purpose,
+// BufferDeallocation firstly finds all possible aliases for a single value
+// (using the BufferViewFlowAnalysis class). Consider the following example:
+//
+// ^bb0(%arg0):
+//   cf.cond_br %cond, ^bb1, ^bb2
+// ^bb1:
+//   cf.br ^exit(%arg0)
+// ^bb2:
+//   %new_value = ...
+//   cf.br ^exit(%new_value)
+// ^exit(%arg1):
+//   return %arg1;
+//
+// We should place the dealloc for %new_value in exit. However, we have to free
+// the buffer in the same block, because it cannot be freed in the post
+// dominator. However, this requires a new clone buffer for %arg1 that will
+// contain the actual contents. Using the class BufferViewFlowAnalysis, we
+// will find out that %new_value has a potential alias %arg1. In order to find
+// the dealloc position we have to find all potential aliases, iterate over
+// their uses and find the common post-dominator block (note that additional
+// clones and buffers remove potential aliases and will influence the placement
+// of the deallocs). In all cases, the computed block can be safely used to free
+// the %new_value buffer (may be exit or bb2) as it will die and we can use
+// liveness information to determine the exact operation after which we have to
+// insert the dealloc. However, the algorithm supports introducing clone buffers
+// and placing deallocs in safe locations to ensure that all buffers will be
+// freed in the end.
+//
+// TODO:
+// The current implementation does not support explicit-control-flow loops and
+// the resulting code will be invalid with respect to program semantics.
+// However, structured control-flow loops are fully supported. Furthermore, it
+// doesn't accept functions which return buffers already.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "llvm/ADT/SetOperations.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_BUFFERDEALLOCATION
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+/// Walks over all immediate return-like terminators in the given region.
+static LogicalResult walkReturnOperations(
+    Region *region,
+    llvm::function_ref<LogicalResult(RegionBranchTerminatorOpInterface)> func) {
+  for (Block &block : *region) {
+    Operation *terminator = block.getTerminator();
+    // Skip non region-return-like terminators.
+    if (auto regionTerminator =
+            dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
+      if (failed(func(regionTerminator)))
+        return failure();
+    }
+  }
+  return success();
+}
+
+/// Checks if all operations that have at least one attached region implement
+/// the RegionBranchOpInterface. This is not required in edge cases, where we
+/// have a single attached region and the parent operation has no results.
+static bool validateSupportedControlFlow(Operation *op) {
+  WalkResult result = op->walk([&](Operation *operation) {
+    // Only check ops that are inside a function.
+    if (!operation->getParentOfType<func::FuncOp>())
+      return WalkResult::advance();
+
+    auto regions = operation->getRegions();
+    // Walk over all operations in a region and check if the operation has at
+    // least one region and implements the RegionBranchOpInterface. If there
+    // is an operation that does not fulfill this condition, we cannot apply
+    // the deallocation steps. Furthermore, we accept cases, where we have a
+    // region that returns no results, since, in that case, the intra-region
+    // control flow does not affect the transformation.
+    size_t size = regions.size();
+    if (((size == 1 && !operation->getResults().empty()) || size > 1) &&
+        !dyn_cast<RegionBranchOpInterface>(operation)) {
+      operation->emitError("All operations with attached regions need to "
+                           "implement the RegionBranchOpInterface.");
+    }
+
+    return WalkResult::advance();
+  });
+  return !result.wasSkipped();
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Backedges analysis
+//===----------------------------------------------------------------------===//
+
+/// A straight-forward program analysis which detects loop backedges induced by
+/// explicit control flow.
+class Backedges {
+public:
+  using BlockSetT = SmallPtrSet<Block *, 16>;
+  using BackedgeSetT = llvm::DenseSet<std::pair<Block *, Block *>>;
+
+public:
+  /// Constructs a new backedges analysis using the op provided.
+  Backedges(Operation *op) { recurse(op); }
+
+  /// Returns the number of backedges formed by explicit control flow.
+  size_t size() const { return edgeSet.size(); }
+
+  /// Returns the start iterator to loop over all backedges.
+  BackedgeSetT::const_iterator begin() const { return edgeSet.begin(); }
+
+  /// Returns the end iterator to loop over all backedges.
+  BackedgeSetT::const_iterator end() const { return edgeSet.end(); }
+
+private:
+  /// Enters the current block and inserts a backedge into the `edgeSet` if we
+  /// have already visited the current block. The inserted edge links the given
+  /// `predecessor` with the `current` block.
+  bool enter(Block &current, Block *predecessor) {
+    bool inserted = visited.insert(&current).second;
+    if (!inserted)
+      edgeSet.insert(std::make_pair(predecessor, &current));
+    return inserted;
+  }
+
+  /// Leaves the current block.
+  void exit(Block &current) { visited.erase(&current); }
+
+  /// Recurses into the given operation while taking all attached regions into
+  /// account.
+  void recurse(Operation *op) {
+    Block *current = op->getBlock();
+    // If the current op implements the `BranchOpInterface`, there can be
+    // cycles in the scope of all successor blocks.
+    if (isa<BranchOpInterface>(op)) {
+      for (Block *succ : current->getSuccessors())
+        recurse(*succ, current);
+    }
+    // Recurse into all distinct regions and check for explicit control-flow
+    // loops.
+    for (Region &region : op->getRegions()) {
+      if (!region.empty())
+        recurse(region.front(), current);
+    }
+  }
+
+  /// Recurses into explicit control-flow structures that are given by
+  /// the successor relation defined on the block level.
+  void recurse(Block &block, Block *predecessor) {
+    // Try to enter the current block. If this is not possible, we are
+    // currently processing this block and can safely return here.
+    if (!enter(block, predecessor))
+      return;
+
+    // Recurse into all operations and successor blocks.
+    for (Operation &op : block.getOperations())
+      recurse(&op);
+
+    // Leave the current block.
+    exit(block);
+  }
+
+  /// Stores all blocks that are currently visited and on the processing stack.
+  BlockSetT visited;
+
+  /// Stores all backedges in the format (source, target).
+  BackedgeSetT edgeSet;
+};
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocation
+//===----------------------------------------------------------------------===//
+
+/// The buffer deallocation transformation which ensures that all allocs in the
+/// program have a corresponding de-allocation. As a side-effect, it might also
+/// introduce clones that in turn leads to additional deallocations.
+class BufferDeallocation : public BufferPlacementTransformationBase {
+public:
+  using AliasAllocationMapT =
+      llvm::DenseMap<Value, bufferization::AllocationOpInterface>;
+
+  BufferDeallocation(Operation *op)
+      : BufferPlacementTransformationBase(op), dominators(op),
+        postDominators(op) {}
+
+  /// Checks if all allocation operations either provide an already existing
+  /// deallocation operation or implement the AllocationOpInterface. In
+  /// addition, this method initializes the internal alias to
+  /// AllocationOpInterface mapping in order to get compatible
+  /// AllocationOpInterface implementations for aliases.
+  LogicalResult prepare() {
+    for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
+      // Get the defining allocation operation.
+      Value alloc = std::get<0>(entry);
+      auto allocationInterface =
+          alloc.getDefiningOp<bufferization::AllocationOpInterface>();
+      // If there is no existing deallocation operation and no implementation of
+      // the AllocationOpInterface, we cannot apply the BufferDeallocation pass.
+      if (!std::get<1>(entry) && !allocationInterface) {
+        return alloc.getDefiningOp()->emitError(
+            "Allocation is not deallocated explicitly nor does the operation "
+            "implement the AllocationOpInterface.");
+      }
+
+      // Register the current allocation interface implementation.
+      aliasToAllocations[alloc] = allocationInterface;
+
+      // Get the alias information for the current allocation node.
+      for (Value alias : aliases.resolve(alloc)) {
+        // TODO: check for incompatible implementations of the
+        // AllocationOpInterface. This could be realized by promoting the
+        // AllocationOpInterface to a DialectInterface.
+        aliasToAllocations[alias] = allocationInterface;
+      }
+    }
+    return success();
+  }
+
+  /// Performs the actual placement/creation of all temporary clone and dealloc
+  /// nodes.
+  LogicalResult deallocate() {
+    // Add additional clones that are required.
+    if (failed(introduceClones()))
+      return failure();
+
+    // Place deallocations for all allocation entries.
+    return placeDeallocs();
+  }
+
+private:
+  /// Introduces required clone operations to avoid memory leaks.
+  LogicalResult introduceClones() {
+    // Initialize the set of values that require a dedicated memory free
+    // operation since their operands cannot be safely deallocated in a post
+    // dominator.
+    SetVector<Value> valuesToFree;
+    llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
+    SmallVector<std::tuple<Value, Block *>, 8> toProcess;
+
+    // Check dominance relation for proper dominance properties. If the given
+    // value node does not dominate an alias, we will have to create a clone in
+    // order to free all buffers that can potentially leak into a post
+    // dominator.
+    auto findUnsafeValues = [&](Value source, Block *definingBlock) {
+      auto it = aliases.find(source);
+      if (it == aliases.end())
+        return;
+      for (Value value : it->second) {
+        if (valuesToFree.count(value) > 0)
+          continue;
+        Block *parentBlock = value.getParentBlock();
+        // Check whether we have to free this particular block argument or
+        // generic value. We have to free the current alias if it is either
+        // defined in a non-dominated block or it is defined in the same block
+        // but the current value is not dominated by the source value.
+        if (!dominators.dominates(definingBlock, parentBlock) ||
+            (definingBlock == parentBlock && isa<BlockArgument>(value))) {
+          toProcess.emplace_back(value, parentBlock);
+          valuesToFree.insert(value);
+        } else if (visitedValues.insert(std::make_tuple(value, definingBlock))
+                       .second)
+          toProcess.emplace_back(value, definingBlock);
+      }
+    };
+
+    // Detect possibly unsafe aliases starting from all allocations.
+    for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
+      Value allocValue = std::get<0>(entry);
+      findUnsafeValues(allocValue, allocValue.getDefiningOp()->getBlock());
+    }
+    // Try to find block arguments that require an explicit free operation
+    // until we reach a fix point.
+    while (!toProcess.empty()) {
+      auto current = toProcess.pop_back_val();
+      findUnsafeValues(std::get<0>(current), std::get<1>(current));
+    }
+
+    // Update buffer aliases to ensure that we free all buffers and block
+    // arguments at the correct locations.
+    aliases.remove(valuesToFree);
+
+    // Add new allocs and additional clone operations.
+    for (Value value : valuesToFree) {
+      if (failed(isa<BlockArgument>(value)
+                     ? introduceBlockArgCopy(cast<BlockArgument>(value))
+                     : introduceValueCopyForRegionResult(value)))
+        return failure();
+
+      // Register the value to require a final dealloc. Note that we do not have
+      // to assign a block here since we do not want to move the allocation node
+      // to another location.
+      allocs.registerAlloc(std::make_tuple(value, nullptr));
+    }
+    return success();
+  }
+
+  /// Introduces temporary clones in all predecessors and copies the source
+  /// values into the newly allocated buffers.
+  LogicalResult introduceBlockArgCopy(BlockArgument blockArg) {
+    // Allocate a buffer for the current block argument in the block of
+    // the associated value (which will be a predecessor block by
+    // definition).
+    Block *block = blockArg.getOwner();
+    for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
+      // Get the terminator and the value that will be passed to our
+      // argument.
+      Operation *terminator = (*it)->getTerminator();
+      auto branchInterface = cast<BranchOpInterface>(terminator);
+      SuccessorOperands operands =
+          branchInterface.getSuccessorOperands(it.getSuccessorIndex());
+
+      // Query the associated source value.
+      Value sourceValue = operands[blockArg.getArgNumber()];
+      if (!sourceValue) {
+        return failure();
+      }
+      // Wire new clone and successor operand.
+      // Create a new clone at the current location of the terminator.
+      auto clone = introduceCloneBuffers(sourceValue, terminator);
+      if (failed(clone))
+        return failure();
+      operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
+    }
+
+    // Check whether the block argument has implicitly defined predecessors via
+    // the RegionBranchOpInterface. This can be the case if the current block
+    // argument belongs to the first block in a region and the parent operation
+    // implements the RegionBranchOpInterface.
+    Region *argRegion = block->getParent();
+    Operation *parentOp = argRegion->getParentOp();
+    RegionBranchOpInterface regionInterface;
+    if (&argRegion->front() != block ||
+        !(regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)))
+      return success();
+
+    if (failed(introduceClonesForRegionSuccessors(
+            regionInterface, argRegion->getParentOp()->getRegions(), blockArg,
+            [&](RegionSuccessor &successorRegion) {
+              // Find a predecessor of our argRegion.
+              return successorRegion.getSuccessor() == argRegion;
+            })))
+      return failure();
+
+    // Check whether the block argument belongs to an entry region of the
+    // parent operation. In this case, we have to introduce an additional clone
+    // for buffer that is passed to the argument.
+    SmallVector<RegionSuccessor, 2> successorRegions;
+    regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
+                                        successorRegions);
+    auto *it =
+        llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
+          return successorRegion.getSuccessor() == argRegion;
+        });
+    if (it == successorRegions.end())
+      return success();
+
+    // Determine the actual operand to introduce a clone for and rewire the
+    // operand to point to the clone instead.
+    auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
+    size_t operandIndex =
+        llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
+        operands.getBeginOperandIndex();
+    Value operand = parentOp->getOperand(operandIndex);
+    assert(operand ==
+               operands[operandIndex - operands.getBeginOperandIndex()] &&
+           "region interface operands don't match parentOp operands");
+    auto clone = introduceCloneBuffers(operand, parentOp);
+    if (failed(clone))
+      return failure();
+
+    parentOp->setOperand(operandIndex, *clone);
+    return success();
+  }
+
+  /// Introduces temporary clones in front of all associated nested-region
+  /// terminators and copies the source values into the newly allocated buffers.
+  LogicalResult introduceValueCopyForRegionResult(Value value) {
+    // Get the actual result index in the scope of the parent terminator.
+    Operation *operation = value.getDefiningOp();
+    auto regionInterface = cast<RegionBranchOpInterface>(operation);
+    // Filter successors that return to the parent operation.
+    auto regionPredicate = [&](RegionSuccessor &successorRegion) {
+      // If the RegionSuccessor has no associated successor, it will return to
+      // its parent operation.
+      return !successorRegion.getSuccessor();
+    };
+    // Introduce a clone for all region "results" that are returned to the
+    // parent operation. This is required since the parent's result value has
+    // been considered critical. Therefore, the algorithm assumes that a clone
+    // of a previously allocated buffer is returned by the operation (like in
+    // the case of a block argument).
+    return introduceClonesForRegionSuccessors(
+        regionInterface, operation->getRegions(), value, regionPredicate);
+  }
+
+  /// Introduces buffer clones for all terminators in the given regions. The
+  /// regionPredicate is applied to every successor region in order to restrict
+  /// the clones to specific regions.
+  template <typename TPredicate>
+  LogicalResult introduceClonesForRegionSuccessors(
+      RegionBranchOpInterface regionInterface, MutableArrayRef<Region> regions,
+      Value argValue, const TPredicate &regionPredicate) {
+    for (Region &region : regions) {
+      // Query the regionInterface to get all successor regions of the current
+      // one.
+      SmallVector<RegionSuccessor, 2> successorRegions;
+      regionInterface.getSuccessorRegions(region, successorRegions);
+      // Try to find a matching region successor.
+      RegionSuccessor *regionSuccessor =
+          llvm::find_if(successorRegions, regionPredicate);
+      if (regionSuccessor == successorRegions.end())
+        continue;
+      // Get the operand index in the context of the current successor input
+      // bindings.
+      size_t operandIndex =
+          llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
+              .getIndex();
+
+      // Iterate over all immediate terminator operations to introduce
+      // new buffer allocations. Thereby, the appropriate terminator operand
+      // will be adjusted to point to the newly allocated buffer instead.
+      if (failed(walkReturnOperations(
+              &region, [&](RegionBranchTerminatorOpInterface terminator) {
+                // Get the actual mutable operands for this terminator op.
+                auto terminatorOperands =
+                    terminator.getMutableSuccessorOperands(*regionSuccessor);
+                // Extract the source value from the current terminator.
+                // This conversion needs to exist on a separate line due to a
+                // bug in GCC conversion analysis.
+                OperandRange immutableTerminatorOperands = terminatorOperands;
+                Value sourceValue = immutableTerminatorOperands[operandIndex];
+                // Create a new clone at the current location of the terminator.
+                auto clone = introduceCloneBuffers(sourceValue, terminator);
+                if (failed(clone))
+                  return failure();
+                // Wire clone and terminator operand.
+                terminatorOperands.slice(operandIndex, 1).assign(*clone);
+                return success();
+              })))
+        return failure();
+    }
+    return success();
+  }
+
+  /// Creates a new memory allocation for the given source value and clones
+  /// its content into the newly allocated buffer. The terminator operation is
+  /// used to insert the clone operation at the right place.
+  FailureOr<Value> introduceCloneBuffers(Value sourceValue,
+                                         Operation *terminator) {
+    // Avoid multiple clones of the same source value. This can happen in the
+    // presence of loops when a branch acts as a backedge while also having
+    // another successor that returns to its parent operation. Note: that
+    // copying copied buffers can introduce memory leaks since the invariant of
+    // BufferDeallocation assumes that a buffer will be only cloned once into a
+    // temporary buffer. Hence, the construction of clone chains introduces
+    // additional allocations that are not tracked automatically by the
+    // algorithm.
+    if (clonedValues.contains(sourceValue))
+      return sourceValue;
+    // Create a new clone operation that copies the contents of the old
+    // buffer to the new one.
+    auto clone = buildClone(terminator, sourceValue);
+    if (succeeded(clone)) {
+      // Remember the clone of original source value.
+      clonedValues.insert(*clone);
+    }
+    return clone;
+  }
+
+  /// Finds correct dealloc positions according to the algorithm described at
+  /// the top of the file for all alloc nodes and block arguments that can be
+  /// handled by this analysis.
+  LogicalResult placeDeallocs() {
+    // Move or insert deallocs using the previously computed information.
+    // These deallocations will be linked to their associated allocation nodes
+    // since they don't have any aliases that can (potentially) increase their
+    // liveness.
+    for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
+      Value alloc = std::get<0>(entry);
+      auto aliasesSet = aliases.resolve(alloc);
+      assert(!aliasesSet.empty() && "must contain at least one alias");
+
+      // Determine the actual block to place the dealloc and get liveness
+      // information.
+      Block *placementBlock =
+          findCommonDominator(alloc, aliasesSet, postDominators);
+      const LivenessBlockInfo *livenessInfo =
+          liveness.getLiveness(placementBlock);
+
+      // We have to ensure that the dealloc will be after the last use of all
+      // aliases of the given value. We first assume that there are no uses in
+      // the placementBlock and that we can safely place the dealloc at the
+      // beginning.
+      Operation *endOperation = &placementBlock->front();
+
+      // Iterate over all aliases and ensure that the endOperation will point
+      // to the last operation of all potential aliases in the placementBlock.
+      for (Value alias : aliasesSet) {
+        // Ensure that the start operation is at least the defining operation of
+        // the current alias to avoid invalid placement of deallocs for aliases
+        // without any uses.
+        Operation *beforeOp = endOperation;
+        if (alias.getDefiningOp() &&
+            !(beforeOp = placementBlock->findAncestorOpInBlock(
+                  *alias.getDefiningOp())))
+          continue;
+
+        Operation *aliasEndOperation =
+            livenessInfo->getEndOperation(alias, beforeOp);
+        // Check whether the aliasEndOperation lies in the desired block and
+        // whether it is behind the current endOperation. If yes, this will be
+        // the new endOperation.
+        if (aliasEndOperation->getBlock() == placementBlock &&
+            endOperation->isBeforeInBlock(aliasEndOperation))
+          endOperation = aliasEndOperation;
+      }
+      // endOperation is the last operation behind which we can safely store
+      // the dealloc taking all potential aliases into account.
+
+      // If there is an existing dealloc, move it to the right place.
+      Operation *deallocOperation = std::get<1>(entry);
+      if (deallocOperation) {
+        deallocOperation->moveAfter(endOperation);
+      } else {
+        // If the Dealloc position is at the terminator operation of the
+        // block, then the value should escape from a deallocation.
+        Operation *nextOp = endOperation->getNextNode();
+        if (!nextOp)
+          continue;
+        // If there is no dealloc node, insert one in the right place.
+        if (failed(buildDealloc(nextOp, alloc)))
+          return failure();
+      }
+    }
+    return success();
+  }
+
+  /// Builds a deallocation operation compatible with the given allocation
+  /// value. If there is no registered AllocationOpInterface implementation for
+  /// the given value (e.g. in the case of a function parameter), this method
+  /// builds a memref::DeallocOp.
+  LogicalResult buildDealloc(Operation *op, Value alloc) {
+    OpBuilder builder(op);
+    auto it = aliasToAllocations.find(alloc);
+    if (it != aliasToAllocations.end()) {
+      // Call the allocation op interface to build a supported and
+      // compatible deallocation operation.
+      auto dealloc = it->second.buildDealloc(builder, alloc);
+      if (!dealloc)
+        return op->emitError()
+               << "allocations without compatible deallocations are "
+                  "not supported";
+    } else {
+      // Build a "default" DeallocOp for unknown allocation sources.
+      builder.create<memref::DeallocOp>(alloc.getLoc(), alloc);
+    }
+    return success();
+  }
+
+  /// Builds a clone operation compatible with the given allocation value. If
+  /// there is no registered AllocationOpInterface implementation for the given
+  /// value (e.g. in the case of a function parameter), this method builds a
+  /// bufferization::CloneOp.
+  FailureOr<Value> buildClone(Operation *op, Value alloc) {
+    OpBuilder builder(op);
+    auto it = aliasToAllocations.find(alloc);
+    if (it != aliasToAllocations.end()) {
+      // Call the allocation op interface to build a supported and
+      // compatible clone operation.
+      auto clone = it->second.buildClone(builder, alloc);
+      if (clone)
+        return *clone;
+      return (LogicalResult)(op->emitError()
+                             << "allocations without compatible clone ops "
+                                "are not supported");
+    }
+    // Build a "default" CloneOp for unknown allocation sources.
+    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
+        .getResult();
+  }
+
+  /// The dominator info to find the appropriate start operation to move the
+  /// allocs.
+  DominanceInfo dominators;
+
+  /// The post dominator info to move the dependent allocs in the right
+  /// position.
+  PostDominanceInfo postDominators;
+
+  /// Stores already cloned buffers to avoid additional clones of clones.
+  ValueSetT clonedValues;
+
+  /// Maps aliases to their source allocation interfaces (inverse mapping).
+  AliasAllocationMapT aliasToAllocations;
+};
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocationPass
+//===----------------------------------------------------------------------===//
+
+/// The actual buffer deallocation pass that inserts and moves dealloc nodes
+/// into the right positions. Furthermore, it inserts additional clones if
+/// necessary. It uses the algorithm described at the top of the file.
+struct BufferDeallocationPass
+    : public bufferization::impl::BufferDeallocationBase<
+          BufferDeallocationPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect>();
+    registry.insert<memref::MemRefDialect>();
+  }
+
+  void runOnOperation() override {
+    Operation* func = getOperation();
+    if (func->getRegion(0).empty())
+      return;
+
+    if (failed(deallocateBuffers(func)))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+LogicalResult bufferization::deallocateBuffers(Operation *op) {
+  if (isa<ModuleOp>(op)) {
+    WalkResult result = op->walk([&](func::FuncOp funcOp) {
+      if (failed(deallocateBuffers(funcOp)))
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+    return success(!result.wasInterrupted());
+  }
+
+  // Ensure that there are supported loops only.
+  Backedges backedges(op);
+  if (backedges.size()) {
+    op->emitError("Only structured control-flow loops are supported.");
+    return failure();
+  }
+
+  // Check that the control flow structures are supported.
+  if (!validateSupportedControlFlow(op))
+    return failure();
+
+  // Gather all required allocation nodes and prepare the deallocation phase.
+  BufferDeallocation deallocation(op);
+
+  // Check for supported AllocationOpInterface implementations and prepare the
+  // internal deallocation pass.
+  if (failed(deallocation.prepare()))
+    return failure();
+
+  // Place all required temporary clone and dealloc nodes.
+  if (failed(deallocation.deallocate()))
+    return failure();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocationPass construction
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<Pass> mlir::bufferization::createBufferDeallocationPass() {
+  return std::make_unique<BufferDeallocationPass>();
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..50104e8f8346b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRBufferizationTransforms
   Bufferize.cpp
+  BufferDeallocation.cpp
   BufferDeallocationSimplification.cpp
   BufferOptimizations.cpp
   BufferResultsToOutParams.cpp
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 67c18189b85e0..e10aa05b1aec7 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
@@ -206,7 +207,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 
 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
-  addPass(std::unique_ptr<Pass>(adaptor));
+  passes.emplace_back(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
 }
 
@@ -216,12 +217,20 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
   std::optional<StringRef> pmOpName = getOpName();
   std::optional<StringRef> passOpName = pass->getOpName();
   if (pmOpName && passOpName && *pmOpName != *passOpName) {
-    if (nesting == OpPassManager::Nesting::Implicit)
-      return nest(*passOpName).addPass(std::move(pass));
+    if (nesting != OpPassManager::Nesting::Explicit)
+      return nest(OpPassManager(*passOpName, OpPassManager::Nesting::Explicit))
+          .addPass(std::move(pass));
     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
                              "' restricted to '" + *passOpName +
                              "' on a PassManager intended to run on '" +
                              getOpAnchorName() + "', did you intend to nest?");
+  } else if (pmOpName && !passOpName &&
+             nesting == OpPassManager::Nesting::ImplicitAny) {
+    nesting = OpPassManager::Nesting::Implicit;
+    nest(OpPassManager(OpPassManager::Nesting::Explicit))
+        .addPass(std::move(pass));
+    nesting = OpPassManager::Nesting::ImplicitAny;
+    return;
   }
 
   passes.emplace_back(std::move(pass));
@@ -464,7 +473,6 @@ llvm::hash_code OpPassManager::hash() {
   return hashCode;
 }
 
-
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
@@ -869,7 +877,8 @@ LogicalResult PassManager::run(Operation *op) {
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
   llvm::hash_code pipelineKey = hash();
-  if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
+  if (newInitKey != initializationKey ||
+      pipelineKey != pipelineInitializationKey) {
     if (failed(initialize(context, impl->initializationGeneration + 1)))
       return failure();
     initializationKey = newInitKey;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 9bbf91de18305..b352bc319bd62 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -460,7 +460,7 @@ performActions(raw_ostream &os,
   context->enableMultithreading(wasThreadingEnabled);
 
   // Prepare the pass manager, applying command-line and reproducer options.
-  PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
+  PassManager pm(op.get()->getName(), PassManager::Nesting::ImplicitAny);
   pm.enableVerifier(config.shouldVerifyPasses());
   if (failed(applyPassManagerCLOptions(pm)))
     return failure();
diff --git a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
index 751302550092d..3407a043332dc 100644
--- a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 
 #define PASS_NAME "test-affine-access-analysis"
@@ -23,7 +22,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestAccessAnalysis
-    : public PassWrapper<TestAccessAnalysis, OperationPass<func::FuncOp>> {
+    : public PassWrapper<TestAccessAnalysis, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAccessAnalysis)
 
   StringRef getArgument() const final { return PASS_NAME; }
@@ -52,7 +51,7 @@ void TestAccessAnalysis::runOnOperation() {
   SmallVector<AffineForOp> enclosingOps;
   // Go over all top-level affine.for ops and test each contained affine
   // access's contiguity along every surrounding loop IV.
-  for (auto forOp : getOperation().getOps<AffineForOp>()) {
+  for (auto forOp : getOperation()->getRegion(0).getOps<AffineForOp>()) {
     loadStores.clear();
     gatherLoadsAndStores(forOp, loadStores);
     for (Operation *memOp : loadStores) {
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 404f34ebee17a..12b832950ba85 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -28,7 +27,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestAffineDataCopy
-    : public PassWrapper<TestAffineDataCopy, OperationPass<func::FuncOp>> {
+    : public PassWrapper<TestAffineDataCopy, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineDataCopy)
 
   StringRef getArgument() const final { return PASS_NAME; }
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
index f8e76356c4321..0543769768aea 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
@@ -23,8 +23,7 @@ using namespace mlir::affine;
 
 namespace {
 struct TestAffineLoopParametricTiling
-    : public PassWrapper<TestAffineLoopParametricTiling,
-                         OperationPass<func::FuncOp>> {
+    : public PassWrapper<TestAffineLoopParametricTiling, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopParametricTiling)
 
   StringRef getArgument() const final { return "test-affine-parametric-tile"; }
diff --git a/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp
index 429784f26e038..1faf01f51ec25 100644
--- a/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp
@@ -26,7 +26,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestDecomposeAffineOps
-    : public PassWrapper<TestDecomposeAffineOps, OperationPass<func::FuncOp>> {
+    : public PassWrapper<TestDecomposeAffineOps, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeAffineOps)
 
   StringRef getArgument() const final { return PASS_NAME; }
@@ -43,7 +43,7 @@ struct TestDecomposeAffineOps
 
 void TestDecomposeAffineOps::runOnOperation() {
   IRRewriter rewriter(&getContext());
-  this->getOperation().walk([&](AffineApplyOp op) {
+  this->getOperation()->walk([&](AffineApplyOp op) {
     rewriter.setInsertionPoint(op);
     reorderOperandsByHoistability(rewriter, op);
     (void)decompose(rewriter, op);
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
index 19011803a793a..7ea47bac8d562 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 
@@ -25,7 +26,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestLoopFusion
-    : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
+    : public PassWrapper<TestLoopFusion, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
 
   StringRef getArgument() const final { return "test-loop-fusion"; }
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 891b3bab8629d..ed05c92d48491 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
@@ -31,8 +32,7 @@ namespace {
 
 /// This pass applies the permutation on the first maximal perfect nest.
 struct TestReifyValueBounds
-    : public PassWrapper<TestReifyValueBounds,
-                         InterfacePass<FunctionOpInterface>> {
+    : public PassWrapper<TestReifyValueBounds, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
 
   StringRef getArgument() const final { return PASS_NAME; }
@@ -76,11 +76,11 @@ invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
 
 /// Look for "test.reify_bound" ops in the input and replace their results with
 /// the reified values.
-static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
+static LogicalResult testReifyValueBounds(Operation* funcOp,
                                           bool reifyToFuncArgs,
                                           bool useArithOps) {
-  IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
+  IRRewriter rewriter(funcOp->getContext());
+  WalkResult result = funcOp->walk([&](test::ReifyBoundOp op) {
     auto boundType = op.getBoundType();
     Value value = op.getVar();
     std::optional<int64_t> dim = op.getDim();
@@ -158,9 +158,9 @@ static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
 }
 
 /// Look for "test.compare" ops and emit errors/remarks.
-static LogicalResult testEquality(FunctionOpInterface funcOp) {
-  IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](test::CompareOp op) {
+static LogicalResult testEquality(Operation* funcOp) {
+  IRRewriter rewriter(funcOp->getContext());
+  WalkResult result = funcOp->walk([&](test::CompareOp op) {
     auto cmpType = op.getComparisonOperator();
     if (op.getCompose()) {
       if (cmpType != ValueBoundsConstraintSet::EQ) {
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index c32bd24014215..70cbefe9144ea 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -24,7 +24,6 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Passes.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CommandLine.h"
@@ -39,7 +38,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
 
 namespace {
 struct VectorizerTestPass
-    : public PassWrapper<VectorizerTestPass, OperationPass<func::FuncOp>> {
+    : public PassWrapper<VectorizerTestPass, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass)
 
   static constexpr auto kTestAffineMapOpName = "test_affine_map";
@@ -53,7 +52,7 @@ struct VectorizerTestPass
   }
 
   VectorizerTestPass() = default;
-  VectorizerTestPass(const VectorizerTestPass &pass) : PassWrapper(pass){};
+  VectorizerTestPass(const VectorizerTestPass &pass) : PassWrapper(pass) {};
 
   ListOption<int> clTestVectorShapeRatio{
       *this, "vector-shape-ratio",
@@ -97,11 +96,12 @@ struct VectorizerTestPass
 } // namespace
 
 void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
-  auto f = getOperation();
+  auto *f = getOperation();
   using affine::matcher::Op;
   SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(),
                                 clTestVectorShapeRatio.end());
-  auto subVectorType = VectorType::get(shape, Float32Type::get(f.getContext()));
+  auto subVectorType =
+      VectorType::get(shape, FloatType::getF32(f->getContext()));
   // Only filter operations that operate on a strict super-vector and have one
   // return. This makes testing easier.
   auto filter = [&](Operation &op) {
@@ -147,7 +147,7 @@ static NestedPattern patternTestSlicingOps() {
 }
 
 void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
-  auto f = getOperation();
+  auto f = cast<func::FuncOp>(getOperation());
   outs << "\n" << f.getName();
 
   SmallVector<NestedMatch, 8> matches;
@@ -163,7 +163,7 @@ void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) {
 }
 
 void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
-  auto f = getOperation();
+  auto f = cast<func::FuncOp>(getOperation());
   outs << "\n" << f.getName();
 
   SmallVector<NestedMatch, 8> matches;
@@ -179,7 +179,7 @@ void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) {
 }
 
 void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) {
-  auto f = getOperation();
+  auto f = cast<func::FuncOp>(getOperation());
   outs << "\n" << f.getName();
 
   SmallVector<NestedMatch, 8> matches;
@@ -198,7 +198,7 @@ static bool customOpWithAffineMapAttribute(Operation &op) {
 }
 
 void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
-  auto f = getOperation();
+  auto f = cast<func::FuncOp>(getOperation());
 
   using affine::matcher::Op;
   auto pattern = Op(customOpWithAffineMapAttribute);
@@ -252,7 +252,7 @@ void VectorizerTestPass::testVecAffineLoopNest(llvm::raw_ostream &outs) {
 
 void VectorizerTestPass::runOnOperation() {
   // Only support single block functions at this point.
-  func::FuncOp f = getOperation();
+  func::FuncOp f = cast<func::FuncOp>(getOperation());
   if (!llvm::hasSingleElement(f))
     return;
 

>From 1d3db2073ac5f2520eee859fbec70ac482edfa5a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sat, 26 Oct 2024 19:36:42 +0200
Subject: [PATCH 05/23] Add CLI flag to select anchor

---
 .../include/mlir/Tools/mlir-opt/MlirOptMain.h | 19 ++++++
 mlir/lib/Pass/Pass.cpp                        | 19 ++----
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       | 58 ++++++++++++++-----
 3 files changed, 67 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 09bd86b9581df..24270f8f40c0f 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -14,12 +14,15 @@
 #define MLIR_TOOLS_MLIROPT_MLIROPTMAIN_H
 
 #include "mlir/Debug/CLOptionsSetup.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/ToolUtilities.h"
 #include "llvm/ADT/StringRef.h"
 
 #include <cstdlib>
 #include <functional>
 #include <memory>
+#include <optional>
+#include <string>
 
 namespace llvm {
 class raw_ostream;
@@ -141,6 +144,18 @@ class MlirOptMainConfig {
   }
   bool shouldListPasses() const { return listPassesFlag; }
 
+  MlirOptMainConfig& setPassPipelineAnchor(std::string&& name) {
+    passPipelineAnchorFlag = std::move(name);
+    return *this;
+  }
+
+  std::optional<StringRef> getPassPipelineAnchor() const {
+    if (passPipelineAnchorFlag.empty()) {
+      return std::nullopt;
+    }
+    return passPipelineAnchorFlag;
+  }
+
   /// Enable running the reproducer information stored in resources (if
   /// present).
   MlirOptMainConfig &runReproducer(bool enableReproducer) {
@@ -274,6 +289,10 @@ class MlirOptMainConfig {
   /// Merge output chunks into one file using the given marker.
   std::string outputSplitMarkerFlag = "";
 
+  /// Specify an operation name as the anchor for the CLI pass pipeline.
+  /// By default the pipeline is anchored on the root of the IR.
+  std::string passPipelineAnchorFlag = "";
+
   /// Use an explicit top-level module op during parsing.
   bool useExplicitModuleFlag = false;
 
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index e10aa05b1aec7..67c18189b85e0 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -17,7 +17,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
-#include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
@@ -207,7 +206,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 
 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
-  passes.emplace_back(std::unique_ptr<Pass>(adaptor));
+  addPass(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
 }
 
@@ -217,20 +216,12 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
   std::optional<StringRef> pmOpName = getOpName();
   std::optional<StringRef> passOpName = pass->getOpName();
   if (pmOpName && passOpName && *pmOpName != *passOpName) {
-    if (nesting != OpPassManager::Nesting::Explicit)
-      return nest(OpPassManager(*passOpName, OpPassManager::Nesting::Explicit))
-          .addPass(std::move(pass));
+    if (nesting == OpPassManager::Nesting::Implicit)
+      return nest(*passOpName).addPass(std::move(pass));
     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
                              "' restricted to '" + *passOpName +
                              "' on a PassManager intended to run on '" +
                              getOpAnchorName() + "', did you intend to nest?");
-  } else if (pmOpName && !passOpName &&
-             nesting == OpPassManager::Nesting::ImplicitAny) {
-    nesting = OpPassManager::Nesting::Implicit;
-    nest(OpPassManager(OpPassManager::Nesting::Explicit))
-        .addPass(std::move(pass));
-    nesting = OpPassManager::Nesting::ImplicitAny;
-    return;
   }
 
   passes.emplace_back(std::move(pass));
@@ -473,6 +464,7 @@ llvm::hash_code OpPassManager::hash() {
   return hashCode;
 }
 
+
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
@@ -877,8 +869,7 @@ LogicalResult PassManager::run(Operation *op) {
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
   llvm::hash_code pipelineKey = hash();
-  if (newInitKey != initializationKey ||
-      pipelineKey != pipelineInitializationKey) {
+  if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
     if (failed(initialize(context, impl->initializationGeneration + 1)))
       return failure();
     initializationKey = newInitKey;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index b352bc319bd62..598d9d2d4355d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -27,6 +27,8 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -37,6 +39,7 @@
 #include "mlir/Tools/ParseUtilities.h"
 #include "mlir/Tools/Plugins/DialectPlugin.h"
 #include "mlir/Tools/Plugins/PassPlugin.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
@@ -192,6 +195,12 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
     static cl::list<std::string> passPlugins(
         "load-pass-plugin", cl::desc("Load passes from plugin library"));
 
+    static cl::opt<std::string, /*ExternalStorage=*/true> passPipelineAnchor{
+        "pass-pipeline-anchor", llvm::cl::ValueOptional,
+        cl::desc("Specify an operation name that will be used as the anchor of "
+                 "the CLI pass pipeline"),
+        cl::location(passPipelineAnchorFlag), cl::init("")};
+
     static cl::opt<std::string, /*ExternalStorage=*/true>
         generateReproducerFile(
             "mlir-generate-reproducer",
@@ -413,17 +422,17 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
   auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true);
   return success(succeeded(txtStatus) && succeeded(bcStatus));
 }
-
-/// Perform the actions on the input file indicated by the command line flags
-/// within the specified context.
-///
-/// This typically parses the main source file, runs zero or more optimization
-/// passes, then prints the output.
-///
-static LogicalResult
-performActions(raw_ostream &os,
-               const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
-               MLIRContext *context, const MlirOptMainConfig &config) {
+static
+    /// Perform the actions on the input file indicated by the command line
+    /// flags within the specified context.
+    ///
+    /// This typically parses the main source file, runs zero or more
+    /// optimization passes, then prints the output.
+    ///
+    static LogicalResult
+    performActions(raw_ostream &os,
+                   const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+                   MLIRContext *context, const MlirOptMainConfig &config) {
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
@@ -460,7 +469,11 @@ performActions(raw_ostream &os,
   context->enableMultithreading(wasThreadingEnabled);
 
   // Prepare the pass manager, applying command-line and reproducer options.
-  PassManager pm(op.get()->getName(), PassManager::Nesting::ImplicitAny);
+  StringRef rootName = op.get()->getName().getStringRef();
+  StringRef passPipelineAnchor =
+      config.getPassPipelineAnchor().value_or(rootName);
+
+  PassManager pm(context, passPipelineAnchor, PassManager::Nesting::Implicit);
   pm.enableVerifier(config.shouldVerifyPasses());
   if (failed(applyPassManagerCLOptions(pm)))
     return failure();
@@ -470,9 +483,24 @@ performActions(raw_ostream &os,
   if (failed(config.setupPassPipeline(pm)))
     return failure();
 
-  // Run the pipeline.
-  if (failed(pm.run(*op)))
-    return failure();
+  if (config.getPassPipelineAnchor().has_value()) {
+    // Run the pipeline on each anchor. TODO parallelize
+    auto result = op->walk([&](Operation *anchor) {
+      if (anchor->getName().getStringRef() == *config.getPassPipelineAnchor()) {
+        if (failed(pm.run(anchor)))
+          return WalkResult::interrupt();
+        return WalkResult::skip();
+      }
+      return WalkResult::advance();
+    });
+
+    if (result.wasInterrupted())
+      return failure();
+  } else {
+    // Run the pipeline on the root
+    if (failed(pm.run(*op)))
+      return failure();
+  }
 
   // Generate reproducers if requested
   if (!config.getReproducerFilename().empty()) {

>From 617d2ad3de05a0d59f4b2bd421723dd8e2a94c67 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sat, 26 Oct 2024 20:52:45 +0200
Subject: [PATCH 06/23] Fix nesting problem for affine passes

---
 mlir/include/mlir/Dialect/Affine/Passes.h     |  5 +++++
 mlir/include/mlir/Pass/Pass.h                 | 12 ++++++++++++
 mlir/lib/Pass/Pass.cpp                        | 19 +++++++++++++------
 .../lib/Dialect/Affine/TestAccessAnalysis.cpp |  3 ++-
 .../lib/Dialect/Affine/TestAffineDataCopy.cpp |  3 ++-
 .../Affine/TestAffineLoopParametricTiling.cpp |  2 +-
 .../Affine/TestAffineLoopUnswitching.cpp      |  3 ++-
 .../lib/Dialect/Affine/TestLoopFusion.cpp     |  2 +-
 .../lib/Dialect/Affine/TestLoopMapping.cpp    |  3 ++-
 .../Dialect/Affine/TestLoopPermutation.cpp    |  4 ++--
 .../Dialect/Affine/TestVectorizationUtils.cpp |  3 ++-
 11 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index e580d73d83a8a..dbca6821c82ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -18,6 +18,7 @@
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include <limits>
+#include <llvm/ADT/StringRef.h>
 
 namespace mlir {
 
@@ -39,6 +40,10 @@ class AffineScopePassBase : public OperationPass<> {
     return opInfo.hasTrait<OpTrait::AffineScope>() &&
            opInfo.getStringRef() != ModuleOp::getOperationName();
   }
+
+  bool shouldImplicitlyNestOn(llvm::StringRef anchorName) const final {
+    return anchorName == ModuleOp::getOperationName();
+  }
 };
 
 /// Fusion mode to attempt. The default mode `Greedy` does both
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7725a3a2910bd..6b9eb2bd3d01c 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -14,6 +14,7 @@
 #include "mlir/Pass/PassRegistry.h"
 #include "llvm/ADT/PointerIntPair.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
 #include <optional>
 
 namespace mlir {
@@ -193,6 +194,17 @@ class Pass {
   /// operations they operate on.
   virtual bool canScheduleOn(RegisteredOperationName opName) const = 0;
 
+  /// Indicate whether this pass should implicitly nest itself in the pass manager,
+  /// when there is a mismatch between the anchor type and this pass' anchor type.
+  /// By default passes that have a specific anchor name nest themselves, and passes
+  /// that can handle any anchor don't.
+  ///
+  /// This is only ever called if the PassManager uses implicit nesting. Passes are
+  /// also never implicitly nested on a pass manager with anchor "any".
+  virtual bool shouldImplicitlyNestOn(StringRef anchorName) const {
+    return getOpName() && *getOpName() != anchorName;
+  }
+
   /// Schedule an arbitrary pass pipeline on the provided operation.
   /// This can be invoke any time in a pass to dynamic schedule more passes.
   /// The provided operation must be the current one or one nested below.
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 67c18189b85e0..fda16bc15a62a 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -15,12 +15,14 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/CrashRecoveryContext.h"
 #include "llvm/Support/Mutex.h"
@@ -28,6 +30,7 @@
 #include "llvm/Support/Threading.h"
 #include "llvm/Support/ToolOutputFile.h"
 #include <optional>
+#include <utility>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -206,7 +209,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
 
 OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
-  addPass(std::unique_ptr<Pass>(adaptor));
+  passes.emplace_back(std::unique_ptr<Pass>(adaptor));
   return adaptor->getPassManagers().front();
 }
 
@@ -215,9 +218,13 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
   // implicitly nest a pass manager for this operation if enabled.
   std::optional<StringRef> pmOpName = getOpName();
   std::optional<StringRef> passOpName = pass->getOpName();
-  if (pmOpName && passOpName && *pmOpName != *passOpName) {
-    if (nesting == OpPassManager::Nesting::Implicit)
-      return nest(*passOpName).addPass(std::move(pass));
+  if (pmOpName && ((passOpName && *passOpName != *pmOpName) ||
+                   pass->shouldImplicitlyNestOn(*pmOpName))) {
+    if (nesting == OpPassManager::Nesting::Implicit) {
+      if (passOpName)
+        return nest(*passOpName).addPass(std::move(pass));
+      return nestAny().addPass(std::move(pass));
+    }
     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
                              "' restricted to '" + *passOpName +
                              "' on a PassManager intended to run on '" +
@@ -464,7 +471,6 @@ llvm::hash_code OpPassManager::hash() {
   return hashCode;
 }
 
-
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
@@ -869,7 +875,8 @@ LogicalResult PassManager::run(Operation *op) {
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
   llvm::hash_code pipelineKey = hash();
-  if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
+  if (newInitKey != initializationKey ||
+      pipelineKey != pipelineInitializationKey) {
     if (failed(initialize(context, impl->initializationGeneration + 1)))
       return failure();
     initializationKey = newInitKey;
diff --git a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
index 3407a043332dc..f48091c9c0893 100644
--- a/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAccessAnalysis.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Pass/Pass.h"
 
 #define PASS_NAME "test-affine-access-analysis"
@@ -22,7 +23,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestAccessAnalysis
-    : public PassWrapper<TestAccessAnalysis, OperationPass<>> {
+    : public PassWrapper<TestAccessAnalysis, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAccessAnalysis)
 
   StringRef getArgument() const final { return PASS_NAME; }
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 12b832950ba85..742f9dae9e619 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -27,7 +28,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestAffineDataCopy
-    : public PassWrapper<TestAffineDataCopy, OperationPass<>> {
+    : public PassWrapper<TestAffineDataCopy, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineDataCopy)
 
   StringRef getArgument() const final { return PASS_NAME; }
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
index 0543769768aea..8e47e7c25da2e 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
@@ -23,7 +23,7 @@ using namespace mlir::affine;
 
 namespace {
 struct TestAffineLoopParametricTiling
-    : public PassWrapper<TestAffineLoopParametricTiling, OperationPass<>> {
+    : public PassWrapper<TestAffineLoopParametricTiling, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopParametricTiling)
 
   StringRef getArgument() const final { return "test-affine-parametric-tile"; }
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
index 7e4a3ca7b7c72..87c10b74b0cb7 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopUnswitching.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
@@ -25,7 +26,7 @@ namespace {
 
 /// This pass applies the permutation on the first maximal perfect nest.
 struct TestAffineLoopUnswitching
-    : public PassWrapper<TestAffineLoopUnswitching, OperationPass<>> {
+    : public PassWrapper<TestAffineLoopUnswitching, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopUnswitching)
 
   StringRef getArgument() const final { return PASS_NAME; }
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
index 7ea47bac8d562..0352c84391d10 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
@@ -26,7 +26,7 @@ using namespace mlir::affine;
 namespace {
 
 struct TestLoopFusion
-    : public PassWrapper<TestLoopFusion, OperationPass<>> {
+    : public PassWrapper<TestLoopFusion, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
 
   StringRef getArgument() const final { return "test-loop-fusion"; }
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp b/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp
index 3dc7abb15af17..429f7ea42cbe8 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopMapping.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
@@ -22,7 +23,7 @@ using namespace mlir::affine;
 
 namespace {
 struct TestLoopMappingPass
-    : public PassWrapper<TestLoopMappingPass, OperationPass<>> {
+    : public PassWrapper<TestLoopMappingPass, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopMappingPass)
 
   StringRef getArgument() const final {
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
index e708b7de690ec..7e8a6779ea7d7 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Pass/Pass.h"
+#include "mlir/Dialect/Affine/Passes.h"
 
 #define PASS_NAME "test-loop-permutation"
 
@@ -24,7 +24,7 @@ namespace {
 
 /// This pass applies the permutation on the first maximal perfect nest.
 struct TestLoopPermutation
-    : public PassWrapper<TestLoopPermutation, OperationPass<>> {
+    : public PassWrapper<TestLoopPermutation, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopPermutation)
 
   StringRef getArgument() const final { return PASS_NAME; }
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index 70cbefe9144ea..d5827f8989e04 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -38,7 +39,7 @@ static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
 
 namespace {
 struct VectorizerTestPass
-    : public PassWrapper<VectorizerTestPass, OperationPass<>> {
+    : public PassWrapper<VectorizerTestPass, AffineScopePassBase> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass)
 
   static constexpr auto kTestAffineMapOpName = "test_affine_map";

>From 66e94459e833c53a6268afe2d088a110d7c4cda8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sun, 27 Oct 2024 15:18:07 +0100
Subject: [PATCH 07/23] Parallelize new pass-pipeline-anchor option

---
 .../Dialect/Bufferization/Transforms/Passes.h |   4 +
 mlir/include/mlir/Pass/PassManager.h          |  10 +-
 mlir/lib/Pass/Pass.cpp                        | 154 +++++++++++++-----
 mlir/lib/Pass/PassDetail.h                    |   2 +
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp       |  34 ++--
 mlir/lib/Transforms/SROA.cpp                  |   6 +
 6 files changed, 144 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 50b2fac4ba994..429a76c45d092 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -33,6 +33,10 @@ class BufferScopePassBase : public OperationPass<> {
     return opInfo.hasTrait<OpTrait::AutomaticAllocationScope>() &&
            opInfo.getStringRef() != ModuleOp::getOperationName();
   }
+
+  bool shouldImplicitlyNestOn(llvm::StringRef anchorName) const final {
+    return anchorName == ModuleOp::getOperationName();
+  }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 950f3e9c547eb..038671e7c15ca 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -55,9 +55,6 @@ class OpPassManager {
     /// Explicit nesting behavior. This requires that any passes added to this
     /// pass manager support its operation type.
     Explicit,
-    /// Implicitly add an "any" nesting level when scheduling a pass that handles 
-    /// "any" type.
-    ImplicitAny,
   };
 
   /// Construct a new op-agnostic ("any") pass manager with the given operation
@@ -168,6 +165,13 @@ class OpPassManager {
   /// Return the current nesting mode.
   Nesting getNesting();
 
+
+  /// Make the pass pipeline fetch its anchors by doing a recursive walk,
+  /// instead of being anchored on the root of the IR.
+  void setRecursiveAnchorFetching(bool enabled = true);
+
+  bool hasRecursiveAnchor() const;
+
 private:
   /// Initialize all of the passes within this pass manager with the given
   /// initialization generation. The initialization generation is used to detect
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index fda16bc15a62a..f4bb9952f5dcf 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -18,17 +18,22 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/CrashRecoveryContext.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/Support/Threading.h"
 #include "llvm/Support/ToolOutputFile.h"
+#include <functional>
 #include <optional>
 #include <utility>
 
@@ -111,16 +116,19 @@ namespace detail {
 struct OpPassManagerImpl {
   OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
       : name(opName.getStringRef().str()), opName(opName),
-        initializationGeneration(0), nesting(nesting) {}
+        initializationGeneration(0), nesting(nesting),
+        isRecursiveAnchor(false) {}
   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
       : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()),
-        initializationGeneration(0), nesting(nesting) {}
+        initializationGeneration(0), nesting(nesting),
+        isRecursiveAnchor(false) {}
   OpPassManagerImpl(OpPassManager::Nesting nesting)
-      : initializationGeneration(0), nesting(nesting) {}
+      : initializationGeneration(0), nesting(nesting),
+        isRecursiveAnchor(false) {}
   OpPassManagerImpl(const OpPassManagerImpl &rhs)
       : name(rhs.name), opName(rhs.opName),
         initializationGeneration(rhs.initializationGeneration),
-        nesting(rhs.nesting) {
+        nesting(rhs.nesting), isRecursiveAnchor(false) {
     for (const std::unique_ptr<Pass> &pass : rhs.passes) {
       std::unique_ptr<Pass> newPass = pass->clone();
       newPass->threadingSibling = pass.get();
@@ -196,12 +204,17 @@ struct OpPassManagerImpl {
   /// Control the implicit nesting of passes that mismatch the name set for this
   /// OpPassManager.
   OpPassManager::Nesting nesting;
+
+  /// Whether the anchor is recursive
+  bool isRecursiveAnchor;
 };
 } // namespace detail
 } // namespace mlir
 
 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
   assert(name == rhs.name && "merging unrelated pass managers");
+  assert(isRecursiveAnchor == rhs.isRecursiveAnchor &&
+         "anchor fetching method is different");
   for (auto &pass : passes)
     rhs.passes.push_back(std::move(pass));
   passes.clear();
@@ -433,6 +446,14 @@ void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
 
 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
 
+void OpPassManager::setRecursiveAnchorFetching(bool enabled) {
+  impl->isRecursiveAnchor = enabled;
+}
+
+bool OpPassManager::hasRecursiveAnchor() const {
+  return impl->isRecursiveAnchor;
+}
+
 LogicalResult OpPassManager::initialize(MLIRContext *context,
                                         unsigned newInitGeneration) {
   if (impl->initializationGeneration == newInitGeneration)
@@ -478,16 +499,21 @@ llvm::hash_code OpPassManager::hash() {
 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
                                      AnalysisManager am, bool verifyPasses,
                                      unsigned parentInitGeneration) {
-  std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
-  if (!opInfo)
-    return op->emitOpError()
-           << "trying to schedule a pass on an unregistered operation";
-  if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
-    return op->emitOpError() << "trying to schedule a pass on an operation not "
-                                "marked as 'IsolatedFromAbove'";
-  if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
-    return op->emitOpError()
-           << "trying to schedule a pass on an unsupported operation";
+  bool hasRecursiveAnchor = isa<OpToOpPassAdaptor>(pass) &&
+                            cast<OpToOpPassAdaptor>(pass)->hasRecursiveAnchor();
+  if (!hasRecursiveAnchor) {
+    std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
+    if (!opInfo)
+      return op->emitOpError()
+             << "trying to schedule a pass on an unregistered operation";
+    if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
+      return op->emitOpError()
+             << "trying to schedule a pass on an operation not "
+                "marked as 'IsolatedFromAbove'";
+    if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
+      return op->emitOpError()
+             << "trying to schedule a pass on an unsupported operation";
+  }
 
   // Initialize the pass state with a callback for the pass to dynamically
   // execute a pipeline on the currently visited operation.
@@ -644,6 +670,10 @@ LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
   auto hasScheduleConflictWith = [&](OpPassManager &genericPM,
                                      MutableArrayRef<OpPassManager> otherPMs) {
     return llvm::any_of(otherPMs, [&](OpPassManager &pm) {
+      /// Anchor fetching methods must match
+      if (pm.hasRecursiveAnchor() != genericPM.hasRecursiveAnchor())
+        return true;
+
       // If this is a non-generic pass manager, a conflict will arise if a
       // non-generic pass manager's operation name can be scheduled on the
       // generic passmanager.
@@ -675,11 +705,11 @@ LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
     // into it.
     if (auto *existingPM =
             findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) {
-      pm.getImpl().mergeInto(existingPM->getImpl());
-    } else {
-      // Otherwise, add the given pass manager to the list.
-      rhs.mgrs.emplace_back(std::move(pm));
+      if (existingPM->hasRecursiveAnchor() == pm.hasRecursiveAnchor())
+        pm.getImpl().mergeInto(existingPM->getImpl());
     }
+    // Otherwise, add the given pass manager to the list.
+    rhs.mgrs.emplace_back(std::move(pm));
   }
   mgrs.clear();
 
@@ -708,6 +738,11 @@ std::string OpToOpPassAdaptor::getAdaptorName() {
   return name;
 }
 
+bool OpToOpPassAdaptor::hasRecursiveAnchor() {
+  return llvm::all_of(
+      mgrs, [](OpPassManager &pm) { return pm.hasRecursiveAnchor(); });
+}
+
 void OpToOpPassAdaptor::runOnOperation() {
   llvm_unreachable(
       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
@@ -727,18 +762,35 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
                                                         this};
   auto *instrumentor = am.getPassInstrumentor();
-  for (auto &region : getOperation()->getRegions()) {
-    for (auto &block : region) {
-      for (auto &op : block) {
-        auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
-        if (!mgr)
-          continue;
-
-        // Run the held pipeline over the current operation.
-        unsigned initGeneration = mgr->impl->initializationGeneration;
-        if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
-                               initGeneration, instrumentor, &parentInfo)))
-          signalPassFailure();
+  auto handleOp = [&](Operation &op) -> WalkResult {
+    auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext());
+    if (!mgr)
+      return WalkResult::advance();
+
+    // Run the held pipeline over the current operation.
+    unsigned initGeneration = mgr->impl->initializationGeneration;
+    if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses,
+                           initGeneration, instrumentor, &parentInfo))) {
+      signalPassFailure();
+      return WalkResult::interrupt();
+    }
+    return WalkResult::skip();
+  };
+
+  if (hasRecursiveAnchor()) {
+    for (auto &region : getOperation()->getRegions()) {
+      auto res = region.walk<WalkOrder::PreOrder>(
+          [&](Operation *op) -> WalkResult { return handleOp(*op); });
+      if (res.wasInterrupted())
+        return;
+    }
+  } else {
+    for (auto &region : getOperation()->getRegions()) {
+      for (auto &block : region) {
+        for (auto &op : block) {
+          if (handleOp(op).wasInterrupted())
+            return;
+        }
       }
     }
   }
@@ -782,18 +834,38 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
   // operation, as well as providing a queue of operations to execute over.
   std::vector<OpPMInfo> opInfos;
   DenseMap<OperationName, std::optional<unsigned>> knownOpPMIdx;
+
+  auto handleOp = [&](Operation &op) -> LogicalResult {
+    // Get the pass manager index for this operation type.
+    auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt);
+    if (pmIdxIt.second) {
+      if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context))
+        pmIdxIt.first->second = std::distance(mgrs.begin(), mgr);
+    }
+
+    // If this operation can be scheduled, add it to the list.
+    if (pmIdxIt.first->second) {
+      opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
+      return failure();
+    }
+    return success();
+  };
+
   for (auto &region : getOperation()->getRegions()) {
-    for (Operation &op : region.getOps()) {
-      // Get the pass manager index for this operation type.
-      auto pmIdxIt = knownOpPMIdx.try_emplace(op.getName(), std::nullopt);
-      if (pmIdxIt.second) {
-        if (auto *mgr = findPassManagerFor(mgrs, op.getName(), *context))
-          pmIdxIt.first->second = std::distance(mgrs.begin(), mgr);
-      }
 
-      // If this operation can be scheduled, add it to the list.
-      if (pmIdxIt.first->second)
-        opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
+    if (hasRecursiveAnchor()) {
+      // in that case the next nested ops to process are fetched recursively
+      region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+        if (succeeded(handleOp(*op))) {
+          return WalkResult::skip();
+        }
+        return WalkResult::advance();
+      });
+    } else {
+      // here they are only fetched from the children
+      for (Operation &op : region.getOps()) {
+        (void)handleOp(op);
+      }
     }
   }
 
@@ -853,7 +925,7 @@ void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
 LogicalResult PassManager::run(Operation *op) {
   MLIRContext *context = getContext();
   std::optional<OperationName> anchorOp = getOpName(*context);
-  if (anchorOp && anchorOp != op->getName())
+  if (anchorOp && anchorOp != op->getName() && !hasRecursiveAnchor())
     return emitError(op->getLoc())
            << "can't run '" << getOpAnchorName() << "' pass manager on '"
            << op->getName() << "' op";
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 5cc726295c9f1..02534a777641b 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -58,6 +58,8 @@ class OpToOpPassAdaptor
   /// Returns the adaptor pass name.
   std::string getAdaptorName();
 
+  bool hasRecursiveAnchor();
+
 private:
   /// Run this pass adaptor synchronously.
   void runOnOperationImpl(bool verifyPasses);
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 598d9d2d4355d..f99cb7fc01c23 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -294,7 +294,14 @@ MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
       emitError(UnknownLoc::get(pm.getContext())) << msg;
       return failure();
     };
-    if (failed(passPipeline.addToPipeline(pm, errorHandler)))
+
+    OpPassManager* oppm = ±
+    if (auto anchor = getPassPipelineAnchor()) {
+      oppm = &pm.nest(*anchor);
+      oppm->setRecursiveAnchorFetching(true);
+    }
+
+    if (failed(passPipeline.addToPipeline(*oppm, errorHandler)))
       return failure();
     if (this->shouldDumpPassPipeline()) {
 
@@ -470,10 +477,8 @@ static
 
   // Prepare the pass manager, applying command-line and reproducer options.
   StringRef rootName = op.get()->getName().getStringRef();
-  StringRef passPipelineAnchor =
-      config.getPassPipelineAnchor().value_or(rootName);
+  PassManager pm(context, rootName, PassManager::Nesting::Implicit);
 
-  PassManager pm(context, passPipelineAnchor, PassManager::Nesting::Implicit);
   pm.enableVerifier(config.shouldVerifyPasses());
   if (failed(applyPassManagerCLOptions(pm)))
     return failure();
@@ -483,24 +488,9 @@ static
   if (failed(config.setupPassPipeline(pm)))
     return failure();
 
-  if (config.getPassPipelineAnchor().has_value()) {
-    // Run the pipeline on each anchor. TODO parallelize
-    auto result = op->walk([&](Operation *anchor) {
-      if (anchor->getName().getStringRef() == *config.getPassPipelineAnchor()) {
-        if (failed(pm.run(anchor)))
-          return WalkResult::interrupt();
-        return WalkResult::skip();
-      }
-      return WalkResult::advance();
-    });
-
-    if (result.wasInterrupted())
-      return failure();
-  } else {
-    // Run the pipeline on the root
-    if (failed(pm.run(*op)))
-      return failure();
-  }
+  // Run the pipeline on the root
+  if (failed(pm.run(*op)))
+    return failure();
 
   // Generate reproducers if requested
   if (!config.getReproducerFilename().empty()) {
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index db8be38a51443..0e9a9dc43fe61 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -10,8 +10,10 @@
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
+#include <llvm/ADT/StringRef.h>
 
 namespace mlir {
 #define GEN_PASS_DEF_SROA
@@ -248,6 +250,10 @@ namespace {
 struct SROA : public impl::SROABase<SROA> {
   using impl::SROABase<SROA>::SROABase;
 
+  bool shouldImplicitlyNestOn(llvm::StringRef name) const final {
+    return name == ModuleOp::getOperationName();
+  }
+
   void runOnOperation() override {
     Operation *scopeOp = getOperation();
 

>From 2f36f4c47bfa2455c23d6d1c49b005db3e72484b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Sun, 27 Oct 2024 15:45:03 +0100
Subject: [PATCH 08/23] Add textual syntax

---
 mlir/include/mlir/Pass/PassManager.h    |  2 +-
 mlir/lib/Pass/Pass.cpp                  |  7 +++++--
 mlir/lib/Pass/PassCrashRecovery.cpp     |  6 +++---
 mlir/lib/Pass/PassRegistry.cpp          | 21 +++++++++++++++-----
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp | 26 ++++++++++++-------------
 5 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 038671e7c15ca..4d688a049c194 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -229,7 +229,7 @@ using ReproducerStreamFactory =
     std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;
 
 std::string
-makeReproducer(StringRef anchorName,
+makeReproducer(StringRef anchorName, bool hasRecursiveAnchor,
                const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
                Operation *op, StringRef outputFile, bool disableThreads = false,
                bool verifyPasses = false);
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f4bb9952f5dcf..d4f516b041878 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -410,8 +410,11 @@ StringRef OpPassManager::getOpAnchorName() const {
 /// Prints out the passes of the pass manager as the textual representation
 /// of pipelines.
 void printAsTextualPipeline(
-    raw_ostream &os, StringRef anchorName,
+    raw_ostream &os, StringRef anchorName, bool hasRecursiveAnchor,
     const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
+  if (hasRecursiveAnchor) {
+    os << "**";
+  }
   os << anchorName << "(";
   llvm::interleave(
       passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
@@ -421,7 +424,7 @@ void printAsTextualPipeline(
 void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
   StringRef anchorName = getOpAnchorName();
   ::printAsTextualPipeline(
-      os, anchorName,
+      os, anchorName, hasRecursiveAnchor(),
       {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
        MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
 }
diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp
index 8c6d865cb31dd..74d9dfa679728 100644
--- a/mlir/lib/Pass/PassCrashRecovery.cpp
+++ b/mlir/lib/Pass/PassCrashRecovery.cpp
@@ -442,11 +442,11 @@ makeReproducerStreamFactory(StringRef outputFile) {
 }
 
 void printAsTextualPipeline(
-    raw_ostream &os, StringRef anchorName,
+    raw_ostream &os, StringRef anchorName, bool hasRecursiveAnchor,
     const llvm::iterator_range<OpPassManager::pass_iterator> &passes);
 
 std::string mlir::makeReproducer(
-    StringRef anchorName,
+    StringRef anchorName, bool hasRecursiveAnchor,
     const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
     Operation *op, StringRef outputFile, bool disableThreads,
     bool verifyPasses) {
@@ -454,7 +454,7 @@ std::string mlir::makeReproducer(
   std::string description;
   std::string pipelineStr;
   llvm::raw_string_ostream passOS(pipelineStr);
-  ::printAsTextualPipeline(passOS, anchorName, passes);
+  ::printAsTextualPipeline(passOS, anchorName, hasRecursiveAnchor, passes);
   appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
                    pipelineStr, disableThreads, verifyPasses);
   return description;
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index ece2fdaed0dfd..ada1da8050760 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -546,9 +546,17 @@ class TextualPipeline {
   /// the name is the name of a pass, the InnerPipeline is empty, since passes
   /// cannot contain inner pipelines.
   struct PipelineElement {
-    PipelineElement(StringRef name) : name(name) {}
+    PipelineElement(StringRef name) {
+      if (name.starts_with("**")) {
+        this->name = name.drop_front(2);
+        this->hasRecursiveAnchor = true;
+      } else {
+        this->name = name;
+      }
+    }
 
     StringRef name;
+    bool hasRecursiveAnchor = false;
     StringRef options;
     const PassRegistryEntry *registryEntry = nullptr;
     std::vector<PipelineElement> innerPipeline;
@@ -755,10 +763,13 @@ LogicalResult TextualPipeline::addToPipeline(
         return errorHandler("failed to add `" + elt.name + "` with options `" +
                             elt.options + "`");
       }
-    } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
-                                    errorHandler))) {
-      return errorHandler("failed to add `" + elt.name + "` with options `" +
-                          elt.options + "` to inner pipeline");
+    } else {
+      auto &nested = pm.nest(elt.name);
+      if (failed(addToPipeline(elt.innerPipeline, nested, errorHandler))) {
+        return errorHandler("failed to add `" + elt.name + "` with options `" +
+                            elt.options + "` to inner pipeline");
+      }
+      nested.setRecursiveAnchorFetching(elt.hasRecursiveAnchor);
     }
   }
   return success();
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index f99cb7fc01c23..7241e2386471d 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -295,7 +295,7 @@ MlirOptMainConfig &MlirOptMainConfig::setPassPipelineParser(
       return failure();
     };
 
-    OpPassManager* oppm = ±
+    OpPassManager *oppm = ±
     if (auto anchor = getPassPipelineAnchor()) {
       oppm = &pm.nest(*anchor);
       oppm->setRecursiveAnchorFetching(true);
@@ -429,17 +429,17 @@ static LogicalResult doVerifyRoundTrip(Operation *op,
   auto bcStatus = doVerifyRoundTrip(op, config, /*useBytecode=*/true);
   return success(succeeded(txtStatus) && succeeded(bcStatus));
 }
-static
-    /// Perform the actions on the input file indicated by the command line
-    /// flags within the specified context.
-    ///
-    /// This typically parses the main source file, runs zero or more
-    /// optimization passes, then prints the output.
-    ///
-    static LogicalResult
-    performActions(raw_ostream &os,
-                   const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
-                   MLIRContext *context, const MlirOptMainConfig &config) {
+
+/// Perform the actions on the input file indicated by the command line
+/// flags within the specified context.
+///
+/// This typically parses the main source file, runs zero or more
+/// optimization passes, then prints the output.
+///
+static LogicalResult
+performActions(raw_ostream &os,
+               const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+               MLIRContext *context, const MlirOptMainConfig &config) {
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
@@ -496,7 +496,7 @@ static
   if (!config.getReproducerFilename().empty()) {
     StringRef anchorName = pm.getAnyOpAnchorName();
     const auto &passes = pm.getPasses();
-    makeReproducer(anchorName, passes, op.get(),
+    makeReproducer(anchorName, pm.hasRecursiveAnchor(), passes, op.get(),
                    config.getReproducerFilename());
   }
 

>From 33b982bc1589ceb9e1ca319f3eff7e71f7c72d43 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 28 Oct 2024 10:50:21 +0100
Subject: [PATCH 09/23] Fix merging bug

---
 mlir/lib/Pass/Pass.cpp | 35 +++++++++++++++++------------------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d4f516b041878..ee4710120550b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -502,21 +502,16 @@ llvm::hash_code OpPassManager::hash() {
 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
                                      AnalysisManager am, bool verifyPasses,
                                      unsigned parentInitGeneration) {
-  bool hasRecursiveAnchor = isa<OpToOpPassAdaptor>(pass) &&
-                            cast<OpToOpPassAdaptor>(pass)->hasRecursiveAnchor();
-  if (!hasRecursiveAnchor) {
-    std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
-    if (!opInfo)
-      return op->emitOpError()
-             << "trying to schedule a pass on an unregistered operation";
-    if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
-      return op->emitOpError()
-             << "trying to schedule a pass on an operation not "
-                "marked as 'IsolatedFromAbove'";
-    if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
-      return op->emitOpError()
-             << "trying to schedule a pass on an unsupported operation";
-  }
+  std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
+  if (!opInfo)
+    return op->emitOpError()
+           << "trying to schedule a pass on an unregistered operation";
+  if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
+    return op->emitOpError() << "trying to schedule a pass on an operation not "
+                                "marked as 'IsolatedFromAbove'";
+  if (!pass->canScheduleOn(*op->getName().getRegisteredInfo()))
+    return op->emitOpError()
+           << "trying to schedule a pass on an unsupported operation";
 
   // Initialize the pass state with a callback for the pass to dynamically
   // execute a pipeline on the currently visited operation.
@@ -708,8 +703,10 @@ LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
     // into it.
     if (auto *existingPM =
             findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) {
-      if (existingPM->hasRecursiveAnchor() == pm.hasRecursiveAnchor())
+      if (existingPM->hasRecursiveAnchor() == pm.hasRecursiveAnchor()) {
         pm.getImpl().mergeInto(existingPM->getImpl());
+        continue;
+      }
     }
     // Otherwise, add the given pass manager to the list.
     rhs.mgrs.emplace_back(std::move(pm));
@@ -777,6 +774,7 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
       signalPassFailure();
       return WalkResult::interrupt();
     }
+    // if we could run the pipeline, we skip exploration of its subtree.
     return WalkResult::skip();
   };
 
@@ -849,9 +847,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     // If this operation can be scheduled, add it to the list.
     if (pmIdxIt.first->second) {
       opInfos.emplace_back(*pmIdxIt.first->second, &op, am.nest(&op));
-      return failure();
+      return success();
     }
-    return success();
+    return failure();
   };
 
   for (auto &region : getOperation()->getRegions()) {
@@ -860,6 +858,7 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
       // in that case the next nested ops to process are fetched recursively
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
         if (succeeded(handleOp(*op))) {
+          // if we can run the pipeline, we skip exploration of its subtree.
           return WalkResult::skip();
         }
         return WalkResult::advance();

>From f9817c679ec4e3e9d35dee47422635e900b83b0a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 28 Oct 2024 11:44:00 +0100
Subject: [PATCH 10/23] Add tests

---
 mlir/test/Pass/recursive-pipeline-anchor.mlir | 32 +++++++++++++++++++
 mlir/test/lib/Pass/TestPassManager.cpp        |  8 ++++-
 2 files changed, 39 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Pass/recursive-pipeline-anchor.mlir

diff --git a/mlir/test/Pass/recursive-pipeline-anchor.mlir b/mlir/test/Pass/recursive-pipeline-anchor.mlir
new file mode 100644
index 0000000000000..f8464fe77ae22
--- /dev/null
+++ b/mlir/test/Pass/recursive-pipeline-anchor.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(**func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -mlir-disable-threading -test-function-pass --pass-pipeline-anchor=func.func -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s
+
+// some with threading enabled
+
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(**func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -test-function-pass --pass-pipeline-anchor=func.func -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s
+
+// some without recursion
+
+// RUN: mlir-opt %s -mlir-disable-threading -test-function-pass -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s --check-prefix=NON_REC_CHECK
+// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(func.func(test-function-pass))' -verify-each=false -mlir-pass-statistics -mlir-pass-statistics-display=list 2>&1 | FileCheck %s --check-prefix=NON_REC_CHECK
+
+func.func @foo() {
+  return
+}
+
+module {
+  func.func @bar() {
+    return
+  }
+}
+
+// with recursive anchor the pass is executed on @foo and @bar
+
+// CHECK:                TestFunctionPass
+// CHECK-NEXT:              (S) 2 counter - Number of invocations
+
+// in non-recursive mode the pass is only executed on @foo
+
+// NON_REC_CHECK:        TestFunctionPass
+// NON_REC_CHECK-NEXT:      (S) 1 counter - Number of invocations
\ No newline at end of file
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 7afe2109f04db..9289300f5ba4c 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -29,8 +29,14 @@ struct TestModulePass
 struct TestFunctionPass
     : public PassWrapper<TestFunctionPass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFunctionPass)
+  TestFunctionPass() = default;
+  TestFunctionPass(const TestFunctionPass& pass) {}
 
-  void runOnOperation() final {}
+  Statistic callCount{this, "counter", "Number of invocations"};
+
+  void runOnOperation() final {
+    callCount++;
+  }
   StringRef getArgument() const final { return "test-function-pass"; }
   StringRef getDescription() const final {
     return "Test a function pass in the pass manager";

>From e3024632c4216ba669d3305311480383acfb6577 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 25 Oct 2024 14:41:08 +0200
Subject: [PATCH 11/23] Allow folding memref.load into a constant

And canonicalize a single-element memref copy
into a load and a store. This allows SROA to
scalarize the copied elements. Currently
SROA cannot handle forwarding of one memory slot
to another.
---
 .../Dialect/MemRef/IR/MemRefMemorySlot.cpp    |  3 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 69 ++++++++++++++++++-
 2 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index f630c48cdcaa1..d54590acf09e7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
@@ -309,7 +310,7 @@ struct MemRefDestructurableTypeExternalModel
     constexpr int64_t maxMemrefSizeForDestructuring = 16;
     if (!memrefType.hasStaticShape() ||
         memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
-        memrefType.getNumElements() == 1)
+        memrefType.getShape().empty())
       return {};
 
     DenseMap<Attribute, Type> destructured;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..74293241fbf0c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -11,18 +11,27 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
+#include <iterator>
 
 using namespace mlir;
 using namespace mlir::memref;
@@ -850,11 +859,40 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
     return failure();
   }
 };
+
+struct DestructureSingleEltCopy final : public OpRewritePattern<CopyOp> {
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CopyOp copyOp,
+                                PatternRewriter &rewriter) const override {
+    if (copyOp.getSource().getType() == copyOp.getTarget().getType()) {
+      auto ty = copyOp.getSource().getType();
+      if (ty.hasRank() && ty.getNumElements() == 1 && ty.hasStaticShape()) {
+        // copy of one element
+        rewriter.setInsertionPoint(copyOp);
+        SmallVector<Value> indices;
+        if (!ty.getShape().empty()) {
+          Value cst0 = rewriter.create<arith::ConstantOp>(
+              copyOp->getLoc(), rewriter.getIndexAttr(0));
+          indices.append(ty.getShape().size(), cst0);
+        }
+        auto loaded = rewriter.create<memref::LoadOp>(
+            copyOp->getLoc(), copyOp.getSource(), indices);
+        rewriter.create<memref::StoreOp>(copyOp->getLoc(), loaded.getResult(),
+                                         copyOp.getTarget(), indices);
+        rewriter.eraseOp(copyOp);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
+  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy,
+              DestructureSingleEltCopy>(context);
 }
 
 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
@@ -1676,6 +1714,35 @@ LogicalResult LoadOp::verify() {
 }
 
 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
+
+  if (auto getCst = mlir::dyn_cast_or_null<GetGlobalOp>(getMemref().getDefiningOp())) {
+    auto global = mlir::dyn_cast_or_null<GlobalOp>(
+        SymbolTable::lookupNearestSymbolFrom(getCst, getCst.getNameAttr()));
+    if (global && global.getConstant() && global.getInitialValue()) {
+      auto constIndices = adaptor.getIndices();
+      if (llvm::all_of(constIndices, [](auto attr) {
+            return mlir::isa<IntegerAttr>(attr);
+          })) {
+        SmallVector<uint64_t> index;
+        for (auto attr : constIndices) {
+          index.push_back(cast<IntegerAttr>(attr).getUInt());
+        }
+        // all indices are constant, value is constant
+        if (auto constValue =
+                mlir::dyn_cast<ElementsAttr>(*global.getInitialValue())) {
+          if (constValue.isValidIndex(index)) {
+            auto flatIdx = constValue.getFlattenedIndex(index);
+            auto values = constValue.getValues<Attribute>();
+            auto iter = values.begin();
+            if (std::next(iter, flatIdx) < values.end()) {
+              return OpFoldResult(*iter);
+            }
+          }
+        }
+      }
+    }
+  }
+
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
     return getResult();

>From 3e7b76db0647a2fda1eb85daa16a23f05ff10ffb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 25 Oct 2024 19:16:53 +0200
Subject: [PATCH 12/23] Remove unneeded build dependency in TosaToLinalg

causing relinking of many libs when
a change to Pass.cpp is made
---
 mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index 5123c2a7cf916..8ad221387ef99 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -19,7 +19,6 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   MLIRLinalgDialect
   MLIRLinalgUtils
   MLIRMathDialect
-  MLIRPass
   MLIRSupport
   MLIRTensorDialect
   MLIRTosaDialect

>From a107776e553079fe31038c56d5645b3f7fc1957e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 25 Oct 2024 21:24:50 +0200
Subject: [PATCH 13/23] Improve Affine scalrep to identify reduction variables

---
 .../Affine/IR/AffineMemoryOpInterfaces.td     |  10 ++
 .../mlir/Dialect/Affine/IR/AffineOps.td       |   3 +-
 .../Affine/Analysis/AffineAnalysis.cpp        |   8 ++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |   5 +
 .../Transforms/AffineScalarReplacement.cpp    |   2 -
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 124 +++++++++++++++++-
 mlir/test/Dialect/Affine/scalrep.mlir         |  47 ++++++-
 7 files changed, 189 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
index c07ab9deca48c..efbe15eb00d7a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -138,6 +138,16 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
         return $_op.getOperand($_op.getStoredValOperandIndex());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Returns the value to store.",
+      /*retTy=*/"::mlir::OpOperand&",
+      /*methodName=*/"getValueToStoreMutable",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return $_op->getOpOperand($_op.getStoredValOperandIndex());
+      }]
+    >,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6cd3408e2b2e9..e6450defb4376 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -119,7 +119,8 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
-      "getLoopUpperBounds", "getYieldedValuesMutable",
+      "getLoopUpperBounds", "getYieldedValuesMutable", "getLoopResults",
+      "getInitsMutable", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 84b76d33c3e67..1518f5dbae749 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -491,6 +491,14 @@ LogicalResult MemRefAccess::getAccessRelation(IntegerRelation &rel) const {
   IntegerRelation domainRel = domain;
   if (rel.getSpace().isUsingIds() && !domainRel.getSpace().isUsingIds())
     domainRel.resetIds();
+
+  if (!rel.getSpace().isUsingIds()) {
+    assert(rel.getNumVars() == 0);
+    rel.resetIds();
+    if (!domainRel.getSpace().isUsingIds())
+      domainRel.resetIds();
+  }
+
   domainRel.appendVar(VarKind::Range, accessValueMap.getNumResults());
   domainRel.mergeAndAlignSymbols(rel);
   domainRel.mergeLocalVars(rel);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 11a087f59b072..f1dbdb8eee455 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -2483,6 +2484,10 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
+std::optional<ResultRange> AffineForOp::getLoopResults() {
+  return {getResults()};
+}
+
 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
   return SmallVector<Value>{getInductionVar()};
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
index 98338dc473f1b..fe0a668a59b68 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineScalarReplacement.cpp
@@ -18,8 +18,6 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
-#include "mlir/IR/OpDefinition.h"
-#include <algorithm>
 
 namespace mlir {
 namespace affine {
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 543bff20a4199..bdc6d9763ae4a 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -27,9 +27,12 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/LogicalResult.h"
 #include <optional>
+#include <tuple>
 
 #define DEBUG_TYPE "affine-utils"
 
@@ -891,6 +894,8 @@ static void forwardStoreToLoad(
   // loads and stores.
   if (storeVal.getType() != loadOp.getValue().getType())
     return;
+  LLVM_DEBUG(llvm::dbgs() << "Erased load (forwarded from store): " << loadOp
+                          << "\n");
   loadOp.getValue().replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
@@ -945,11 +950,122 @@ static void findUnusedStore(AffineWriteOpInterface writeA,
                                                              mayAlias))
       continue;
 
+    LLVM_DEBUG(llvm::dbgs() << "Erased store (unused): " << writeA << "\n");
     opsToErase.push_back(writeA);
     break;
   }
 }
 
+/// This attempts to find load-store pairs in the body of the loop
+/// that could be replaced by an iter_args variable on the loop. The
+/// initial load and the final store are moved out of the loop. For 
+/// such a pair to be eligible:
+/// 1. the load must be followed by the store
+/// 2. the memref must not be read again after the store
+/// 3. the indices of the load and store must match AND be 
+/// loop-invariant for the given loop.
+///
+/// This is a useful transformation as
+/// - it exposes reduction dependencies that can be extracted by --affine-parallelize
+/// - it is a common pattern in code lowered from linalg.
+/// - it exposes more opportunities for forwarding of load/store by 
+/// moving the load/store out of the loop and into a scope.
+/// 
+static void findReductionVariablesAndRewrite(
+    LoopLikeOpInterface loop, PostDominanceInfo &postDominanceInfo,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!loop.getLoopResults())
+    return;
+
+  SmallVector<std::pair<AffineReadOpInterface, AffineWriteOpInterface>> result;
+  auto *region = loop.getLoopRegions()[0];
+  auto &block = region->front();
+
+  for (auto &op : block.without_terminator()) {
+    // iterate over ops to find loop-invariant load/store pairs
+    auto asLoad = dyn_cast<AffineReadOpInterface>(op);
+    if (!asLoad) {
+      continue;
+    }
+
+    // Indices must be loop-invariant
+    bool isLoopInvariant = true;
+    for (auto operand : asLoad.getMapOperands()) {
+      if (!loop.isDefinedOutsideOfLoop(operand)) {
+        isLoopInvariant = false;
+        break;
+      }
+    }
+    if (!isLoopInvariant)
+      continue;
+
+    // find a corresponding store
+    for (auto *user : asLoad.getMemRef().getUsers()) {
+      if (user->getBlock() != &block || user->isBeforeInBlock(&op))
+        continue;
+      auto asStore = dyn_cast<AffineWriteOpInterface>(user);
+      if (!asStore)
+        continue;
+
+      // both load and store must access the same index
+      if (MemRefAccess(asLoad) != MemRefAccess(asStore)) {
+        break;
+      }
+
+      // Check that nobody could be reading from the store before the next load,
+      // as we want to eliminate the store.
+      if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(
+              asStore.getOperation(), asLoad, mayAlias))
+        break;
+      
+      // now let's just replace this pair of accesses with loop iter args
+      result.push_back({asLoad, asStore});
+    }
+  }
+  if (result.empty())
+    return;
+  SmallVector<Value> newInitOperands;
+  SmallVector<Value> newYieldOperands;
+  IRRewriter rewriter(loop->getContext());
+  rewriter.startOpModification(loop->getParentOp());
+  rewriter.setInsertionPoint(loop);
+  for (auto [load, store] : result) {
+    auto rewrittenLoad = cast<AffineReadOpInterface>(rewriter.clone(*load));
+    newInitOperands.push_back(rewrittenLoad.getValue());
+    newYieldOperands.push_back(store.getValueToStore());
+  }
+
+  const auto numResults = loop.getLoopResults()->size();
+  auto rewritten = loop.replaceWithAdditionalYields(
+      rewriter, newInitOperands, false,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+        return newYieldOperands;
+      });
+  if (failed(rewritten)) {
+    rewriter.cancelOpModification(loop->getParentOp());
+    return;
+  }
+  auto newLoop = *rewritten;
+
+  rewriter.setInsertionPointAfter(newLoop);
+  Operation *next = newLoop;
+  for (auto [loadStore, bbArg, loopRes] :
+       llvm::zip(result, rewritten->getRegionIterArgs().drop_front(numResults),
+                 rewritten->getLoopResults()->drop_front(numResults))) {
+    auto load = loadStore.first;
+    rewriter.replaceOp(load, bbArg);
+
+    auto store = loadStore.second;
+    rewriter.moveOpAfter(store, next);
+    store.getValueToStoreMutable().set(loopRes);
+    next = store;
+  }
+
+  rewriter.finalizeOpModification(newLoop->getParentOp());
+  LLVM_DEBUG(llvm::dbgs() << "Replaced loop reduction variable: \n"
+                          << newLoop << "\n");
+}
+
 // The load to load forwarding / redundant load elimination is similar to the
 // store to load forwarding.
 // loadA will be be replaced with loadB if:
@@ -1037,7 +1153,8 @@ static void loadCSE(AffineReadOpInterface loadA,
 // currently only eliminates the stores only if no other loads/uses (other
 // than dealloc) remain.
 //
-void mlir::affine::affineScalarReplace(Operation* parentOp, DominanceInfo &domInfo,
+void mlir::affine::affineScalarReplace(Operation *parentOp,
+                                       DominanceInfo &domInfo,
                                        PostDominanceInfo &postDomInfo,
                                        AliasAnalysis &aliasAnalysis) {
   // Load op's whose results were replaced by those forwarded from stores.
@@ -1050,6 +1167,11 @@ void mlir::affine::affineScalarReplace(Operation* parentOp, DominanceInfo &domIn
     return !aliasAnalysis.alias(val1, val2).isNo();
   };
 
+  // scalarize reduction variables as iter_args
+  parentOp->walk([&](AffineForOp loop) {
+    findReductionVariablesAndRewrite(loop, postDomInfo, mayAlias);
+  });
+
   // Walk all load's and perform store to load forwarding.
   parentOp->walk([&](AffineReadOpInterface loadOp) {
     forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir
index 092597860c8d9..6b9d4fc1ac15f 100644
--- a/mlir/test/Dialect/Affine/scalrep.mlir
+++ b/mlir/test/Dialect/Affine/scalrep.mlir
@@ -141,9 +141,14 @@ func.func @store_load_store_nested_no_fwd(%N : index) {
   affine.for %i0 = 0 to 10 {
     affine.store %cf7, %m[%i0] : memref<10xf32>
     affine.for %i1 = 0 to %N {
-      // CHECK: %{{[0-9]+}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+      // CHECK:      %[[C7:.*]] = arith.constant 7.0{{.*}}
+      // CHECK:      %[[C9:.*]] = arith.constant 9.0{{.*}}
+      // CHECK:      %{{[0-9]+}} = affine.for %{{.*}} = 0 to %{{.*}} iter_args(%[[A:.*]] = %[[C7]]) -> (f32)
+      // CHECK-NEXT:    %[[R:.*]] = arith.addf %[[A]], %[[A]] : f32
+      // CHECK:    affine.yield %[[C9]] : f32
       %v0 = affine.load %m[%i0] : memref<10xf32>
       %v1 = arith.addf %v0, %v0 : f32
+      "use"(%v1) : (f32) -> ()
       affine.store %cf9, %m[%i0] : memref<10xf32>
     }
   }
@@ -423,7 +428,8 @@ func.func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) {
     // CHECK:       affine.load
     %v0 = affine.load %m[%i0] : memref<10xf32>
     affine.for %i1 = 0 to %N {
-      // CHECK:       affine.load
+      // CHECK:       iter_args
+      // CHECK-NOT:       affine.load
       %v1 = affine.load %m[%i0] : memref<10xf32>
       %v2 = arith.addf %v0, %v1 : f32
       affine.store %v2, %m[%i0] : memref<10xf32>
@@ -556,10 +562,11 @@ func.func @reduction_multi_store() -> memref<1xf32> {
    "test.foo"(%m) : (f32) -> ()
   }
 
-// CHECK:       affine.for
-// CHECK:         affine.load
-// CHECK:         affine.store %[[S:.*]],
-// CHECK-NEXT:    "test.foo"(%[[S]])
+// CHECK:       affine.for {{.*}} 
+// CHECK-NEXT:    %[[A:.*]] = affine.load
+// CHECK-NEXT:    %[[X:.*]] = arith.addf %[[A]], 
+// CHECK-NEXT:    affine.store %[[X]]
+// CHECK-NEXT:    "test.foo"(%[[X]])
 
   return %A : memref<1xf32>
 }
@@ -890,6 +897,34 @@ func.func @parallel_surrounding_for() {
 // CHECK-NEXT:  return
 }
 
+// CHECK-LABEL: func @reduction_extraction
+func.func @reduction_extraction(%x : memref<10x10xf32>) -> f32 {
+  %b = memref.alloc() : memref<f32>
+  %cst = arith.constant 0.0 : f32
+  affine.store %cst, %b[] : memref<f32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %v0 = affine.load %x[%i0,%i1] : memref<10x10xf32>
+      %acc = affine.load %b[] : memref<f32>
+      %v1 = arith.addf %acc, %v0 : f32
+      affine.store %v1, %b[] : memref<f32>
+    }
+  }
+  %x2 = affine.load %b[]: memref<f32>
+  return %x2 : f32
+// CHECK:       %[[I:.*]] = arith.constant 0{{.*}} : f32
+// CHECK-NEXT:  %[[SUM2:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC2:.*]] = %[[I]]) -> (f32) {
+// CHECK-NEXT:    %[[SUM:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC:.*]] = %[[ACC2]]) -> (f32) {
+// CHECK-NEXT:      %[[X:.*]] = affine.load {{.*}} : memref<10x10xf32>
+// CHECK-NEXT:      %[[Y:.*]] = arith.addf %[[ACC]], %[[X]] : f32
+// CHECK-NEXT:      affine.yield %[[Y]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    affine.yield %[[SUM]] : f32
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[SUM2]] : f32
+}
+
+
 // CHECK-LABEL: func.func @dead_affine_region_op
 func.func @dead_affine_region_op() {
   %c1 = arith.constant 1 : index

>From 29677768781dfb860c3cc34854e0c6984c8fc570 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 1 Nov 2024 14:46:04 +0100
Subject: [PATCH 14/23] Add missing maxnumf to vector reduce

---
 mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp |  2 ++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp            |  3 ++-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp            | 11 ++++++++++-
 3 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 1518f5dbae749..5eaa9d5eccd11 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -65,6 +65,8 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
               [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; })
           .Case(
               [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; })
+          .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
+          .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
           .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
           .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
           .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f1dbdb8eee455..984fed19dd788 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3895,8 +3895,9 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType,
   case arith::AtomicRMWKind::muli:
     return isa<IntegerType>(resultType);
   case arith::AtomicRMWKind::maximumf:
-    return isa<FloatType>(resultType);
   case arith::AtomicRMWKind::minimumf:
+  case arith::AtomicRMWKind::maxnumf:
+  case arith::AtomicRMWKind::minnumf:
     return isa<FloatType>(resultType);
   case arith::AtomicRMWKind::maxs: {
     auto intType = llvm::dyn_cast<IntegerType>(resultType);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..2dc07dba62d15 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -661,8 +661,17 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
   case arith::AtomicRMWKind::ori:
     return builder.create<vector::ReductionOp>(vector.getLoc(),
                                                CombiningKind::OR, vector);
-  // TODO: Add remaining reduction operations.
+  case arith::AtomicRMWKind::maxnumf:
+    return builder.create<vector::ReductionOp>(vector.getLoc(),
+                                               CombiningKind::MAXNUMF, vector);
+  case arith::AtomicRMWKind::minnumf:
+    return builder.create<vector::ReductionOp>(vector.getLoc(),
+                                               CombiningKind::MINNUMF, vector);
+  case arith::AtomicRMWKind::assign:
+    (void)emitOptionalError(loc, "Reduction operation type not supported (assign)");
+    break;
   default:
+    // Should this be an assert(false)?
     (void)emitOptionalError(loc, "Reduction operation type not supported");
     break;
   }

>From 76d4769b4a50653f47f4f9217224660a800da473 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 29 Nov 2024 12:09:27 +0100
Subject: [PATCH 15/23] Fix rebase

---
 mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp  | 3 +--
 .../Transforms/OptimizeAllocationLiveness.cpp              | 7 ++-----
 2 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
index a6e961a6d6439..b9d1df054390d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -166,7 +166,6 @@ struct RaiseMemrefDialect
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::affine::createRaiseMemrefToAffine() {
+std::unique_ptr<AffineScopePassBase> mlir::affine::createRaiseMemrefToAffine() {
   return std::make_unique<RaiseMemrefDialect>();
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
index 5178d4a62f374..0533358d11abc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp
@@ -95,14 +95,11 @@ struct OptimizeAllocationLiveness
   OptimizeAllocationLiveness() = default;
 
   void runOnOperation() override {
-    func::FuncOp func = getOperation();
-
-    if (func.isExternal())
-      return;
+    Operation* func = getOperation();
 
     BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func);
 
-    func.walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult {
+    func->walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult {
       if (!hasMemoryAllocEffect(memEffectOp))
         return WalkResult::advance();
 

>From d4d40f211114e26c51d281408d71c86dde2b8108 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 23 Dec 2024 19:10:05 +0100
Subject: [PATCH 16/23] Fix scalrep on nested loops

---
 mlir/lib/Dialect/Affine/Utils/Utils.cpp | 68 ++++++++++++++++++-------
 1 file changed, 49 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index bdc6d9763ae4a..4e1b0665f7cd7 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -958,24 +958,25 @@ static void findUnusedStore(AffineWriteOpInterface writeA,
 
 /// This attempts to find load-store pairs in the body of the loop
 /// that could be replaced by an iter_args variable on the loop. The
-/// initial load and the final store are moved out of the loop. For 
+/// initial load and the final store are moved out of the loop. For
 /// such a pair to be eligible:
 /// 1. the load must be followed by the store
 /// 2. the memref must not be read again after the store
-/// 3. the indices of the load and store must match AND be 
+/// 3. the indices of the load and store must match AND be
 /// loop-invariant for the given loop.
 ///
 /// This is a useful transformation as
-/// - it exposes reduction dependencies that can be extracted by --affine-parallelize
+/// - it exposes reduction dependencies that can be extracted by
+/// --affine-parallelize
 /// - it is a common pattern in code lowered from linalg.
-/// - it exposes more opportunities for forwarding of load/store by 
+/// - it exposes more opportunities for forwarding of load/store by
 /// moving the load/store out of the loop and into a scope.
-/// 
-static void findReductionVariablesAndRewrite(
+///
+static bool findReductionVariablesAndRewrite(
     LoopLikeOpInterface loop, PostDominanceInfo &postDominanceInfo,
     llvm::function_ref<bool(Value, Value)> mayAlias) {
   if (!loop.getLoopResults())
-    return;
+    return false;
 
   SmallVector<std::pair<AffineReadOpInterface, AffineWriteOpInterface>> result;
   auto *region = loop.getLoopRegions()[0];
@@ -1017,13 +1018,13 @@ static void findReductionVariablesAndRewrite(
       if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(
               asStore.getOperation(), asLoad, mayAlias))
         break;
-      
+
       // now let's just replace this pair of accesses with loop iter args
       result.push_back({asLoad, asStore});
     }
   }
   if (result.empty())
-    return;
+    return false;
   SmallVector<Value> newInitOperands;
   SmallVector<Value> newYieldOperands;
   IRRewriter rewriter(loop->getContext());
@@ -1043,7 +1044,7 @@ static void findReductionVariablesAndRewrite(
       });
   if (failed(rewritten)) {
     rewriter.cancelOpModification(loop->getParentOp());
-    return;
+    return false;
   }
   auto newLoop = *rewritten;
 
@@ -1064,6 +1065,7 @@ static void findReductionVariablesAndRewrite(
   rewriter.finalizeOpModification(newLoop->getParentOp());
   LLVM_DEBUG(llvm::dbgs() << "Replaced loop reduction variable: \n"
                           << newLoop << "\n");
+  return true;
 }
 
 // The load to load forwarding / redundant load elimination is similar to the
@@ -1153,24 +1155,52 @@ static void loadCSE(AffineReadOpInterface loadA,
 // currently only eliminates the stores only if no other loads/uses (other
 // than dealloc) remain.
 //
+void doForwarding(Operation *parentOp, DominanceInfo &domInfo,
+                  PostDominanceInfo &postDomInfo,
+                  llvm::function_ref<bool(Value, Value)> mayAlias);
+
 void mlir::affine::affineScalarReplace(Operation *parentOp,
                                        DominanceInfo &domInfo,
                                        PostDominanceInfo &postDomInfo,
                                        AliasAnalysis &aliasAnalysis) {
-  // Load op's whose results were replaced by those forwarded from stores.
-  SmallVector<Operation *, 8> opsToErase;
-
-  // A list of memref's that are potentially dead / could be eliminated.
-  SmallPtrSet<Value, 4> memrefsToErase;
 
   auto mayAlias = [&](Value val1, Value val2) -> bool {
     return !aliasAnalysis.alias(val1, val2).isNo();
   };
 
-  // scalarize reduction variables as iter_args
-  parentOp->walk([&](AffineForOp loop) {
-    findReductionVariablesAndRewrite(loop, postDomInfo, mayAlias);
-  });
+  bool continueWalk;
+  do {
+    continueWalk = false;
+
+    // Walk loops and rewrite reduction variables. Once a loop has been
+    // rewritten, we need to perform forwarding to eliminate the new store and
+    // loads introduced before and after the new loop. Then we need to continue
+    // doing that loop by loop.
+    parentOp->walk([&](AffineForOp loop) {
+      Operation *loopParent = loop->getParentOp();
+      bool rewritten =
+          findReductionVariablesAndRewrite(loop, postDomInfo, mayAlias);
+      if (rewritten && loopParent != parentOp) {
+        doForwarding(loopParent, domInfo, postDomInfo, mayAlias);
+        continueWalk = true;
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
+  } while (continueWalk);
+
+  // cleanup the parent
+  doForwarding(parentOp, domInfo, postDomInfo, mayAlias);
+}
+
+void doForwarding(Operation *parentOp, DominanceInfo &domInfo,
+                  PostDominanceInfo &postDomInfo,
+                  llvm::function_ref<bool(Value, Value)> mayAlias) {
+  // Load op's whose results were replaced by those forwarded from stores.
+  SmallVector<Operation *, 8> opsToErase;
+
+  // A list of memref's that are potentially dead / could be eliminated.
+  SmallPtrSet<Value, 4> memrefsToErase;
 
   // Walk all load's and perform store to load forwarding.
   parentOp->walk([&](AffineReadOpInterface loadOp) {

>From 5b38210cbe77ec48680ebc5fd2d26b4948cdd421 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 10 Mar 2025 14:40:21 +0100
Subject: [PATCH 17/23] Fix rebase

---
 .../Transforms/BufferDeallocation.cpp         | 693 ------------------
 .../Bufferization/Transforms/CMakeLists.txt   |   1 -
 2 files changed, 694 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
deleted file mode 100644
index 73d2d4e4ca427..0000000000000
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ /dev/null
@@ -1,693 +0,0 @@
-//===- BufferDeallocation.cpp - the impl for buffer deallocation ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements logic for computing correct alloc and dealloc positions.
-// Furthermore, buffer deallocation also adds required new clone operations to
-// ensure that all buffers are deallocated. The main class is the
-// BufferDeallocationPass class that implements the underlying algorithm. In
-// order to put allocations and deallocations at safe positions, it is
-// significantly important to put them into the correct blocks. However, the
-// liveness analysis does not pay attention to aliases, which can occur due to
-// branches (and their associated block arguments) in general. For this purpose,
-// BufferDeallocation firstly finds all possible aliases for a single value
-// (using the BufferViewFlowAnalysis class). Consider the following example:
-//
-// ^bb0(%arg0):
-//   cf.cond_br %cond, ^bb1, ^bb2
-// ^bb1:
-//   cf.br ^exit(%arg0)
-// ^bb2:
-//   %new_value = ...
-//   cf.br ^exit(%new_value)
-// ^exit(%arg1):
-//   return %arg1;
-//
-// We should place the dealloc for %new_value in exit. However, we have to free
-// the buffer in the same block, because it cannot be freed in the post
-// dominator. However, this requires a new clone buffer for %arg1 that will
-// contain the actual contents. Using the class BufferViewFlowAnalysis, we
-// will find out that %new_value has a potential alias %arg1. In order to find
-// the dealloc position we have to find all potential aliases, iterate over
-// their uses and find the common post-dominator block (note that additional
-// clones and buffers remove potential aliases and will influence the placement
-// of the deallocs). In all cases, the computed block can be safely used to free
-// the %new_value buffer (may be exit or bb2) as it will die and we can use
-// liveness information to determine the exact operation after which we have to
-// insert the dealloc. However, the algorithm supports introducing clone buffers
-// and placing deallocs in safe locations to ensure that all buffers will be
-// freed in the end.
-//
-// TODO:
-// The current implementation does not support explicit-control-flow loops and
-// the resulting code will be invalid with respect to program semantics.
-// However, structured control-flow loops are fully supported. Furthermore, it
-// doesn't accept functions which return buffers already.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
-
-#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "llvm/ADT/SetOperations.h"
-
-namespace mlir {
-namespace bufferization {
-#define GEN_PASS_DEF_BUFFERDEALLOCATION
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
-} // namespace bufferization
-} // namespace mlir
-
-using namespace mlir;
-using namespace mlir::bufferization;
-
-/// Walks over all immediate return-like terminators in the given region.
-static LogicalResult walkReturnOperations(
-    Region *region,
-    llvm::function_ref<LogicalResult(RegionBranchTerminatorOpInterface)> func) {
-  for (Block &block : *region) {
-    Operation *terminator = block.getTerminator();
-    // Skip non region-return-like terminators.
-    if (auto regionTerminator =
-            dyn_cast<RegionBranchTerminatorOpInterface>(terminator)) {
-      if (failed(func(regionTerminator)))
-        return failure();
-    }
-  }
-  return success();
-}
-
-/// Checks if all operations that have at least one attached region implement
-/// the RegionBranchOpInterface. This is not required in edge cases, where we
-/// have a single attached region and the parent operation has no results.
-static bool validateSupportedControlFlow(Operation *op) {
-  WalkResult result = op->walk([&](Operation *operation) {
-    // Only check ops that are inside a function.
-    if (!operation->getParentOfType<func::FuncOp>())
-      return WalkResult::advance();
-
-    auto regions = operation->getRegions();
-    // Walk over all operations in a region and check if the operation has at
-    // least one region and implements the RegionBranchOpInterface. If there
-    // is an operation that does not fulfill this condition, we cannot apply
-    // the deallocation steps. Furthermore, we accept cases, where we have a
-    // region that returns no results, since, in that case, the intra-region
-    // control flow does not affect the transformation.
-    size_t size = regions.size();
-    if (((size == 1 && !operation->getResults().empty()) || size > 1) &&
-        !dyn_cast<RegionBranchOpInterface>(operation)) {
-      operation->emitError("All operations with attached regions need to "
-                           "implement the RegionBranchOpInterface.");
-    }
-
-    return WalkResult::advance();
-  });
-  return !result.wasSkipped();
-}
-
-namespace {
-
-//===----------------------------------------------------------------------===//
-// Backedges analysis
-//===----------------------------------------------------------------------===//
-
-/// A straight-forward program analysis which detects loop backedges induced by
-/// explicit control flow.
-class Backedges {
-public:
-  using BlockSetT = SmallPtrSet<Block *, 16>;
-  using BackedgeSetT = llvm::DenseSet<std::pair<Block *, Block *>>;
-
-public:
-  /// Constructs a new backedges analysis using the op provided.
-  Backedges(Operation *op) { recurse(op); }
-
-  /// Returns the number of backedges formed by explicit control flow.
-  size_t size() const { return edgeSet.size(); }
-
-  /// Returns the start iterator to loop over all backedges.
-  BackedgeSetT::const_iterator begin() const { return edgeSet.begin(); }
-
-  /// Returns the end iterator to loop over all backedges.
-  BackedgeSetT::const_iterator end() const { return edgeSet.end(); }
-
-private:
-  /// Enters the current block and inserts a backedge into the `edgeSet` if we
-  /// have already visited the current block. The inserted edge links the given
-  /// `predecessor` with the `current` block.
-  bool enter(Block &current, Block *predecessor) {
-    bool inserted = visited.insert(&current).second;
-    if (!inserted)
-      edgeSet.insert(std::make_pair(predecessor, &current));
-    return inserted;
-  }
-
-  /// Leaves the current block.
-  void exit(Block &current) { visited.erase(&current); }
-
-  /// Recurses into the given operation while taking all attached regions into
-  /// account.
-  void recurse(Operation *op) {
-    Block *current = op->getBlock();
-    // If the current op implements the `BranchOpInterface`, there can be
-    // cycles in the scope of all successor blocks.
-    if (isa<BranchOpInterface>(op)) {
-      for (Block *succ : current->getSuccessors())
-        recurse(*succ, current);
-    }
-    // Recurse into all distinct regions and check for explicit control-flow
-    // loops.
-    for (Region &region : op->getRegions()) {
-      if (!region.empty())
-        recurse(region.front(), current);
-    }
-  }
-
-  /// Recurses into explicit control-flow structures that are given by
-  /// the successor relation defined on the block level.
-  void recurse(Block &block, Block *predecessor) {
-    // Try to enter the current block. If this is not possible, we are
-    // currently processing this block and can safely return here.
-    if (!enter(block, predecessor))
-      return;
-
-    // Recurse into all operations and successor blocks.
-    for (Operation &op : block.getOperations())
-      recurse(&op);
-
-    // Leave the current block.
-    exit(block);
-  }
-
-  /// Stores all blocks that are currently visited and on the processing stack.
-  BlockSetT visited;
-
-  /// Stores all backedges in the format (source, target).
-  BackedgeSetT edgeSet;
-};
-
-//===----------------------------------------------------------------------===//
-// BufferDeallocation
-//===----------------------------------------------------------------------===//
-
-/// The buffer deallocation transformation which ensures that all allocs in the
-/// program have a corresponding de-allocation. As a side-effect, it might also
-/// introduce clones that in turn leads to additional deallocations.
-class BufferDeallocation : public BufferPlacementTransformationBase {
-public:
-  using AliasAllocationMapT =
-      llvm::DenseMap<Value, bufferization::AllocationOpInterface>;
-
-  BufferDeallocation(Operation *op)
-      : BufferPlacementTransformationBase(op), dominators(op),
-        postDominators(op) {}
-
-  /// Checks if all allocation operations either provide an already existing
-  /// deallocation operation or implement the AllocationOpInterface. In
-  /// addition, this method initializes the internal alias to
-  /// AllocationOpInterface mapping in order to get compatible
-  /// AllocationOpInterface implementations for aliases.
-  LogicalResult prepare() {
-    for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
-      // Get the defining allocation operation.
-      Value alloc = std::get<0>(entry);
-      auto allocationInterface =
-          alloc.getDefiningOp<bufferization::AllocationOpInterface>();
-      // If there is no existing deallocation operation and no implementation of
-      // the AllocationOpInterface, we cannot apply the BufferDeallocation pass.
-      if (!std::get<1>(entry) && !allocationInterface) {
-        return alloc.getDefiningOp()->emitError(
-            "Allocation is not deallocated explicitly nor does the operation "
-            "implement the AllocationOpInterface.");
-      }
-
-      // Register the current allocation interface implementation.
-      aliasToAllocations[alloc] = allocationInterface;
-
-      // Get the alias information for the current allocation node.
-      for (Value alias : aliases.resolve(alloc)) {
-        // TODO: check for incompatible implementations of the
-        // AllocationOpInterface. This could be realized by promoting the
-        // AllocationOpInterface to a DialectInterface.
-        aliasToAllocations[alias] = allocationInterface;
-      }
-    }
-    return success();
-  }
-
-  /// Performs the actual placement/creation of all temporary clone and dealloc
-  /// nodes.
-  LogicalResult deallocate() {
-    // Add additional clones that are required.
-    if (failed(introduceClones()))
-      return failure();
-
-    // Place deallocations for all allocation entries.
-    return placeDeallocs();
-  }
-
-private:
-  /// Introduces required clone operations to avoid memory leaks.
-  LogicalResult introduceClones() {
-    // Initialize the set of values that require a dedicated memory free
-    // operation since their operands cannot be safely deallocated in a post
-    // dominator.
-    SetVector<Value> valuesToFree;
-    llvm::SmallDenseSet<std::tuple<Value, Block *>> visitedValues;
-    SmallVector<std::tuple<Value, Block *>, 8> toProcess;
-
-    // Check dominance relation for proper dominance properties. If the given
-    // value node does not dominate an alias, we will have to create a clone in
-    // order to free all buffers that can potentially leak into a post
-    // dominator.
-    auto findUnsafeValues = [&](Value source, Block *definingBlock) {
-      auto it = aliases.find(source);
-      if (it == aliases.end())
-        return;
-      for (Value value : it->second) {
-        if (valuesToFree.count(value) > 0)
-          continue;
-        Block *parentBlock = value.getParentBlock();
-        // Check whether we have to free this particular block argument or
-        // generic value. We have to free the current alias if it is either
-        // defined in a non-dominated block or it is defined in the same block
-        // but the current value is not dominated by the source value.
-        if (!dominators.dominates(definingBlock, parentBlock) ||
-            (definingBlock == parentBlock && isa<BlockArgument>(value))) {
-          toProcess.emplace_back(value, parentBlock);
-          valuesToFree.insert(value);
-        } else if (visitedValues.insert(std::make_tuple(value, definingBlock))
-                       .second)
-          toProcess.emplace_back(value, definingBlock);
-      }
-    };
-
-    // Detect possibly unsafe aliases starting from all allocations.
-    for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
-      Value allocValue = std::get<0>(entry);
-      findUnsafeValues(allocValue, allocValue.getDefiningOp()->getBlock());
-    }
-    // Try to find block arguments that require an explicit free operation
-    // until we reach a fix point.
-    while (!toProcess.empty()) {
-      auto current = toProcess.pop_back_val();
-      findUnsafeValues(std::get<0>(current), std::get<1>(current));
-    }
-
-    // Update buffer aliases to ensure that we free all buffers and block
-    // arguments at the correct locations.
-    aliases.remove(valuesToFree);
-
-    // Add new allocs and additional clone operations.
-    for (Value value : valuesToFree) {
-      if (failed(isa<BlockArgument>(value)
-                     ? introduceBlockArgCopy(cast<BlockArgument>(value))
-                     : introduceValueCopyForRegionResult(value)))
-        return failure();
-
-      // Register the value to require a final dealloc. Note that we do not have
-      // to assign a block here since we do not want to move the allocation node
-      // to another location.
-      allocs.registerAlloc(std::make_tuple(value, nullptr));
-    }
-    return success();
-  }
-
-  /// Introduces temporary clones in all predecessors and copies the source
-  /// values into the newly allocated buffers.
-  LogicalResult introduceBlockArgCopy(BlockArgument blockArg) {
-    // Allocate a buffer for the current block argument in the block of
-    // the associated value (which will be a predecessor block by
-    // definition).
-    Block *block = blockArg.getOwner();
-    for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
-      // Get the terminator and the value that will be passed to our
-      // argument.
-      Operation *terminator = (*it)->getTerminator();
-      auto branchInterface = cast<BranchOpInterface>(terminator);
-      SuccessorOperands operands =
-          branchInterface.getSuccessorOperands(it.getSuccessorIndex());
-
-      // Query the associated source value.
-      Value sourceValue = operands[blockArg.getArgNumber()];
-      if (!sourceValue) {
-        return failure();
-      }
-      // Wire new clone and successor operand.
-      // Create a new clone at the current location of the terminator.
-      auto clone = introduceCloneBuffers(sourceValue, terminator);
-      if (failed(clone))
-        return failure();
-      operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
-    }
-
-    // Check whether the block argument has implicitly defined predecessors via
-    // the RegionBranchOpInterface. This can be the case if the current block
-    // argument belongs to the first block in a region and the parent operation
-    // implements the RegionBranchOpInterface.
-    Region *argRegion = block->getParent();
-    Operation *parentOp = argRegion->getParentOp();
-    RegionBranchOpInterface regionInterface;
-    if (&argRegion->front() != block ||
-        !(regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)))
-      return success();
-
-    if (failed(introduceClonesForRegionSuccessors(
-            regionInterface, argRegion->getParentOp()->getRegions(), blockArg,
-            [&](RegionSuccessor &successorRegion) {
-              // Find a predecessor of our argRegion.
-              return successorRegion.getSuccessor() == argRegion;
-            })))
-      return failure();
-
-    // Check whether the block argument belongs to an entry region of the
-    // parent operation. In this case, we have to introduce an additional clone
-    // for buffer that is passed to the argument.
-    SmallVector<RegionSuccessor, 2> successorRegions;
-    regionInterface.getSuccessorRegions(/*point=*/RegionBranchPoint::parent(),
-                                        successorRegions);
-    auto *it =
-        llvm::find_if(successorRegions, [&](RegionSuccessor &successorRegion) {
-          return successorRegion.getSuccessor() == argRegion;
-        });
-    if (it == successorRegions.end())
-      return success();
-
-    // Determine the actual operand to introduce a clone for and rewire the
-    // operand to point to the clone instead.
-    auto operands = regionInterface.getEntrySuccessorOperands(argRegion);
-    size_t operandIndex =
-        llvm::find(it->getSuccessorInputs(), blockArg).getIndex() +
-        operands.getBeginOperandIndex();
-    Value operand = parentOp->getOperand(operandIndex);
-    assert(operand ==
-               operands[operandIndex - operands.getBeginOperandIndex()] &&
-           "region interface operands don't match parentOp operands");
-    auto clone = introduceCloneBuffers(operand, parentOp);
-    if (failed(clone))
-      return failure();
-
-    parentOp->setOperand(operandIndex, *clone);
-    return success();
-  }
-
-  /// Introduces temporary clones in front of all associated nested-region
-  /// terminators and copies the source values into the newly allocated buffers.
-  LogicalResult introduceValueCopyForRegionResult(Value value) {
-    // Get the actual result index in the scope of the parent terminator.
-    Operation *operation = value.getDefiningOp();
-    auto regionInterface = cast<RegionBranchOpInterface>(operation);
-    // Filter successors that return to the parent operation.
-    auto regionPredicate = [&](RegionSuccessor &successorRegion) {
-      // If the RegionSuccessor has no associated successor, it will return to
-      // its parent operation.
-      return !successorRegion.getSuccessor();
-    };
-    // Introduce a clone for all region "results" that are returned to the
-    // parent operation. This is required since the parent's result value has
-    // been considered critical. Therefore, the algorithm assumes that a clone
-    // of a previously allocated buffer is returned by the operation (like in
-    // the case of a block argument).
-    return introduceClonesForRegionSuccessors(
-        regionInterface, operation->getRegions(), value, regionPredicate);
-  }
-
-  /// Introduces buffer clones for all terminators in the given regions. The
-  /// regionPredicate is applied to every successor region in order to restrict
-  /// the clones to specific regions.
-  template <typename TPredicate>
-  LogicalResult introduceClonesForRegionSuccessors(
-      RegionBranchOpInterface regionInterface, MutableArrayRef<Region> regions,
-      Value argValue, const TPredicate &regionPredicate) {
-    for (Region &region : regions) {
-      // Query the regionInterface to get all successor regions of the current
-      // one.
-      SmallVector<RegionSuccessor, 2> successorRegions;
-      regionInterface.getSuccessorRegions(region, successorRegions);
-      // Try to find a matching region successor.
-      RegionSuccessor *regionSuccessor =
-          llvm::find_if(successorRegions, regionPredicate);
-      if (regionSuccessor == successorRegions.end())
-        continue;
-      // Get the operand index in the context of the current successor input
-      // bindings.
-      size_t operandIndex =
-          llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
-              .getIndex();
-
-      // Iterate over all immediate terminator operations to introduce
-      // new buffer allocations. Thereby, the appropriate terminator operand
-      // will be adjusted to point to the newly allocated buffer instead.
-      if (failed(walkReturnOperations(
-              &region, [&](RegionBranchTerminatorOpInterface terminator) {
-                // Get the actual mutable operands for this terminator op.
-                auto terminatorOperands =
-                    terminator.getMutableSuccessorOperands(*regionSuccessor);
-                // Extract the source value from the current terminator.
-                // This conversion needs to exist on a separate line due to a
-                // bug in GCC conversion analysis.
-                OperandRange immutableTerminatorOperands = terminatorOperands;
-                Value sourceValue = immutableTerminatorOperands[operandIndex];
-                // Create a new clone at the current location of the terminator.
-                auto clone = introduceCloneBuffers(sourceValue, terminator);
-                if (failed(clone))
-                  return failure();
-                // Wire clone and terminator operand.
-                terminatorOperands.slice(operandIndex, 1).assign(*clone);
-                return success();
-              })))
-        return failure();
-    }
-    return success();
-  }
-
-  /// Creates a new memory allocation for the given source value and clones
-  /// its content into the newly allocated buffer. The terminator operation is
-  /// used to insert the clone operation at the right place.
-  FailureOr<Value> introduceCloneBuffers(Value sourceValue,
-                                         Operation *terminator) {
-    // Avoid multiple clones of the same source value. This can happen in the
-    // presence of loops when a branch acts as a backedge while also having
-    // another successor that returns to its parent operation. Note: that
-    // copying copied buffers can introduce memory leaks since the invariant of
-    // BufferDeallocation assumes that a buffer will be only cloned once into a
-    // temporary buffer. Hence, the construction of clone chains introduces
-    // additional allocations that are not tracked automatically by the
-    // algorithm.
-    if (clonedValues.contains(sourceValue))
-      return sourceValue;
-    // Create a new clone operation that copies the contents of the old
-    // buffer to the new one.
-    auto clone = buildClone(terminator, sourceValue);
-    if (succeeded(clone)) {
-      // Remember the clone of original source value.
-      clonedValues.insert(*clone);
-    }
-    return clone;
-  }
-
-  /// Finds correct dealloc positions according to the algorithm described at
-  /// the top of the file for all alloc nodes and block arguments that can be
-  /// handled by this analysis.
-  LogicalResult placeDeallocs() {
-    // Move or insert deallocs using the previously computed information.
-    // These deallocations will be linked to their associated allocation nodes
-    // since they don't have any aliases that can (potentially) increase their
-    // liveness.
-    for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
-      Value alloc = std::get<0>(entry);
-      auto aliasesSet = aliases.resolve(alloc);
-      assert(!aliasesSet.empty() && "must contain at least one alias");
-
-      // Determine the actual block to place the dealloc and get liveness
-      // information.
-      Block *placementBlock =
-          findCommonDominator(alloc, aliasesSet, postDominators);
-      const LivenessBlockInfo *livenessInfo =
-          liveness.getLiveness(placementBlock);
-
-      // We have to ensure that the dealloc will be after the last use of all
-      // aliases of the given value. We first assume that there are no uses in
-      // the placementBlock and that we can safely place the dealloc at the
-      // beginning.
-      Operation *endOperation = &placementBlock->front();
-
-      // Iterate over all aliases and ensure that the endOperation will point
-      // to the last operation of all potential aliases in the placementBlock.
-      for (Value alias : aliasesSet) {
-        // Ensure that the start operation is at least the defining operation of
-        // the current alias to avoid invalid placement of deallocs for aliases
-        // without any uses.
-        Operation *beforeOp = endOperation;
-        if (alias.getDefiningOp() &&
-            !(beforeOp = placementBlock->findAncestorOpInBlock(
-                  *alias.getDefiningOp())))
-          continue;
-
-        Operation *aliasEndOperation =
-            livenessInfo->getEndOperation(alias, beforeOp);
-        // Check whether the aliasEndOperation lies in the desired block and
-        // whether it is behind the current endOperation. If yes, this will be
-        // the new endOperation.
-        if (aliasEndOperation->getBlock() == placementBlock &&
-            endOperation->isBeforeInBlock(aliasEndOperation))
-          endOperation = aliasEndOperation;
-      }
-      // endOperation is the last operation behind which we can safely store
-      // the dealloc taking all potential aliases into account.
-
-      // If there is an existing dealloc, move it to the right place.
-      Operation *deallocOperation = std::get<1>(entry);
-      if (deallocOperation) {
-        deallocOperation->moveAfter(endOperation);
-      } else {
-        // If the Dealloc position is at the terminator operation of the
-        // block, then the value should escape from a deallocation.
-        Operation *nextOp = endOperation->getNextNode();
-        if (!nextOp)
-          continue;
-        // If there is no dealloc node, insert one in the right place.
-        if (failed(buildDealloc(nextOp, alloc)))
-          return failure();
-      }
-    }
-    return success();
-  }
-
-  /// Builds a deallocation operation compatible with the given allocation
-  /// value. If there is no registered AllocationOpInterface implementation for
-  /// the given value (e.g. in the case of a function parameter), this method
-  /// builds a memref::DeallocOp.
-  LogicalResult buildDealloc(Operation *op, Value alloc) {
-    OpBuilder builder(op);
-    auto it = aliasToAllocations.find(alloc);
-    if (it != aliasToAllocations.end()) {
-      // Call the allocation op interface to build a supported and
-      // compatible deallocation operation.
-      auto dealloc = it->second.buildDealloc(builder, alloc);
-      if (!dealloc)
-        return op->emitError()
-               << "allocations without compatible deallocations are "
-                  "not supported";
-    } else {
-      // Build a "default" DeallocOp for unknown allocation sources.
-      builder.create<memref::DeallocOp>(alloc.getLoc(), alloc);
-    }
-    return success();
-  }
-
-  /// Builds a clone operation compatible with the given allocation value. If
-  /// there is no registered AllocationOpInterface implementation for the given
-  /// value (e.g. in the case of a function parameter), this method builds a
-  /// bufferization::CloneOp.
-  FailureOr<Value> buildClone(Operation *op, Value alloc) {
-    OpBuilder builder(op);
-    auto it = aliasToAllocations.find(alloc);
-    if (it != aliasToAllocations.end()) {
-      // Call the allocation op interface to build a supported and
-      // compatible clone operation.
-      auto clone = it->second.buildClone(builder, alloc);
-      if (clone)
-        return *clone;
-      return (LogicalResult)(op->emitError()
-                             << "allocations without compatible clone ops "
-                                "are not supported");
-    }
-    // Build a "default" CloneOp for unknown allocation sources.
-    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
-        .getResult();
-  }
-
-  /// The dominator info to find the appropriate start operation to move the
-  /// allocs.
-  DominanceInfo dominators;
-
-  /// The post dominator info to move the dependent allocs in the right
-  /// position.
-  PostDominanceInfo postDominators;
-
-  /// Stores already cloned buffers to avoid additional clones of clones.
-  ValueSetT clonedValues;
-
-  /// Maps aliases to their source allocation interfaces (inverse mapping).
-  AliasAllocationMapT aliasToAllocations;
-};
-
-//===----------------------------------------------------------------------===//
-// BufferDeallocationPass
-//===----------------------------------------------------------------------===//
-
-/// The actual buffer deallocation pass that inserts and moves dealloc nodes
-/// into the right positions. Furthermore, it inserts additional clones if
-/// necessary. It uses the algorithm described at the top of the file.
-struct BufferDeallocationPass
-    : public bufferization::impl::BufferDeallocationBase<
-          BufferDeallocationPass> {
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<bufferization::BufferizationDialect>();
-    registry.insert<memref::MemRefDialect>();
-  }
-
-  void runOnOperation() override {
-    Operation* func = getOperation();
-    if (func->getRegion(0).empty())
-      return;
-
-    if (failed(deallocateBuffers(func)))
-      signalPassFailure();
-  }
-};
-
-} // namespace
-
-LogicalResult bufferization::deallocateBuffers(Operation *op) {
-  if (isa<ModuleOp>(op)) {
-    WalkResult result = op->walk([&](func::FuncOp funcOp) {
-      if (failed(deallocateBuffers(funcOp)))
-        return WalkResult::interrupt();
-      return WalkResult::advance();
-    });
-    return success(!result.wasInterrupted());
-  }
-
-  // Ensure that there are supported loops only.
-  Backedges backedges(op);
-  if (backedges.size()) {
-    op->emitError("Only structured control-flow loops are supported.");
-    return failure();
-  }
-
-  // Check that the control flow structures are supported.
-  if (!validateSupportedControlFlow(op))
-    return failure();
-
-  // Gather all required allocation nodes and prepare the deallocation phase.
-  BufferDeallocation deallocation(op);
-
-  // Check for supported AllocationOpInterface implementations and prepare the
-  // internal deallocation pass.
-  if (failed(deallocation.prepare()))
-    return failure();
-
-  // Place all required temporary clone and dealloc nodes.
-  if (failed(deallocation.deallocate()))
-    return failure();
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferDeallocationPass construction
-//===----------------------------------------------------------------------===//
-
-std::unique_ptr<Pass> mlir::bufferization::createBufferDeallocationPass() {
-  return std::make_unique<BufferDeallocationPass>();
-}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 50104e8f8346b..7c38621be1bb5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect_library(MLIRBufferizationTransforms
   Bufferize.cpp
-  BufferDeallocation.cpp
   BufferDeallocationSimplification.cpp
   BufferOptimizations.cpp
   BufferResultsToOutParams.cpp

>From 9bb5bd0f8e0c93d737e11322d32b29bcb131df67 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 10 Mar 2025 14:57:50 +0100
Subject: [PATCH 18/23] Fix rebase

---
 .../Bufferization/Transforms/Passes.td        | 73 -------------------
 .../Transforms/AffineDataCopyGeneration.cpp   |  5 --
 2 files changed, 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index a4863a1b5c6d7..d9e559d42df47 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -15,79 +15,6 @@ class BufferScopePass<string name>
   : PassBase<name, "::mlir::bufferization::BufferScopePassBase">;
 
 
-def BufferDeallocation : BufferScopePass<"buffer-deallocation"> {
-  let summary = "Adds all required dealloc operations for all allocations in "
-                "the input program";
-  let description = [{
-    This pass implements an algorithm to automatically introduce all required
-    deallocation operations for all buffers in the input program. This ensures
-    that the resulting program does not have any memory leaks.
-
-
-    Input
-
-    ```mlir
-    #map0 = affine_map<(d0) -> (d0)>
-    module {
-      func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
-        cf.cond_br %arg0, ^bb1, ^bb2
-      ^bb1:
-        cf.br ^bb3(%arg1 : memref<2xf32>)
-      ^bb2:
-        %0 = memref.alloc() : memref<2xf32>
-        linalg.generic {
-          indexing_maps = [#map0, #map0],
-          iterator_types = ["parallel"]} %arg1, %0 {
-        ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
-          %tmp1 = exp %gen1_arg0 : f32
-          linalg.yield %tmp1 : f32
-        }: memref<2xf32>, memref<2xf32>
-        cf.br ^bb3(%0 : memref<2xf32>)
-      ^bb3(%1: memref<2xf32>):
-        "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
-        return
-      }
-    }
-
-    ```
-
-    Output
-
-    ```mlir
-    #map0 = affine_map<(d0) -> (d0)>
-    module {
-      func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
-        cf.cond_br %arg0, ^bb1, ^bb2
-      ^bb1:  // pred: ^bb0
-        %0 = memref.alloc() : memref<2xf32>
-        memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
-        cf.br ^bb3(%0 : memref<2xf32>)
-      ^bb2:  // pred: ^bb0
-        %1 = memref.alloc() : memref<2xf32>
-        linalg.generic {
-          indexing_maps = [#map0, #map0],
-          iterator_types = ["parallel"]} %arg1, %1 {
-        ^bb0(%arg3: f32, %arg4: f32):
-          %4 = exp %arg3 : f32
-          linalg.yield %4 : f32
-        }: memref<2xf32>, memref<2xf32>
-        %2 = memref.alloc() : memref<2xf32>
-        memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
-        dealloc %1 : memref<2xf32>
-        cf.br ^bb3(%2 : memref<2xf32>)
-      ^bb3(%3: memref<2xf32>):  // 2 preds: ^bb1, ^bb2
-        memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
-        dealloc %3 : memref<2xf32>
-        return
-      }
-
-    }
-    ```
-
-  }];
-  let constructor = "mlir::bufferization::createBufferDeallocationPass()";
-}
-
 def OwnershipBasedBufferDeallocation : BufferScopePass<
     "ownership-based-buffer-deallocation"> {
   let summary = "Adds all required dealloc operations for all allocations in "
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 6ea9e1d02bcb6..2ed1393cd0a41 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -87,13 +87,8 @@ struct AffineDataCopyGeneration
 
 /// Generates copies for memref's living in 'slowMemorySpace' into newly created
 /// buffers in 'fastMemorySpace', and replaces memory operations to the former
-<<<<<<< HEAD
-/// by the latter.
-std::unique_ptr<OperationPass<func::FuncOp>>
-=======
 /// by the latter. Only load op's handled for now.
 std::unique_ptr<AffineScopePassBase>
->>>>>>> 9bd74f961815 (Make affine and bufferization pass applicable to any AffineScopeOp/AutomaticAllocationScope)
 mlir::affine::createAffineDataCopyGenerationPass(
     unsigned slowMemorySpace, unsigned fastMemorySpace, unsigned tagMemorySpace,
     int minDmaTransferSize, uint64_t fastMemCapacityBytes) {

>From 5fc954d5f193e5c727530462128a5c5b3668e06a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Tue, 11 Mar 2025 12:54:03 +0100
Subject: [PATCH 19/23] Fix rebase

---
 .../Dialect/Bufferization/Transforms/Passes.td   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index d9e559d42df47..972dd2236d672 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -15,7 +15,7 @@ class BufferScopePass<string name>
   : PassBase<name, "::mlir::bufferization::BufferScopePassBase">;
 
 
-def OwnershipBasedBufferDeallocation : BufferScopePass<
+def OwnershipBasedBufferDeallocationPass : BufferScopePass<
     "ownership-based-buffer-deallocation"> {
   let summary = "Adds all required dealloc operations for all allocations in "
                 "the input program";
@@ -156,7 +156,7 @@ def OwnershipBasedBufferDeallocation : BufferScopePass<
   ];
 }
 
-def BufferDeallocationSimplification :
+def BufferDeallocationSimplificationPass :
     BufferScopePass<"buffer-deallocation-simplification"> {
   let summary = "Optimizes `bufferization.dealloc` operation for more "
                 "efficient codegen";
@@ -173,7 +173,7 @@ def BufferDeallocationSimplification :
   ];
 }
 
-def OptimizeAllocationLiveness
+def OptimizeAllocationLivenessPass
     : BufferScopePass<"optimize-allocation-liveness"> {
   let summary = "This pass optimizes the liveness of temp allocations in the "
                 "input function";
@@ -188,7 +188,7 @@ def OptimizeAllocationLiveness
   let dependentDialects = ["mlir::memref::MemRefDialect"];
 }
 
-def LowerDeallocations : BufferScopePass<"bufferization-lower-deallocations"> {
+def LowerDeallocationsPass : BufferScopePass<"bufferization-lower-deallocations"> {
   let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
                 "operations";
   let description = [{
@@ -208,7 +208,7 @@ def LowerDeallocations : BufferScopePass<"bufferization-lower-deallocations"> {
   ];
 }
 
-def BufferHoisting : BufferScopePass<"buffer-hoisting"> {
+def BufferHoistingPass : BufferScopePass<"buffer-hoisting"> {
   let summary = "Optimizes placement of allocation operations by moving them "
                 "into common dominators and out of nested regions";
   let description = [{
@@ -217,7 +217,7 @@ def BufferHoisting : BufferScopePass<"buffer-hoisting"> {
   }];
 }
 
-def BufferLoopHoisting : BufferScopePass<"buffer-loop-hoisting"> {
+def BufferLoopHoistingPass : BufferScopePass<"buffer-loop-hoisting"> {
   let summary = "Optimizes placement of allocation operations by moving them "
                 "out of loop nests";
   let description = [{
@@ -466,7 +466,7 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
   ];
 }
 
-def PromoteBuffersToStack : BufferScopePass<"promote-buffers-to-stack"> {
+def PromoteBuffersToStackPass : BufferScopePass<"promote-buffers-to-stack"> {
   let summary = "Promotes heap-based allocations to automatically managed "
                 "stack-based allocations";
   let description = [{
@@ -486,7 +486,7 @@ def PromoteBuffersToStack : BufferScopePass<"promote-buffers-to-stack"> {
   ];
 }
 
-def EmptyTensorElimination : BufferScopePass<"eliminate-empty-tensors"> {
+def EmptyTensorEliminationPass : BufferScopePass<"eliminate-empty-tensors"> {
   let summary = "Try to eliminate all tensor.empty ops.";
   let description = [{
     Try to eliminate "tensor.empty" ops inside `op`. This transformation looks

>From 08a424a072acf1e2b58ab11b41b95ead4aaa8601 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Tue, 15 Apr 2025 01:54:19 +0200
Subject: [PATCH 20/23] Patch dataflow framework

It had a bug whereby the regions of an
op are marked as dead code, and their
successors are not properly populated,
if any operand of a RegionBranchOpInterface
is not folded to a constant. The interface
already is meant to support partially-folded
operands, so we should use that instead of giving up.

Some changes are also just to make debug prints
more helpful. Some of them were printing pointer
addresses.
---
 .../mlir/Analysis/DataFlow/DeadCodeAnalysis.h |  4 +++
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    | 34 ++++++++++++-------
 mlir/lib/Analysis/DataFlowFramework.cpp       |  8 ++---
 3 files changed, 30 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 2250db823b551..3652435e4c59e 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -223,6 +223,10 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
   /// Get the constant values of the operands of the operation. Returns
   /// std::nullopt if any of the operand lattices are uninitialized.
   std::optional<SmallVector<Attribute>> getOperandValues(Operation *op);
+  
+  /// Get the constant values of the operands of the operation.
+  /// If the operand lattices are uninitialized, add a null attribute for those.
+  SmallVector<Attribute> getOperandValuesBestEffort(Operation *op);
 
   /// The top-level operation the analysis is running on. This is used to detect
   /// if a callable is outside the scope of the analysis and thus must be
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index e805e21d878bf..ca14d58ab66d7 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -341,15 +341,20 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
 /// constant value lattices are uninitialized, return std::nullopt to indicate
 /// the analysis should bail out.
 static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
-    Operation *op,
+    Operation *op, bool failIfAnyNull,
     function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
   SmallVector<Attribute> operands;
   operands.reserve(op->getNumOperands());
   for (Value operand : op->getOperands()) {
     const Lattice<ConstantValue> *cv = getLattice(operand);
     // If any of the operands' values are uninitialized, bail out.
-    if (cv->getValue().isUninitialized())
-      return {};
+    if (cv->getValue().isUninitialized()) {
+      if (failIfAnyNull)
+        return {};
+      operands.emplace_back();
+      continue;
+    }
+
     operands.push_back(cv->getValue().getConstantValue());
   }
   return operands;
@@ -357,7 +362,16 @@ static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
 
 std::optional<SmallVector<Attribute>>
 DeadCodeAnalysis::getOperandValues(Operation *op) {
-  return getOperandValuesImpl(op, [&](Value value) {
+  return getOperandValuesImpl(op, true, [&](Value value) {
+    auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
+    lattice->useDefSubscribe(this);
+    return lattice;
+  });
+}
+
+SmallVector<Attribute>
+DeadCodeAnalysis::getOperandValuesBestEffort(Operation *op) {
+  return *getOperandValuesImpl(op, false, [&](Value value) {
     auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
     lattice->useDefSubscribe(this);
     return lattice;
@@ -366,11 +380,9 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
 
 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
   // Try to deduce a single successor for the branch.
-  std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
-  if (!operands)
-    return;
+  SmallVector<Attribute> operands = getOperandValuesBestEffort(branch);
 
-  if (Block *successor = branch.getSuccessorForOperands(*operands)) {
+  if (Block *successor = branch.getSuccessorForOperands(operands)) {
     markEdgeLive(branch->getBlock(), successor);
   } else {
     // Otherwise, mark all successors as executable and outgoing edges.
@@ -382,12 +394,10 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
 void DeadCodeAnalysis::visitRegionBranchOperation(
     RegionBranchOpInterface branch) {
   // Try to deduce which regions are executable.
-  std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
-  if (!operands)
-    return;
+  SmallVector<Attribute> operands = getOperandValuesBestEffort(branch);
 
   SmallVector<RegionSuccessor> successors;
-  branch.getEntrySuccessorRegions(*operands, successors);
+  branch.getEntrySuccessorRegions(operands, successors);
   for (const RegionSuccessor &successor : successors) {
     // The successor can be either an entry block or the parent operation.
     ProgramPoint *point =
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 29f57c602f9cb..cfa880e658f74 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -44,9 +44,9 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
   (void)inserted;
   DATAFLOW_DEBUG({
     if (inserted) {
-      llvm::dbgs() << "Creating dependency between " << debugName << " of "
-                   << anchor << "\nand " << debugName << " on " << dependent
-                   << "\n";
+      llvm::dbgs() << "Creating dependency between \t" << debugName << " of "
+                   << anchor << "\n                      and\t" << debugName
+                   << " of " << *dependent << "\n";
     }
   });
 }
@@ -125,7 +125,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
     worklist.pop();
 
     DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
-                                << "' on: " << point << "\n");
+                                << "' on: " << *point << "\n");
     if (failed(analysis->visit(point)))
       return failure();
   }

>From 1f483c9e9abf26ba6ab6d302b5a8f6e5530e1341 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Thu, 24 Apr 2025 18:02:02 +0200
Subject: [PATCH 21/23] Fix liveness analysis

The problem is that the live in
set is not necessarily already
populated when we ask for the
start op
---
 mlir/lib/Analysis/Liveness.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
index e3245d68b3699..c1673b920ac23 100644
--- a/mlir/lib/Analysis/Liveness.cpp
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -369,7 +369,7 @@ Operation *LivenessBlockInfo::getStartOperation(Value value) const {
   Operation *definingOp = value.getDefiningOp();
   // The given value is either live-in or is defined
   // in the scope of this block.
-  if (isLiveIn(value) || !definingOp)
+  if (!definingOp || definingOp->getBlock() != block)
     return &block->front();
   return definingOp;
 }

>From 2b28223056aa993d1e71c94d66a9303d09bbaa4a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Tue, 29 Apr 2025 19:10:12 +0200
Subject: [PATCH 22/23] Fix affine fold memref alias pattern

in the case where the subview op uses
an index of an affine loop as eg, the
offset. The previous implementation
always generated symbols, and the
verifier failed, although the
transformation is valid if you just
generate dims for the variables that
are not symbols, but are valid dims.
---
 .../Affine/Utils/ViewLikeInterfaceUtils.cpp   | 31 ++++++++++++++++---
 1 file changed, 26 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
index b74df4ff6060f..5b02e0406b6ca 100644
--- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
@@ -78,6 +78,26 @@ LogicalResult mlir::affine::mergeOffsetsSizesAndStrides(
       combinedOffsets, combinedSizes, combinedStrides);
 }
 
+static AffineMap bindSymbolsOrDims(
+    MLIRContext *ctx, llvm::ArrayRef<OpFoldResult> operands,
+    function_ref<AffineExpr(llvm::SmallVectorImpl<AffineExpr> &)> makeExpr) {
+  SmallVector<AffineExpr, 4> affineExprs(operands.size());
+  unsigned symbolCount = 0;
+  unsigned dimCount = 0;
+  for (auto [e, value] : llvm::zip_equal(affineExprs, operands)) {
+    auto asValue = llvm::dyn_cast_or_null<Value>(value);
+    if (asValue && !affine::isValidSymbol(asValue) &&
+        affine::isValidDim(asValue)) {
+      e = getAffineDimExpr(dimCount++, ctx);
+    } else {
+      // This is also done if it is not a valid symbol but
+      // we don't care, we need a fallback.
+      e = getAffineSymbolExpr(symbolCount++, ctx);
+    }
+  }
+  return AffineMap::get(dimCount, symbolCount, makeExpr(affineExprs));
+}
+
 void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
     RewriterBase &rewriter, Location loc,
     ArrayRef<OpFoldResult> mixedSourceOffsets,
@@ -100,11 +120,12 @@ void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
   resolvedIndices.clear();
   for (auto [offset, index, stride] :
        llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
-    AffineExpr off, idx, str;
-    bindSymbols(rewriter.getContext(), off, idx, str);
-    OpFoldResult ofr = makeComposedFoldedAffineApply(
-        rewriter, loc, AffineMap::get(0, 3, off + idx * str),
-        {offset, index, stride});
+    auto affineMap =
+        bindSymbolsOrDims(rewriter.getContext(), {offset, index, stride},
+                          [](auto &e) { return e[0] + e[1] * e[2]; });
+
+    OpFoldResult ofr = makeComposedFoldedAffineApply(rewriter, loc, affineMap,
+                                                     {offset, index, stride});
     resolvedIndices.push_back(
         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
   }

>From 4d8ae5fb5aa3c55ddd5dc58fe1e0f6026789ae00 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Wed, 30 Apr 2025 19:48:54 +0200
Subject: [PATCH 23/23] Allow the subview folding to consider expressions of
 dims and symbols

and not just plain dims and symbols. This makes it possible
to fold memref.subview ops that use an affine expression of
valid symbol and dims as an offset, even if that expression
is computed by arith ops like muli and addi.
---
 mlir/include/mlir/Dialect/Affine/Utils.h      |  18 ++-
 .../Affine/Transforms/RaiseMemrefDialect.cpp  | 102 +------------
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 137 ++++++++++++++++++
 .../Affine/Utils/ViewLikeInterfaceUtils.cpp   |  44 +++---
 4 files changed, 177 insertions(+), 124 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 250c28d0c9d41..93b7af7d24f85 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -17,6 +17,8 @@
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 namespace mlir {
@@ -105,7 +107,7 @@ struct VectorizationStrategy {
 /// Replace affine store and load accesses by scalars by forwarding stores to
 /// loads and eliminate invariant affine loads; consequently, eliminate dead
 /// allocs.
-void affineScalarReplace(Operation* parentOp, DominanceInfo &domInfo,
+void affineScalarReplace(Operation *parentOp, DominanceInfo &domInfo,
                          PostDominanceInfo &postDomInfo,
                          AliasAnalysis &analysis);
 
@@ -338,6 +340,20 @@ OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
                             ArrayRef<OpFoldResult> multiIndex,
                             ArrayRef<OpFoldResult> basis);
 
+/// Given a set of indices into a memref which may be computed using
+/// arith ops, try to compute each value to an affine expr. This is
+/// only possible if the indices are an expression of valid dims and
+/// args. If this succeeds, the affine map is populated, along with
+/// the map arguments (concrete bindings for dims and symbols).
+LogicalResult
+convertValuesToAffineMapAndArgs(MLIRContext *ctx, ValueRange indices,
+                                AffineMap &map,
+                                llvm::SmallVectorImpl<Value> &mapArgs);
+LogicalResult
+convertValuesToAffineMapAndArgs(MLIRContext *ctx,
+                                ArrayRef<OpFoldResult> indices, AffineMap &map,
+                                llvm::SmallVectorImpl<OpFoldResult> &mapArgs);
+
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
 /// between the operations) do not have the potential memory effect
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
index b9d1df054390d..21f5a9d6aa7d7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -11,16 +11,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
 
 namespace mlir {
@@ -37,96 +33,6 @@ using namespace mlir::affine;
 
 namespace {
 
-/// Find the index of the given value in the `dims` list,
-/// and append it if it was not already in the list. The
-/// dims list is a list of symbols or dimensions of the
-/// affine map. Within the results of an affine map, they
-/// are identified by their index, which is why we need
-/// this function.
-static std::optional<size_t>
-findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
-                function_ref<bool(Value)> isValidElement) {
-
-  Value *loopIV = std::find(dims.begin(), dims.end(), value);
-  if (loopIV != dims.end()) {
-    // We found an IV that already has an index, return that index.
-    return {std::distance(dims.begin(), loopIV)};
-  }
-  if (isValidElement(value)) {
-    // This is a valid element for the dim/symbol list, push this as a
-    // parameter.
-    size_t idx = dims.size();
-    dims.push_back(value);
-    return idx;
-  }
-  return std::nullopt;
-}
-
-/// Convert a value to an affine expr if possible. Adds dims and symbols
-/// if needed.
-static AffineExpr toAffineExpr(Value value,
-                               llvm::SmallVectorImpl<Value> &affineDims,
-                               llvm::SmallVectorImpl<Value> &affineSymbols) {
-  using namespace matchers;
-  IntegerAttr::ValueType cst;
-  if (matchPattern(value, m_ConstantInt(&cst))) {
-    return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
-  }
-  Value lhs;
-  Value rhs;
-  if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
-      matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
-    AffineExpr lhsE;
-    AffineExpr rhsE;
-    if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
-        (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
-      AffineExprKind kind;
-      if (isa<arith::AddIOp>(value.getDefiningOp())) {
-        kind = mlir::AffineExprKind::Add;
-      } else {
-        kind = mlir::AffineExprKind::Mul;
-      }
-      return getAffineBinaryOpExpr(kind, lhsE, rhsE);
-    }
-  }
-
-  if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
-        return affine::isValidSymbol(v);
-      })) {
-    return getAffineSymbolExpr(*dimIx, value.getContext());
-  }
-
-  if (auto dimIx = findInListOrAdd(
-          value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
-
-    return getAffineDimExpr(*dimIx, value.getContext());
-  }
-
-  return {};
-}
-
-static LogicalResult
-computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
-                        llvm::SmallVectorImpl<Value> &mapArgs) {
-  SmallVector<AffineExpr> results;
-  SmallVector<Value> symbols;
-  SmallVector<Value> dims;
-
-  for (Value indexExpr : indices) {
-    AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
-    if (!res) {
-      return failure();
-    }
-    results.push_back(res);
-  }
-
-  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
-
-  dims.append(symbols);
-  mapArgs.swap(dims);
-  return success();
-}
-
 struct RaiseMemrefDialect
     : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
 
@@ -140,8 +46,8 @@ struct RaiseMemrefDialect
       rewriter.setInsertionPoint(op);
       if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
 
-        if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
-                                              mapArgs))) {
+        if (succeeded(affine::convertValuesToAffineMapAndArgs(
+                ctx, store.getIndices(), map, mapArgs))) {
           rewriter.replaceOpWithNewOp<AffineStoreOp>(
               op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
           return;
@@ -151,8 +57,8 @@ struct RaiseMemrefDialect
                    << "[affine] Cannot raise memref op: " << op << "\n");
 
       } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
-        if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
-                                              mapArgs))) {
+        if (succeeded(affine::convertValuesToAffineMapAndArgs(
+                ctx, load.getIndices(), map, mapArgs))) {
           rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
                                                     mapArgs);
           return;
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4e1b0665f7cd7..791a9587a6f36 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -21,15 +21,19 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/LogicalResult.h"
 #include <optional>
 #include <tuple>
@@ -2203,3 +2207,136 @@ OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
   return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
                                                multiIndexAndStrides);
 }
+
+namespace {
+
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+                function_ref<bool(Value)> isValidElement) {
+
+  Value *loopIV = std::find(dims.begin(), dims.end(), value);
+  if (loopIV != dims.end()) {
+    // We found an IV that already has an index, return that index.
+    return {std::distance(dims.begin(), loopIV)};
+  }
+  if (isValidElement(value)) {
+    // This is a valid element for the dim/symbol list, push this as a
+    // parameter.
+    size_t idx = dims.size();
+    dims.push_back(value);
+    return idx;
+  }
+  return std::nullopt;
+}
+
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+                               llvm::SmallVectorImpl<Value> &affineDims,
+                               llvm::SmallVectorImpl<Value> &affineSymbols) {
+  using namespace matchers;
+  IntegerAttr::ValueType cst;
+  if (matchPattern(value, m_ConstantInt(&cst))) {
+    return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+  }
+  Value lhs;
+  Value rhs;
+  if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
+      matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
+    AffineExpr lhsE;
+    AffineExpr rhsE;
+    if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
+        (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
+      AffineExprKind kind;
+      if (isa<arith::AddIOp>(value.getDefiningOp())) {
+        kind = mlir::AffineExprKind::Add;
+      } else {
+        kind = mlir::AffineExprKind::Mul;
+      }
+      return getAffineBinaryOpExpr(kind, lhsE, rhsE);
+    }
+  }
+
+  if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+        return affine::isValidSymbol(v);
+      })) {
+    return getAffineSymbolExpr(*dimIx, value.getContext());
+  }
+
+  if (auto dimIx = findInListOrAdd(
+          value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+    return getAffineDimExpr(*dimIx, value.getContext());
+  }
+
+  return {};
+}
+
+} // namespace
+
+LogicalResult mlir::affine::convertValuesToAffineMapAndArgs(
+    MLIRContext *ctx, ValueRange indices, AffineMap &map,
+    llvm::SmallVectorImpl<Value> &mapArgs) {
+  SmallVector<AffineExpr> results;
+  SmallVector<Value> symbols;
+  SmallVector<Value> dims;
+
+  for (Value indexExpr : indices) {
+    AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+    if (!res) {
+      return failure();
+    }
+    results.push_back(res);
+  }
+
+  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+  dims.append(symbols);
+  mapArgs.swap(dims);
+  return success();
+}
+
+LogicalResult mlir::affine::convertValuesToAffineMapAndArgs(
+    MLIRContext *ctx, ArrayRef<OpFoldResult> indices, AffineMap &map,
+    llvm::SmallVectorImpl<OpFoldResult> &mapArgs) {
+  SmallVector<AffineExpr> results;
+  SmallVector<Value> symbols;
+  SmallVector<Value> dims;
+  SmallVector<OpFoldResult> constantSymbols;
+
+  for (OpFoldResult indexExpr : indices) {
+    if (auto asValue = llvm::dyn_cast_or_null<Value>(indexExpr)) {
+      AffineExpr res = toAffineExpr(asValue, dims, symbols);
+      if (!res) {
+        return failure();
+      }
+      results.push_back(res);
+    } else {
+      constantSymbols.push_back(indexExpr);
+      results.push_back(getAffineSymbolExpr(symbols.size(), ctx));
+      // add a null symbol here to increment the next symbol id.
+      symbols.emplace_back();
+    }
+  }
+
+  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+  for (auto dim : dims) {
+    mapArgs.push_back(dim);
+  }
+  unsigned nextConstSymbol = 0;
+  for (auto symbol : symbols) {
+    if (symbol) {
+      mapArgs.push_back(symbol);
+    } else {
+      mapArgs.push_back(constantSymbols[nextConstSymbol++]);
+    }
+  }
+  return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
index 5b02e0406b6ca..91c367ad8ca85 100644
--- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
@@ -8,7 +8,10 @@
 
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
@@ -78,26 +81,6 @@ LogicalResult mlir::affine::mergeOffsetsSizesAndStrides(
       combinedOffsets, combinedSizes, combinedStrides);
 }
 
-static AffineMap bindSymbolsOrDims(
-    MLIRContext *ctx, llvm::ArrayRef<OpFoldResult> operands,
-    function_ref<AffineExpr(llvm::SmallVectorImpl<AffineExpr> &)> makeExpr) {
-  SmallVector<AffineExpr, 4> affineExprs(operands.size());
-  unsigned symbolCount = 0;
-  unsigned dimCount = 0;
-  for (auto [e, value] : llvm::zip_equal(affineExprs, operands)) {
-    auto asValue = llvm::dyn_cast_or_null<Value>(value);
-    if (asValue && !affine::isValidSymbol(asValue) &&
-        affine::isValidDim(asValue)) {
-      e = getAffineDimExpr(dimCount++, ctx);
-    } else {
-      // This is also done if it is not a valid symbol but
-      // we don't care, we need a fallback.
-      e = getAffineSymbolExpr(symbolCount++, ctx);
-    }
-  }
-  return AffineMap::get(dimCount, symbolCount, makeExpr(affineExprs));
-}
-
 void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
     RewriterBase &rewriter, Location loc,
     ArrayRef<OpFoldResult> mixedSourceOffsets,
@@ -120,12 +103,23 @@ void mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
   resolvedIndices.clear();
   for (auto [offset, index, stride] :
        llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
-    auto affineMap =
-        bindSymbolsOrDims(rewriter.getContext(), {offset, index, stride},
-                          [](auto &e) { return e[0] + e[1] * e[2]; });
+    AffineMap map;
+    SmallVector<OpFoldResult> mapArgs;
+    auto *ctx = rewriter.getContext();
+    if (failed(affine::convertValuesToAffineMapAndArgs(
+            ctx, {offset, index, stride}, map, mapArgs))) {
+      // todo
+      resolvedIndices.push_back(Value{});
+      continue;
+    }
+    AffineExpr off, ix, str;
+    bindDims(ctx, off, ix, str);
+    auto nextMap = AffineMap::get(3, 0, off + ix * str);
+    auto composedMap = nextMap.compose(map);
+
+    OpFoldResult ofr =
+        makeComposedFoldedAffineApply(rewriter, loc, composedMap, mapArgs);
 
-    OpFoldResult ofr = makeComposedFoldedAffineApply(rewriter, loc, affineMap,
-                                                     {offset, index, stride});
     resolvedIndices.push_back(
         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
   }



More information about the Mlir-commits mailing list