[Mlir-commits] [mlir] 8835a19 - [mlir][linalg][bufferize] Allow non-tensor mappings in BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 15 02:40:42 PST 2021
Author: Matthias Springer
Date: 2021-11-15T19:40:30+09:00
New Revision: 8835a1924e37fe8029f9af7aacd9eb81c2848401
URL: https://github.com/llvm/llvm-project/commit/8835a1924e37fe8029f9af7aacd9eb81c2848401
DIFF: https://github.com/llvm/llvm-project/commit/8835a1924e37fe8029f9af7aacd9eb81c2848401.diff
LOG: [mlir][linalg][bufferize] Allow non-tensor mappings in BufferizationState
This change makes it possible to set up custom mappings in a PostAnalysisStep. Some users of Comprehensive Bufferize have custom tensor types and it is most convenient to just reuse the same bvm.
Also add some more assertions.
Differential Revision: https://reviews.llvm.org/D113726
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 07f065d909e5..a9a344d75c1c 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
@@ -40,6 +41,9 @@ class BufferizationAliasInfo {
public:
explicit BufferizationAliasInfo(Operation *rootOp);
+ // BufferizationAliasInfo should be passed as a reference.
+ BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
+
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void createAliasInfoEntry(Value v);
@@ -237,14 +241,18 @@ struct AllocationCallbacks {
/// the results of the analysis.
struct BufferizationState {
BufferizationState(BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFns,
- BlockAndValueMapping &tensorToBufferMap)
- : aliasInfo(aliasInfo), allocationFns(allocationFns),
- tensorToBufferMap(tensorToBufferMap) {}
+ AllocationCallbacks &allocationFns)
+ : aliasInfo(aliasInfo), allocationFns(allocationFns) {}
+
+ // BufferizationState should be passed as a reference.
+ BufferizationState(const BufferizationState &) = delete;
/// Map tensor values to memref buffers.
void mapBuffer(ValueRange tensors, ValueRange buffers);
+ /// Map a value to another value.
+ void mapValue(Value from, Value to);
+
/// Map a tensor value to a memref buffer.
void mapBuffer(Value tensor, Value buffer);
@@ -252,6 +260,16 @@ struct BufferizationState {
/// Asserts if no buffer is associated.
Value lookupBuffer(Value tensor) const;
+ /// Lookup the value that is associated to the given value. Asserts if no
+ /// value is associated.
+ Value lookupValue(Value value) const;
+
+ /// Return `true` if the given value is mapped.
+ bool isMapped(Value value) const;
+
+ /// Mark `op` as obsolete, so that it is deleted after bufferization.
+ void markOpObsolete(Operation *op);
+
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo &aliasInfo;
@@ -259,8 +277,12 @@ struct BufferizationState {
/// ops and memcpy ops.
AllocationCallbacks &allocationFns;
- /// The mapping of tensors to buffers.
- BlockAndValueMapping &tensorToBufferMap;
+ /// The mapping of tensors to buffers. May also contain mappings of non-tensor
+ /// values.
+ BlockAndValueMapping mapping;
+
+ /// Obsolete ops that should be deleted after bufferization.
+ SmallVector<Operation *> obsoleteOps;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 0778f1dbd3b6..150ffd7e45f3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -398,7 +398,19 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
ValueRange tensors, ValueRange buffers) {
assert(!tensors.empty() && "unexpected empty tensors");
- return tensorToBufferMap.map(tensors, buffers);
+#ifndef NDEBUG
+ for (Value tensor : tensors) {
+ assert(tensor && "unexpected empty tensor");
+ assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
+ }
+ for (Value buffer : buffers) {
+ assert(buffer && "unexpected empty buffer");
+ assert((buffer.getType().isa<MemRefType>() ||
+ buffer.getType().isa<UnrankedMemRefType>()) &&
+ "expected that tensor is mapped to memref");
+ }
+#endif // NDEBUG
+ return mapping.map(tensors, buffers);
}
/// Wrapper for better debugging.
@@ -406,7 +418,17 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::mapBuffer(
Value tensor, Value buffer) {
assert(tensor && "unexpected empty tensor");
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
- return tensorToBufferMap.map(tensor, buffer);
+ assert(buffer && "unexpected empty buffer");
+ assert((buffer.getType().isa<MemRefType>() ||
+ buffer.getType().isa<UnrankedMemRefType>()) &&
+ "expected that tensor is mapped to memref");
+ return mapping.map(tensor, buffer);
+}
+
+void mlir::linalg::comprehensive_bufferize::BufferizationState::mapValue(
+ Value from, Value to) {
+ assert(from && "unexpected empty value");
+ return mapping.map(from, to);
}
/// Wrapper for better debugging.
@@ -414,7 +436,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
Value tensor) const {
// TODO: if key comes from bbArg, forward.
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
- Value v = tensorToBufferMap.lookupOrNull(tensor);
+ Value v = mapping.lookupOrNull(tensor);
if (!v) {
// Dump tensor for easier debugging.
@@ -423,5 +445,28 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
return Value();
}
+ assert((v.getType().isa<MemRefType>() ||
+ v.getType().isa<UnrankedMemRefType>()) &&
+ "expected that tensor is mapped to memref");
return v;
}
+
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupValue(
+ Value value) const {
+ Value v = mapping.lookupOrNull(value);
+ if (!v) {
+ llvm_unreachable("tensor is not mapped");
+ return Value();
+ }
+ return v;
+}
+
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isMapped(
+ Value value) const {
+ return mapping.contains(value);
+}
+
+void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
+ Operation *op) {
+ obsoleteOps.push_back(op);
+}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index f0024214909b..6482d740f151 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -1677,12 +1677,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
- BlockAndValueMapping tensorToBufferMap;
- BufferizationState state(aliasInfo, *options.allocationFns,
- tensorToBufferMap);
+ BufferizationState state(aliasInfo, *options.allocationFns);
+
+ // Bufferize all ops in funcOp.
if (failed(
bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
return failure();
+
+ // Erase all obsolete ops.
+ for (Operation *op : state.obsoleteOps)
+ op->erase();
}
}
// Annotate operations if we only want to report the analysis.
More information about the Mlir-commits
mailing list