[Mlir-commits] [mlir] 8835a19 - [mlir][linalg][bufferize] Allow non-tensor mappings in BufferizationState

Matthias Springer llvmlistbot at llvm.org
Mon Nov 15 02:40:42 PST 2021


Author: Matthias Springer
Date: 2021-11-15T19:40:30+09:00
New Revision: 8835a1924e37fe8029f9af7aacd9eb81c2848401

URL: https://github.com/llvm/llvm-project/commit/8835a1924e37fe8029f9af7aacd9eb81c2848401
DIFF: https://github.com/llvm/llvm-project/commit/8835a1924e37fe8029f9af7aacd9eb81c2848401.diff

LOG: [mlir][linalg][bufferize] Allow non-tensor mappings in BufferizationState

This change makes it possible to set up custom mappings in a PostAnalysisStep. Some users of Comprehensive Bufferize have custom tensor types and it is most convenient to just reuse the same bvm.

Also add some more assertions.

Differential Revision: https://reviews.llvm.org/D113726

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 07f065d909e5..a9a344d75c1c 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
 
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
@@ -40,6 +41,9 @@ class BufferizationAliasInfo {
 public:
   explicit BufferizationAliasInfo(Operation *rootOp);
 
+  // BufferizationAliasInfo should be passed as a reference.
+  BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
+
   /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
   /// beginning the alias and equivalence sets only contain `v` itself.
   void createAliasInfoEntry(Value v);
@@ -237,14 +241,18 @@ struct AllocationCallbacks {
 /// the results of the analysis.
 struct BufferizationState {
   BufferizationState(BufferizationAliasInfo &aliasInfo,
-                     AllocationCallbacks &allocationFns,
-                     BlockAndValueMapping &tensorToBufferMap)
-      : aliasInfo(aliasInfo), allocationFns(allocationFns),
-        tensorToBufferMap(tensorToBufferMap) {}
+                     AllocationCallbacks &allocationFns)
+      : aliasInfo(aliasInfo), allocationFns(allocationFns) {}
+
+  // BufferizationState should be passed as a reference.
+  BufferizationState(const BufferizationState &) = delete;
 
   /// Map tensor values to memref buffers.
   void mapBuffer(ValueRange tensors, ValueRange buffers);
 
+  /// Map a value to another value.
+  void mapValue(Value from, Value to);
+
   /// Map a tensor value to a memref buffer.
   void mapBuffer(Value tensor, Value buffer);
 
@@ -252,6 +260,16 @@ struct BufferizationState {
   /// Asserts if no buffer is associated.
   Value lookupBuffer(Value tensor) const;
 
+  /// Lookup the value that is associated to the given value. Asserts if no
+  /// value is associated.
+  Value lookupValue(Value value) const;
+
+  /// Return `true` if the given value is mapped.
+  bool isMapped(Value value) const;
+
+  /// Mark `op` as obsolete, so that it is deleted after bufferization.
+  void markOpObsolete(Operation *op);
+
   /// `aliasInfo` keeps track of aliasing and equivalent values.
   BufferizationAliasInfo &aliasInfo;
 
@@ -259,8 +277,12 @@ struct BufferizationState {
   /// ops and memcpy ops.
   AllocationCallbacks &allocationFns;
 
-  /// The mapping of tensors to buffers.
-  BlockAndValueMapping &tensorToBufferMap;
+  /// The mapping of tensors to buffers. May also contain mappings of non-tensor
+  /// values.
+  BlockAndValueMapping mapping;
+
+  /// Obsolete ops that should be deleted after bufferization.
+  SmallVector<Operation *> obsoleteOps;
 };
 
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 0778f1dbd3b6..150ffd7e45f3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -398,7 +398,19 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
 void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
     ValueRange tensors, ValueRange buffers) {
   assert(!tensors.empty() && "unexpected empty tensors");
-  return tensorToBufferMap.map(tensors, buffers);
+#ifndef NDEBUG
+  for (Value tensor : tensors) {
+    assert(tensor && "unexpected empty tensor");
+    assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+  }
+  for (Value buffer : buffers) {
+    assert(buffer && "unexpected empty buffer");
+    assert((buffer.getType().isa<MemRefType>() ||
+            buffer.getType().isa<UnrankedMemRefType>()) &&
+           "expected that tensor is mapped to memref");
+  }
+#endif // NDEBUG
+  return mapping.map(tensors, buffers);
 }
 
 /// Wrapper for better debugging.
@@ -406,7 +418,17 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
     Value tensor, Value buffer) {
   assert(tensor && "unexpected empty tensor");
   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
-  return tensorToBufferMap.map(tensor, buffer);
+  assert(buffer && "unexpected empty buffer");
+  assert((buffer.getType().isa<MemRefType>() ||
+          buffer.getType().isa<UnrankedMemRefType>()) &&
+         "expected that tensor is mapped to memref");
+  return mapping.map(tensor, buffer);
+}
+
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapValue(
+    Value from, Value to) {
+  assert(from && "unexpected empty value");
+  return mapping.map(from, to);
 }
 
 /// Wrapper for better debugging.
@@ -414,7 +436,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
     Value tensor) const {
   // TODO: if key comes from bbArg, forward.
   assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
-  Value v = tensorToBufferMap.lookupOrNull(tensor);
+  Value v = mapping.lookupOrNull(tensor);
 
   if (!v) {
     // Dump tensor for easier debugging.
@@ -423,5 +445,28 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
     return Value();
   }
 
+  assert((v.getType().isa<MemRefType>() ||
+          v.getType().isa<UnrankedMemRefType>()) &&
+         "expected that tensor is mapped to memref");
   return v;
 }
+
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupValue(
+    Value value) const {
+  Value v = mapping.lookupOrNull(value);
+  if (!v) {
+    llvm_unreachable("tensor is not mapped");
+    return Value();
+  }
+  return v;
+}
+
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
+    Value value) const {
+  return mapping.contains(value);
+}
+
+void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
+    Operation *op) {
+  obsoleteOps.push_back(op);
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index f0024214909b..6482d740f151 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -1677,12 +1677,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
 
     // Bufferization phase.
     if (!options.testAnalysisOnly) {
-      BlockAndValueMapping tensorToBufferMap;
-      BufferizationState state(aliasInfo, *options.allocationFns,
-                               tensorToBufferMap);
+      BufferizationState state(aliasInfo, *options.allocationFns);
+
+      // Bufferize all ops in funcOp.
       if (failed(
               bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
         return failure();
+
+      // Erase all obsolete ops.
+      for (Operation *op : state.obsoleteOps)
+        op->erase();
     }
   }
   // Annotate operations if we only want to report the analysis.


        


More information about the Mlir-commits mailing list