[Mlir-commits] [mlir] 45cd0e4 - [mlir][bufferization][NFC] Make getEnclosingRepetitiveRegion public

Matthias Springer llvmlistbot at llvm.org
Fri Jan 13 07:39:57 PST 2023


Author: Matthias Springer
Date: 2023-01-13T16:39:41+01:00
New Revision: 45cd0e453d5d0457a3bf9e47ec7232c48729ff59

URL: https://github.com/llvm/llvm-project/commit/45cd0e453d5d0457a3bf9e47ec7232c48729ff59
DIFF: https://github.com/llvm/llvm-project/commit/45cd0e453d5d0457a3bf9e47ec7232c48729ff59.diff

LOG: [mlir][bufferization][NFC] Make getEnclosingRepetitiveRegion public

These functions are generally useful and not specific to One-Shot Analysis. Move them to `BufferizableOpInterface.h` and make them public.

Differential Revision: https://reviews.llvm.org/D141685

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 24754d5993a1f..799aff951c295 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -522,6 +522,19 @@ getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
 /// owner of the block. In case of an OpResult that is the defining op.
 Operation *getOwnerOfValue(Value value);
 
+/// Return the closest enclosing repetitive region around the given op.
+Region *getEnclosingRepetitiveRegion(Operation *op,
+                                     const BufferizationOptions &options);
+
+/// Return the closest enclosing repetitive region around the place where the
+/// given value is defined.
+Region *getEnclosingRepetitiveRegion(Value value,
+                                     const BufferizationOptions &options);
+
+/// Return the closest enclosing repetitive region around the given block.
+Region *getEnclosingRepetitiveRegion(Block *block,
+                                     const BufferizationOptions &options);
+
 namespace detail {
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index af0d48a126af8..9e7dbf5071060 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -41,6 +41,39 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
 using namespace mlir;
 using namespace bufferization;
 
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Operation *op, const BufferizationOptions &options) {
+  if (!op->getBlock())
+    return nullptr;
+  return getEnclosingRepetitiveRegion(op->getBlock(), options);
+}
+
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Value value, const BufferizationOptions &options) {
+  Region *region = value.getParentRegion();
+  while (region) {
+    Operation *op = region->getParentOp();
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+    region = op->getParentRegion();
+  }
+  return nullptr;
+}
+
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Block *block, const BufferizationOptions &options) {
+  Region *region = block->getParent();
+  Operation *op = nullptr;
+  do {
+    op = region->getParentOp();
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+  } while ((region = op->getParentRegion()));
+  return nullptr;
+}
+
 Operation *bufferization::getOwnerOfValue(Value value) {
   if (auto opResult = value.dyn_cast<OpResult>())
     return opResult.getDefiningOp();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index cd06899595f4c..7dfd626323b71 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -355,31 +355,6 @@ static bool happensBefore(Operation *a, Operation *b,
   return false;
 }
 
-static Region *
-getEnclosingRepetitiveRegion(Operation *op,
-                             const BufferizationOptions &options) {
-  while (Region *region = op->getParentRegion()) {
-    op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
-  }
-  return nullptr;
-}
-
-static Region *
-getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
-  Region *region = value.getParentRegion();
-  while (region) {
-    Operation *op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
-    region = op->getParentRegion();
-  }
-  return nullptr;
-}
-
 /// Return `true` if the given tensor value is a memory write. Most values are
 /// tensor writes, but ops that define a tensor SSA value without specifying its
 /// contents (e.g., alloc_tensor) are not.


        


More information about the Mlir-commits mailing list