[Mlir-commits] [mlir] a4547dc - [mlir][linalg][bufferize] Move more helper functions/structs to interface

Matthias Springer llvmlistbot at llvm.org
Wed Nov 10 21:16:38 PST 2021


Author: Matthias Springer
Date: 2021-11-11T14:16:20+09:00
New Revision: a4547dc5758ee9f8ea8aa8e41cce05a7cdec3d56

URL: https://github.com/llvm/llvm-project/commit/a4547dc5758ee9f8ea8aa8e41cce05a7cdec3d56
DIFF: https://github.com/llvm/llvm-project/commit/a4547dc5758ee9f8ea8aa8e41cce05a7cdec3d56.diff

LOG: [mlir][linalg][bufferize] Move more helper functions/structs to interface

Move helper functions for traversing reverse use-def chains. These are useful for implementing custom optimizations (e.g., custom InitTensorOp eliminations).

Also move over the AllocationCallbacks struct. This is in preparation for decoupling ComprehensiveBufferize from various dialects.

Differential Revision: https://reviews.llvm.org/D113386

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 318819cf49a71..1b2b2950776ec 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/SetVector.h"
 
 namespace mlir {
 class BlockAndValueMapping;
@@ -21,7 +22,6 @@ class BlockAndValueMapping;
 namespace linalg {
 namespace comprehensive_bufferize {
 
-struct AllocationCallbacks;
 class BufferizationAliasInfo;
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
@@ -160,6 +160,61 @@ bool isValueRead(Value value);
 /// OpResult that it may alias with. Return None if the op is not bufferizable.
 BufferRelation bufferRelation(OpOperand &opOperand);
 
+/// Starting from `value`, follow the use-def chain in reverse, always selecting
+/// the aliasing OpOperands. Find and return Values for which `condition`
+/// evaluates to true. OpOperands of such matching Values are not traversed any
+/// further.
+///
+/// When reaching the end of a chain (BlockArgument or Value without aliasing
+/// OpOperands), also return the last Value of that chain.
+///
+/// Example:
+///
+///                               8
+///                               |
+///   6*         7*         +-----+----+
+///   |          |          |          |
+///   2*         3          4*         5
+///   |          |          |          |
+///   +----------+----------+----------+
+///              |
+///              1
+///
+/// In the above example, Values with a star satisfy the condition. When
+/// starting the traversal from Value 1, the resulting SetVector is:
+/// { 2, 7, 8, 5 }
+llvm::SetVector<Value>
+findValueInReverseUseDefChain(Value value,
+                              std::function<bool(Value)> condition);
+
+/// Find the Value of the last preceding write of a given Value.
+///
+/// Note: Unknown ops are handled conservatively and assumed to be writes.
+/// Furthermore, BlockArguments are also assumed to be writes. There is no
+/// analysis across block boundaries.
+///
+/// Note: When reaching an end of the reverse SSA use-def chain, that value
+/// is returned regardless of whether it is a memory write or not.
+Value findLastPrecedingWrite(Value value);
+
+/// Callback functions that are used by the comprehensive bufferization pass to
+/// allocate/deallocate memory. The `deallocationFn` is gauranteed to recieve
+/// the `Value` returned by the `allocationFn`.
+struct AllocationCallbacks {
+  using AllocationFn = std::function<Optional<Value>(
+      OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
+  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
+  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+
+  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
+                      MemCpyFn copyFn)
+      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+
+  AllocationFn allocationFn;
+  DeallocationFn deallocationFn;
+  MemCpyFn memCpyFn;
+};
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 83ce9cc026a41..35b6e6f2abe2c 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
 
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/SetOperations.h"
@@ -23,8 +24,6 @@ class ModuleOp;
 namespace linalg {
 namespace comprehensive_bufferize {
 
-class BufferizationAliasInfo;
-
 // TODO: from some HW description.
 static constexpr int64_t kBufferAlignments = 128;
 
@@ -49,29 +48,8 @@ void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
 /// pass. Creates a `linalg.copy` op.
 void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to);
 
-/// Callback functions that are used by the comprehensive bufferization pass to
-/// allocate/deallocate memory. These default to use the
-/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
-/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
-/// by the `allocationFn`.
-struct AllocationCallbacks {
-  using AllocationFn = std::function<Optional<Value>(
-      OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
-  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
-  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-
-  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
-                      MemCpyFn copyFn)
-      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
-
-  AllocationCallbacks()
-      : allocationFn(defaultAllocationFn),
-        deallocationFn(defaultDeallocationFn), memCpyFn(defaultMemCpyFn) {}
-
-  AllocationFn allocationFn;
-  DeallocationFn deallocationFn;
-  MemCpyFn memCpyFn;
-};
+/// Return default allocation callbacks.
+std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
 
 /// Bufferize one particular op.
 /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
@@ -108,8 +86,7 @@ LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
     FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
 
 struct BufferizationOptions {
-  BufferizationOptions()
-      : allocationFns(std::make_unique<AllocationCallbacks>()) {}
+  BufferizationOptions();
 
   std::unique_ptr<AllocationCallbacks> allocationFns;
   bool allowReturnMemref = false;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3efc6e8c33038..1bc2b55f314d4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -260,3 +260,56 @@ mlir::linalg::comprehensive_bufferize::bufferRelation(OpOperand &opOperand) {
   // Conservatively return None.
   return BufferRelation::None;
 }
+
+// Starting from `value`, follow the use-def chain in reverse, always selecting
+// the aliasing OpOperands. Find and return Values for which `condition`
+// evaluates to true. OpOperands of such matching Values are not traversed any
+// further.
+llvm::SetVector<Value>
+mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
+    Value value, std::function<bool(Value)> condition) {
+  llvm::SetVector<Value> result, workingSet;
+  workingSet.insert(value);
+
+  while (!workingSet.empty()) {
+    Value value = workingSet.pop_back_val();
+    if (condition(value) || value.isa<BlockArgument>()) {
+      result.insert(value);
+      continue;
+    }
+
+    OpResult opResult = value.cast<OpResult>();
+    SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
+    if (opOperands.empty()) {
+      result.insert(value);
+      continue;
+    }
+
+    for (OpOperand *o : opOperands)
+      workingSet.insert(o->get());
+  }
+
+  return result;
+}
+
+// Find the Value of the last preceding write of a given Value.
+Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
+    Value value) {
+  SetVector<Value> result =
+      findValueInReverseUseDefChain(value, [](Value value) {
+        Operation *op = value.getDefiningOp();
+        if (!op)
+          return true;
+        auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+        if (!bufferizableOp)
+          return true;
+        return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
+      });
+
+  // To simplify the analysis, `scf.if` ops are considered memory writes. There
+  // are currently no other ops where one OpResult may alias with multiple
+  // OpOperands. Therefore, this function should return exactly one result at
+  // the moment.
+  assert(result.size() == 1 && "expected exactly one result");
+  return result.front();
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 03e3e5e68c85d..1fadf149f164d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -402,84 +402,6 @@ static bool aliasesInPlaceWrite(Value value,
   return foundInplaceWrite;
 }
 
-/// Starting from `value`, follow the use-def chain in reverse, always selecting
-/// the aliasing OpOperands. Find and return Values for which `condition`
-/// evaluates to true. OpOperands of such matching Values are not traversed any
-/// further.
-///
-/// When reaching the end of a chain (BlockArgument or Value without aliasing
-/// OpOperands), also return the last Value of that chain.
-///
-/// Example:
-///
-///                               8
-///                               |
-///   6*         7*         +-----+----+
-///   |          |          |          |
-///   2*         3          4*         5
-///   |          |          |          |
-///   +----------+----------+----------+
-///              |
-///              1
-///
-/// In the above example, Values with a star satisfy the condition. When
-/// starting the traversal from Value 1, the resulting SetVector is:
-/// { 2, 7, 8, 5 }
-static llvm::SetVector<Value>
-findValueInReverseUseDefChain(Value value,
-                              std::function<bool(Value)> condition) {
-  llvm::SetVector<Value> result, workingSet;
-  workingSet.insert(value);
-
-  while (!workingSet.empty()) {
-    Value value = workingSet.pop_back_val();
-    if (condition(value) || value.isa<BlockArgument>()) {
-      result.insert(value);
-      continue;
-    }
-
-    OpResult opResult = value.cast<OpResult>();
-    SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
-    if (opOperands.empty()) {
-      result.insert(value);
-      continue;
-    }
-
-    for (OpOperand *o : opOperands)
-      workingSet.insert(o->get());
-  }
-
-  return result;
-}
-
-/// Find the Value of the last preceding write of a given Value.
-///
-/// Note: Unknown ops are handled conservatively and assumed to be writes.
-/// Furthermore, BlockArguments are also assumed to be writes. There is no
-/// analysis across block boundaries.
-///
-/// Note: When reaching an end of the reverse SSA use-def chain, that value
-/// is returned regardless of whether it is a memory write or not.
-static Value findLastPrecedingWrite(Value value) {
-  SetVector<Value> result =
-      findValueInReverseUseDefChain(value, [](Value value) {
-        Operation *op = value.getDefiningOp();
-        if (!op)
-          return true;
-        auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
-        if (!bufferizableOp)
-          return true;
-        return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
-      });
-
-  // To simplify the analysis, `scf.if` ops are considered memory writes. There
-  // are currently no other ops where one OpResult may alias with multiple
-  // OpOperands. Therefore, this function should return exactly one result at
-  // the moment.
-  assert(result.size() == 1 && "expected exactly one result");
-  return result.front();
-}
-
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
 static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
@@ -2035,6 +1957,17 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   return success();
 }
 
+std::unique_ptr<AllocationCallbacks>
+mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
+  return std::make_unique<AllocationCallbacks>(
+      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
+}
+
+// Default constructor for BufferizationOptions that sets all allocation
+// callbacks to their default functions.
+BufferizationOptions::BufferizationOptions()
+    : allocationFns(defaultAllocationCallbacks()) {}
+
 //===----------------------------------------------------------------------===//
 // BufferizableOpInterface Implementations
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index d95ba926f81e7..b9fbfd727adfb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Pass/Pass.h"

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 6aa22bf533b2d..bb3e66695c5d0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6380,6 +6380,7 @@ cc_library(
         ":AffineUtils",
         ":Analysis",
         ":ArithmeticDialect",
+        ":BufferizableOpInterface",
         ":ComplexDialect",
         ":ComprehensiveBufferize",
         ":DialectUtils",


        


More information about the Mlir-commits mailing list