[Mlir-commits] [mlir] [llvm] [mlir][bufferization][WIP] Use `BufferOriginAnalysis` to fold away `dealloc` runtime checks (PR #79602)

Matthias Springer llvmlistbot at llvm.org
Fri Jan 26 06:50:32 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/79602

This commit adds the `BufferOriginAnalysis`, which can be queried by the buffer deallocation pass to check if two buffers originate from the same allocation.

This analysis enables additional simplifications during `-buffer-deallocation-simplification`. In particular, "regular" `scf.for` loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the `iter_args`, are now handled much more efficiently. (TODO: Add test case.) Such IR patterns are generated by the sparse compiler.


>From 65e0e8b47fb0266b763ebb41b1421576ae989079 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 25 Jan 2024 08:47:24 +0000
Subject: [PATCH 1/2] [mlir][bufferization] Add `BufferViewFlowOpInterface`

This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`.

There are currently no ops that implement this interface. The first op implementations will be added in a consecutive commit.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../BufferViewFlowOpInterfaceImpl.h           | 20 ++++++
 .../IR/BufferViewFlowOpInterface.h            | 27 +++++++
 .../IR/BufferViewFlowOpInterface.td           | 67 +++++++++++++++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |  1 +
 .../Transforms/BufferViewFlowAnalysis.h       |  6 ++
 .../BufferViewFlowOpInterfaceImpl.h           | 20 ++++++
 mlir/include/mlir/InitAllDialects.h           |  4 ++
 .../BufferViewFlowOpInterfaceImpl.cpp         | 44 ++++++++++++
 .../Dialect/Arith/Transforms/CMakeLists.txt   |  1 +
 .../IR/BufferViewFlowOpInterface.cpp          | 18 +++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |  1 +
 .../Transforms/BufferViewFlowAnalysis.cpp     | 72 +++++++++++++++----
 .../BufferViewFlowOpInterfaceImpl.cpp         | 48 +++++++++++++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |  2 +
 .../llvm-project-overlay/mlir/BUILD.bazel     | 36 ++++++++++
 15 files changed, 353 insertions(+), 14 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 000000000000000..a2b3a9bb655b874
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 000000000000000..84e67fe72b623b1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 000000000000000..091d28b9a0a5e88
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,67 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+    OpInterface<"BufferViewFlowOpInterface"> {
+  let description = [{
+    An op interface for the buffer view flow analysis. This interface describes
+    buffer dependencies between operands and op results/region entry block
+    arguments.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Populate buffer dependencies between operands and op results/region
+          entry block arguments.
+
+          Implementations should register dependencies between an operand ("X")
+          and an op result/region entry block argument ("Y") if Y may depend
+          on X. Y depends on X if Y and X are the same buffer or if Y is a
+          subview of X.
+
+          Example:
+          ```
+          %r = arith.select %c, %m1, %m2 : memref<5xf32>
+          ```
+          In the above example, %0 may depend on %m1 or %m2 and a correct
+          interface implementation should call:
+          - "registerDependenciesFn(%m1, %r)".
+          - "registerDependenciesFn(%m2, %r)"
+        }],
+        /*retType=*/"void",
+        /*methodName=*/"populateDependencies",
+        /*args=*/(ins
+            "::mlir::bufferization::RegisterDependenciesFn"
+                :$registerDependenciesFn)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if the given value is a terminal buffer. A buffer value
+          is "terminal" if it cannot be traced back any further in the buffer
+          view flow analysis. E.g., because the value is a newly allocated
+          buffer or because there is not enough information available.
+
+          Implementations can assume that the given SSA value is an OpResult of
+          this operation or a region entry block argument of this operation.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isTerminalBuffer",
+        /*args=*/(ins "Value":$value),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return false;"
+      >,
+  ];
+}
+
+#endif  // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f554..13a5bc370a4fce4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferDeallocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db69f90c50..894831c3848219b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ class BufferViewFlowAnalysis {
   /// results have to be changed.
   void rename(Value from, Value to);
 
+  /// Returns "true" if the given value is a terminal.
+  bool isTerminalBuffer(Value value) const;
+
 private:
   /// This function constructs a mapping from values to its immediate
   /// dependencies.
@@ -70,6 +73,9 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+
+  /// A set of all terminal values. I.e., values where the analysis stopped.
+  DenseSet<Value> terminals;
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 000000000000000..714518a21e975db
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04f..938337232295696 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
+  arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
+  memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..9df9df86b64fb58
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,44 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+    : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+                                                      SelectOp> {
+  void
+  populateDependencies(Operation *op,
+                       RegisterDependenciesFn registerDependenciesFn) const {
+    auto selectOp = cast<SelectOp>(op);
+
+    // Either one of the true/false value may be selected at runtime.
+    registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+    registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+  }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+    SelectOp::attachInterface<SelectOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 02240601bcd35a1..12659eaba1fa5ef 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  BufferViewFlowOpInterfaceImpl.cpp
   EmulateUnsupportedFloats.cpp
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 000000000000000..ea726a4bfc3fb93
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9d93ce0bb..63dcc1eb233e928 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferDeallocationOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  BufferViewFlowOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b639fc5ceb..7cf202ac81d7c07 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
+using namespace mlir::bufferization;
 
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
       this->dependencies[value].insert(dep);
   };
 
+  // Mark all buffer results and buffer region entry block arguments of the
+  // given op as terminals.
+  auto populateTerminalValues = [&](Operation *op) {
+    for (Value v : op->getResults())
+      if (isa<BaseMemRefType>(v.getType()))
+        this->terminals.insert(v);
+    for (Region &r : op->getRegions())
+      for (BlockArgument v : r.getArguments())
+        if (isa<BaseMemRefType>(v.getType()))
+          this->terminals.insert(v);
+  };
+
   op->walk([&](Operation *op) {
-    // TODO: We should have an op interface instead of a hard-coded list of
-    // interfaces/ops.
+    // Query BufferViewFlowOpInterface. If the op does not implement that
+    // interface, try to infer the dependencies from other interfaces that the
+    // op may implement.
+    if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+      bufferViewFlowOp.populateDependencies(registerDependencies);
+      for (Value v : op->getResults())
+        if (isa<BaseMemRefType>(v.getType()) &&
+            bufferViewFlowOp.isTerminalBuffer(v))
+          this->terminals.insert(v);
+      for (Region &r : op->getRegions())
+        for (BlockArgument v : r.getArguments())
+          if (isa<BaseMemRefType>(v.getType()) &&
+              bufferViewFlowOp.isTerminalBuffer(v))
+            this->terminals.insert(v);
+      return WalkResult::advance();
+    }
 
     // Add additional dependencies created by view changes to the alias list.
     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
-      dependencies[viewInterface.getViewSource()].insert(
-          viewInterface->getResult(0));
+      registerDependencies(viewInterface.getViewSource(),
+                           viewInterface->getResult(0));
       return WalkResult::advance();
     }
 
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       return WalkResult::advance();
     }
 
-    // Unknown op: Assume that all operands alias with all results.
-    for (Value operand : op->getOperands()) {
-      if (!isa<BaseMemRefType>(operand.getType()))
-        continue;
-      for (Value result : op->getResults()) {
-        if (!isa<BaseMemRefType>(result.getType()))
-          continue;
-        registerDependencies({operand}, {result});
-      }
+    // Region terminators are handled together with RegionBranchOpInterface.
+    if (isa<RegionBranchTerminatorOpInterface>(op))
+      return WalkResult::advance();
+
+    if (isa<CallOpInterface>(op)) {
+      // This is an intra-function analysis. We have no information about other
+      // functions. Conservatively assume that each operand may alias with each
+      // result. Also mark the results are terminals because the function could
+      // return newly allocated buffers.
+      populateTerminalValues(op);
+      for (Value operand : op->getOperands())
+        for (Value result : op->getResults())
+          registerDependencies({operand}, {result});
+      return WalkResult::advance();
     }
+
+    // We have no information about unknown ops.
+    populateTerminalValues(op);
+
     return WalkResult::advance();
   });
 }
+
+bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
+  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+  return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..407fd38e6647eec
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,48 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+    : public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface,
+                                                      ReallocOp> {
+  void
+  populateDependencies(Operation *op,
+                       RegisterDependenciesFn registerDependenciesFn) const {
+    auto reallocOp = cast<ReallocOp>(op);
+    // memref.realloc may return the source operand.
+    registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult());
+  }
+
+  bool isTerminalBuffer(Operation *op, Value value) const {
+    // The return value of memref.realloc is a terminal buffer because the op
+    // may return a newly allocated buffer.
+    return true;
+  }
+};
+
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void memref::registerBufferViewFlowOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    ReallocOp::attachInterface<ReallocOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 08b7eab726eb7e8..f150ac7ac2d634a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
+  BufferViewFlowOpInterfaceImpl.cpp
   ComposeSubView.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
@@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRArithDialect
   MLIRArithTransforms
   MLIRBufferizationDialect
+  MLIRBufferizationTransforms
   MLIRDialectUtils
   MLIRFuncDialect
   MLIRGPUDialect
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2e75c48510b9c47..936b3e3cccd8e32 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10547,6 +10547,36 @@ gentbl_cc_library(
     ],
 )
 
+td_library(
+    name = "BufferViewFlowOpInterfaceTdFiles",
+    srcs = [
+        "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":OpBaseTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "BufferViewFlowOpInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td",
+    deps = [
+        ":BufferViewFlowOpInterfaceTdFiles",
+    ],
+)
+
 td_library(
     name = "SubsetOpInterfaceTdFiles",
     srcs = [
@@ -12660,6 +12690,7 @@ cc_library(
         ":ArithTransforms",
         ":ArithUtils",
         ":BufferizationDialect",
+        ":BufferizationTransforms",
         ":ControlFlowDialect",
         ":DialectUtils",
         ":FuncDialect",
@@ -12924,6 +12955,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":AllocationOpInterfaceTdFiles",
+        ":BufferViewFlowOpInterfaceTdFiles",
         ":BufferizableOpInterfaceTdFiles",
         ":CopyOpInterfaceTdFiles",
         ":DestinationStyleOpInterfaceTdFiles",
@@ -13065,6 +13097,7 @@ cc_library(
     name = "BufferizationDialect",
     srcs = [
         "lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp",
+        "lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
         "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
@@ -13072,6 +13105,7 @@ cc_library(
     ],
     hdrs = [
         "include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h",
+        "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
         "include/mlir/Dialect/Bufferization/IR/Bufferization.h",
         "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
@@ -13084,6 +13118,7 @@ cc_library(
         ":Analysis",
         ":ArithDialect",
         ":BufferDeallocationOpInterfaceIncGen",
+        ":BufferViewFlowOpInterfaceIncGen",
         ":BufferizableOpInterfaceIncGen",
         ":BufferizationBaseIncGen",
         ":BufferizationEnumsIncGen",
@@ -13140,6 +13175,7 @@ cc_library(
         ":ControlFlowDialect",
         ":ControlFlowInterfaces",
         ":FuncDialect",
+        ":FunctionInterfaces",
         ":IR",
         ":LoopLikeInterface",
         ":MemRefDialect",

>From 97d0b087702c6dc696b6781bf7df26dad399a41a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 26 Jan 2024 14:45:06 +0000
Subject: [PATCH 2/2] [mlir][bufferization] BufferOriginAnalysis

---
 .../Bufferization/IR/BufferizationOps.td      |   1 +
 .../Transforms/BufferViewFlowAnalysis.h       |   2 +
 .../BufferDeallocationSimplification.cpp      | 227 +++++++++++++++---
 .../Transforms/BufferViewFlowAnalysis.cpp     |  32 ++-
 .../buffer-deallocation-simplification.mlir   |  14 +-
 5 files changed, 220 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 9dc6afcaab31c86..4f609ddff9a4138 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -10,6 +10,7 @@
 #define BUFFERIZATION_OPS
 
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 894831c3848219b..9d31ff5872c42a2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -53,6 +53,7 @@ class BufferViewFlowAnalysis {
   ///
   /// Results in resolve(B) returning {B, C}
   ValueSetT resolve(Value value) const;
+  ValueSetT resolveReverse(Value value) const;
 
   /// Removes the given values from all alias sets.
   void remove(const SetVector<Value> &aliasValues);
@@ -73,6 +74,7 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+  ValueMapT reverseDependencies;
 
   /// A set of all terminal values. I.e., values where the analysis stopped.
   DenseSet<Value> terminals;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e30779868b47539..860ce891c9bc2f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -30,6 +30,161 @@ namespace bufferization {
 using namespace mlir;
 using namespace mlir::bufferization;
 
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+  Operation *op = v.getDefiningOp();
+  if (!op)
+    return false;
+  return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+  auto bbArg = dyn_cast<BlockArgument>(v);
+  if (!bbArg)
+    return false;
+  Block *b = bbArg.getOwner();
+  auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+  if (!funcOp)
+    return false;
+  return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+    value = viewLikeOp.getViewSource();
+  return value;
+}
+
+/// An is-same-buffer analysis that checks if two SSA values belong to the same
+/// buffer allocation or not.
+class BufferOriginAnalysis {
+public:
+  BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+  /// Return "true" if `v1` and `v2` originate from the same buffer allocation.
+  /// Return "false" if `v1` and `v2` originate from different allocations.
+  /// Return "nullopt" if we do not know for sure.
+  ///
+  /// Example 1: isSameAllocation(%0, %1) == true
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.subview %0
+  /// ```
+  ///
+  /// Example 2: isSameAllocation(%0, %1) == false
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// ```
+  ///
+  /// Example 3: isSameAllocation(%0, %2) == nullopt
+  /// ```
+  /// %0 = memref.alloc()
+  /// %1 = memref.alloc()
+  /// %2 = arith.select %c, %0, %1
+  /// ```
+  std::optional<bool> isSameAllocation(Value v1, Value v2) {
+    assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+    assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+    // Skip over all view-like ops.
+    v1 = getViewBase(v1);
+    v2 = getViewBase(v2);
+
+    // Fast path: If both buffers are the same SSA value, we can be sure that
+    // they originate from the same allocation.
+    if (v1 == v2)
+      return true;
+
+    // Compute the SSA values from which the buffers `v1` and `v2` originate.
+    SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+    SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+    // Originating buffers are "terminal" if they could not be traced back any
+    // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+    // - function block arguments
+    // - values defined by allocation ops such as "memref.alloc"
+    // - values defined by ops that are unknown to the buffer view flow analysis
+    // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+    SmallPtrSet<Value, 16> terminal1, terminal2;
+
+    // While gathering terminal buffers, keep track of whether all terminal
+    // buffers are newly allocated buffer or function entry arguments.
+    bool allAllocs1 = true, allAllocs2 = true;
+    bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+    // Gather terminal buffers for `v1`.
+    for (Value v : origin1) {
+      if (isa<BaseMemRefType>(v.getType()) && analysis.isTerminalBuffer(v)) {
+        terminal1.insert(v);
+        allAllocs1 &= hasAllocateSideEffect(v);
+        allAllocsOrFuncEntryArgs1 &=
+            isFunctionArgument(v) || hasAllocateSideEffect(v);
+      }
+    }
+    assert(!terminal1.empty() && "expected non-empty terminal set");
+
+    // Gather terminal buffers for `v2`.
+    bool distinctTerminalSets = true;
+    for (Value v : origin2) {
+      if (isa<BaseMemRefType>(v.getType()) && analysis.isTerminalBuffer(v)) {
+        terminal2.insert(v);
+        allAllocs2 &= hasAllocateSideEffect(v);
+        allAllocsOrFuncEntryArgs2 &=
+            isFunctionArgument(v) || hasAllocateSideEffect(v);
+        distinctTerminalSets &= !terminal1.contains(v);
+      }
+    }
+    assert(!terminal2.empty() && "expected non-empty terminal set");
+
+    // If both `v1` and `v2` have a single matching terminal buffer, they are
+    // guaranteed to originate from the same buffer allocation.
+    if (llvm::hasSingleElement(terminal1) &&
+        llvm::hasSingleElement(terminal2) &&
+        *terminal1.begin() == *terminal2.begin())
+      return true;
+
+    // At least one of the two values has multiple terminals.
+
+    // If there is overlap between the terminal buffers of `v1` and `v2`, we
+    // cannot make an accurate decision without further analysis.
+    if (!distinctTerminalSets)
+      return std::nullopt;
+
+    // If `v1` originates from only allocs, and `v2` is guaranteed to originate
+    // from different allocations (that is guaranteed if `v2` originates from
+    // only distinct allocs or function entry arguments), we can be sure that
+    // `v1` and `v2` originate from different allocations. The same argument can
+    // be made when swapping `v1` and `v2`.
+    bool isolatedAlloc1 =
+        allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
+    bool isolatedAlloc2 =
+        (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
+    if (isolatedAlloc1 || isolatedAlloc2)
+      return false;
+
+    // Otherwise: We do not know whether `v1` and `v2` originate from the same
+    // allocation or not.
+    // TODO: Function arguments are currently handled conservatively. We assume
+    // that they could be the same allocation.
+    // TODO: Terminals other than allocations and function arguments are
+    // currently handled conservatively. We assume that they could be the same
+    // allocation. E.g., we currently return "nullopt" for values that originate
+    // from different "memref.get_global" ops (with different symbols).
+    return std::nullopt;
+  }
+
+private:
+  BufferViewFlowAnalysis analysis;
+};
+
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
@@ -49,14 +204,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
   return success();
 }
 
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
-  while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
-    value = viewLikeOp.getViewSource();
-  return value;
-}
-
 /// Return "true" if the given values are guaranteed to be different (and
 /// non-aliasing) allocations based on the fact that one value is the result
 /// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +227,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
 /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
 /// often a requirement of optimization patterns that there cannot be any
 /// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
                                      ValueRange otherList, Value memref) {
   for (auto other : otherList) {
     if (distinctAllocAndBlockArgument(other, memref))
       continue;
-    if (!analysis.alias(other, memref).isNo())
+    std::optional<bool> analysisResult =
+        analysis.isSameAllocation(other, memref);
+    if (!analysisResult.has_value() || analysisResult == true)
       return true;
   }
   return false;
@@ -129,8 +278,8 @@ namespace {
 struct RemoveDeallocMemrefsContainedInRetained
     : public OpRewritePattern<DeallocOp> {
   RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
-                                          AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                          BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   /// The passed 'memref' must not have a may-alias relation to any retained
   /// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +296,11 @@ struct RemoveDeallocMemrefsContainedInRetained
     // deallocated in some situations and can thus not be dropped).
     bool atLeastOneMustAlias = false;
     for (Value retained : deallocOp.getRetained()) {
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMay())
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (!analysisResult.has_value())
         return failure();
-      if (analysisResult.isMust() || analysisResult.isPartial())
+      if (analysisResult == true)
         atLeastOneMustAlias = true;
     }
     if (!atLeastOneMustAlias)
@@ -161,8 +311,9 @@ struct RemoveDeallocMemrefsContainedInRetained
     // we can remove that operand later on.
     for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
       Value updatedCondition = deallocOp.getUpdatedConditions()[i];
-      AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-      if (analysisResult.isMust() || analysisResult.isPartial()) {
+      std::optional<bool> analysisResult =
+          analysis.isSameAllocation(retained, memref);
+      if (analysisResult == true) {
         auto disjunction = rewriter.create<arith::OrIOp>(
             deallocOp.getLoc(), updatedCondition, cond);
         rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +357,7 @@ struct RemoveDeallocMemrefsContainedInRetained
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +379,15 @@ struct RemoveDeallocMemrefsContainedInRetained
 struct RemoveRetainedMemrefsGuaranteedToNotAlias
     : public OpRewritePattern<DeallocOp> {
   RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
-                                            AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                            BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> newRetainedMemrefs, replacements;
 
     for (auto retainedMemref : deallocOp.getRetained()) {
-      if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+      if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
                                    retainedMemref)) {
         newRetainedMemrefs.push_back(retainedMemref);
         replacements.push_back({});
@@ -264,7 +415,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +448,8 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
 struct SplitDeallocWhenNotAliasingAnyOther
     : public OpRewritePattern<DeallocOp> {
   SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
-                                      AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                      BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -314,7 +465,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
       SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
       otherMemrefs.erase(otherMemrefs.begin() + i);
       // Check if `memref` can split off into a separate bufferization.dealloc.
-      if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+      if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
         // `memref` alias with other memrefs, do not split off.
         remainingMemrefs.push_back(memref);
         remainingConditions.push_back(cond);
@@ -352,7 +503,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 /// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +532,8 @@ struct SplitDeallocWhenNotAliasingAnyOther
 struct RetainedMemrefAliasingAlwaysDeallocatedMemref
     : public OpRewritePattern<DeallocOp> {
   RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
-                                                AliasAnalysis &aliasAnalysis)
-      : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+                                                BufferOriginAnalysis &analysis)
+      : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
 
   LogicalResult matchAndRewrite(DeallocOp deallocOp,
                                 PatternRewriter &rewriter) const override {
@@ -396,8 +547,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!matchPattern(cond, m_One()))
           continue;
 
-        AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
-        if (analysisResult.isMust() || analysisResult.isPartial()) {
+        std::optional<bool> analysisResult =
+            analysis.isSameAllocation(retained, memref);
+        if (analysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -411,10 +563,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
         if (!extractOp)
           continue;
 
-        AliasResult extractAnalysisResult =
-            aliasAnalysis.alias(retained, extractOp.getOperand());
-        if (extractAnalysisResult.isMust() ||
-            extractAnalysisResult.isPartial()) {
+        std::optional<bool> extractAnalysisResult =
+            analysis.isSameAllocation(retained, extractOp.getOperand());
+        if (extractAnalysisResult == true) {
           rewriter.replaceAllUsesWith(res, cond);
           aliasesWithConstTrueMemref[i] = true;
           canDropMemref = true;
@@ -434,7 +585,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
   }
 
 private:
-  AliasAnalysis &aliasAnalysis;
+  BufferOriginAnalysis &analysis;
 };
 
 } // namespace
@@ -452,13 +603,13 @@ struct BufferDeallocationSimplificationPass
     : public bufferization::impl::BufferDeallocationSimplificationBase<
           BufferDeallocationSimplificationPass> {
   void runOnOperation() override {
-    AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+    BufferOriginAnalysis analysis(getOperation());
     RewritePatternSet patterns(&getContext());
     patterns.add<RemoveDeallocMemrefsContainedInRetained,
                  RemoveRetainedMemrefsGuaranteedToNotAlias,
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
-                                                                aliasAnalysis);
+                                                                analysis);
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
     if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 7cf202ac81d7c07..d6e9f59ec0a78d9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -22,19 +22,16 @@ using namespace mlir::bufferization;
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
 
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
-  ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+  BufferViewFlowAnalysis::ValueSetT result;
   SmallVector<Value, 8> queue;
-  queue.push_back(rootValue);
+  queue.push_back(value);
   while (!queue.empty()) {
     Value currentValue = queue.pop_back_val();
     if (result.insert(currentValue).second) {
-      auto it = dependencies.find(currentValue);
-      if (it != dependencies.end()) {
+      auto it = map.find(currentValue);
+      if (it != map.end()) {
         for (Value aliasValue : it->second)
           queue.push_back(aliasValue);
       }
@@ -43,6 +40,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
   return result;
 }
 
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+  return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+  return resolveValues(reverseDependencies, rootValue);
+}
+
 /// Removes the given values from all alias sets.
 void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
   for (auto &entry : dependencies)
@@ -69,8 +79,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
       this->dependencies[value].insert(dep);
+      this->reverseDependencies[dep].insert(value);
+    }
   };
 
   // Mark all buffer results and buffer region entry block arguments of the
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
index eee69acbe821b37..b40a17cf800bf30 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -92,15 +92,13 @@ func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>,
 //  CHECK-NEXT:   [[ALLOC0:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloc(
 //  CHECK-NEXT:   [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
-// COM: there is only one value in the retained list because the
-// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here and
-// COM: removes %arg1 from the list. In the second dealloc, this does not apply
-// COM: because function arguments are assumed potentially alias (even if the
-// COM: types don't exactly match).
+// COM: there is only one value in the retained lists because the
+// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here:
+// COM: - %alloc is guaranteed to not alias with %arg1.
+// COM: - %arg2 is guaranteed to not alias with %0.
 //  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
-//  CHECK-NEXT:   [[V2:%.+]]:2 = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]], [[V0]] : memref<2xi32>, memref<2xi32>)
-//  CHECK-NEXT:   [[V3:%.+]] = arith.ori [[V1]], [[V2]]#1
-//  CHECK-NEXT:   return [[V2]]#0, [[V3]] :
+//  CHECK-NEXT:   [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>)
+//  CHECK-NEXT:   return [[V2]], [[V1]] :
 
 // -----
 



More information about the Mlir-commits mailing list