[Mlir-commits] [mlir] 032be23 - [mlir][bufferize] Improve buffer writability analysis
Matthias Springer
llvmlistbot at llvm.org
Wed Jun 8 01:12:10 PDT 2022
Author: Matthias Springer
Date: 2022-06-08T10:11:52+02:00
New Revision: 032be2330928995ae264a4886fd2610bc3e49656
URL: https://github.com/llvm/llvm-project/commit/032be2330928995ae264a4886fd2610bc3e49656
DIFF: https://github.com/llvm/llvm-project/commit/032be2330928995ae264a4886fd2610bc3e49656.diff
LOG: [mlir][bufferize] Improve buffer writability analysis
Find writability conflicts (writes to buffers that are not allowed to be written to) by checking SSA use-def chains. This is better than the current writability analysis, which is too conservative and finds false positives.
Differential Revision: https://reviews.llvm.org/D127256
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 8d44d579fbc8..2663e480f281 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -167,6 +167,9 @@ class OneShotAnalysisState : public AnalysisState {
/// not be called for values inside not yet analyzed functions.
bool isValueWritten(Value value) const;
+ /// Return true if the buffer of the given tensor value is writable.
+ bool isWritable(Value value) const;
+
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runOneShotBufferize` may access this object.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index f29d9a96b0b3..5447f6b0bdc2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -305,6 +305,21 @@ bool OneShotAnalysisState::isValueWritten(Value value) const {
return isWritten;
}
+bool OneShotAnalysisState::isWritable(Value value) const {
+ // TODO: Out-of-place bufferized value could be considered writable.
+ if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
+ return bufferizableOp.isWritable(value, *this);
+
+ // Query BufferizableOpInterface to see if the BlockArgument is writable.
+ if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bufferizableOp =
+ getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+ return bufferizableOp.isWritable(bbArg, *this);
+
+ // Not a bufferizable op: The conservative answer is "not writable".
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
@@ -312,7 +327,7 @@ bool OneShotAnalysisState::isValueWritten(Value value) const {
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
+ const AnalysisState &state) {
// OpOperands that do not bufferize to a memory write do not write in-place.
if (!state.bufferizesToMemoryWrite(opOperand))
return false;
@@ -320,49 +335,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
return aliasInfo.isInPlace(opOperand);
}
-/// Return true if, under current bufferization decisions, the buffer of `value`
-/// is not writable.
-static bool aliasesNonWritableBuffer(Value value,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- bool foundNonWritableBuffer = false;
- aliasInfo.applyOnAliases(value, [&](Value v) {
- // Query BufferizableOpInterface to see if the value is writable.
- // TODO: Out-of-place bufferized value could be considered writable.
- if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
- if (bufferizableOp && bufferizableOp.isWritable(v, state))
- return;
-
- // Query BufferizableOpInterface to see if the BlockArgument is writable.
- if (auto bbArg = v.dyn_cast<BlockArgument>())
- if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
- bbArg.getOwner()->getParentOp()))
- if (bufferizableOp.isWritable(bbArg, state))
- return;
-
- foundNonWritableBuffer = true;
- });
-
- return foundNonWritableBuffer;
-}
-
-/// Return true if the buffer to which `operand` would bufferize is equivalent
-/// to some buffer write.
-static bool aliasesInPlaceWrite(Value value,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- bool foundInplaceWrite = false;
- aliasInfo.applyOnAliases(value, [&](Value v) {
- for (auto &use : v.getUses()) {
- if (isInplaceMemoryWrite(use, aliasInfo, state)) {
- foundInplaceWrite = true;
- return;
- }
- }
- });
- return foundInplaceWrite;
-}
-
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
@@ -604,6 +576,30 @@ static bool hasReadAfterWriteInterference(
return false;
}
+// Helper function to iterate on aliases of `root` and capture the writes.
+static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
+ const BufferizationAliasInfo &aliasInfo,
+ const AnalysisState &state) {
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
+ for (auto &use : alias.getUses())
+ // Inplace write to a value that aliases root.
+ if (isInplaceMemoryWrite(use, aliasInfo, state))
+ res.insert(&use);
+ });
+}
+
+// Helper function to iterate on aliases of `root` and capture the reads.
+static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
+ const BufferizationAliasInfo &aliasInfo,
+ const AnalysisState &state) {
+ aliasInfo.applyOnAliases(root, [&](Value alias) {
+ for (auto &use : alias.getUses())
+ // Read to a value that aliases root.
+ if (state.bufferizesToMemoryRead(use))
+ res.insert(&use);
+ });
+}
+
/// Return true if bufferizing `operand` inplace would create a conflict. A read
/// R and a write W of the same alias set is a conflict if inplace bufferization
/// of W changes the value read by R to a value
diff erent from the one that
@@ -637,33 +633,13 @@ static bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
- // Helper function to iterate on aliases of `root` and capture the reads.
- auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
- for (auto &use : alias.getUses())
- // Read to a value that aliases root.
- if (state.bufferizesToMemoryRead(use))
- res.insert(&use);
- });
- };
-
- // Helper function to iterate on aliases of `root` and capture the writes.
- auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
- for (auto &use : alias.getUses())
- // Inplace write to a value that aliases root.
- if (isInplaceMemoryWrite(use, aliasInfo, state))
- res.insert(&use);
- });
- };
-
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
- getAliasingReads(usesRead, operand.get());
- getAliasingInplaceWrites(usesWrite, operand.get());
+ getAliasingReads(usesRead, operand.get(), aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
for (OpResult result : state.getAliasingOpResult(operand)) {
- getAliasingReads(usesRead, result);
- getAliasingInplaceWrites(usesWrite, result);
+ getAliasingReads(usesRead, result, aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
@@ -672,28 +648,60 @@ static bool wouldCreateReadAfterWriteInterference(
aliasInfo);
}
-/// Return true if bufferizing `opOperand` inplace would create a write to a
-/// non-writable buffer.
+/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
+/// non-writable tensor values. Stop searching when an out-of-place bufferized
+/// OpOperand was found (or when the OpOperand was not bufferized yet).
+/// `currentOpOperand` is assumed to be in-place, even if that decision was not
+/// materialized in `aliasInfo` yet.
static bool
-wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
- const BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
- // Certain buffers are not writeable:
- // 1. A function bbArg that is not inplaceable or
- // 2. A constant op.
- bool nonWritable =
- aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
- if (!nonWritable)
- return false;
+hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
+ const BufferizationAliasInfo &aliasInfo,
+ const OneShotAnalysisState &state) {
+ SmallVector<Value> worklist;
+ worklist.push_back(value);
+ while (!worklist.empty()) {
+ Value nextVal = worklist.pop_back_val();
+ if (!state.isWritable(nextVal))
+ return true;
+
+ // If `nextVal` is not a BlockArgument: End of use-def chain reached.
+ auto opResult = nextVal.dyn_cast<OpResult>();
+ if (!opResult)
+ continue;
+
+ // Follow reverse SSA use-def chain.
+ SmallVector<OpOperand *> aliasingOpOperands =
+ state.getAliasingOpOperand(opResult);
+ for (OpOperand *opOperand : aliasingOpOperands)
+ if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
+ worklist.push_back(opOperand->get());
+ }
+ return false;
+}
- // This is a problem only if the buffer is written to via some alias.
- bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
- state.bufferizesToMemoryWrite(opOperand);
+/// Return true if bufferizing `operand` inplace would create a write to a
+/// non-writable buffer.
+static bool wouldCreateWriteToNonWritableBuffer(
+ OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
+ OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
+ // Collect writes of all aliases of OpOperand and OpResult.
+ DenseSet<OpOperand *> usesWrite;
+ getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+ for (OpResult result : state.getAliasingOpResult(operand)) {
+ getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+ }
+ if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
+ usesWrite.insert(&operand);
- for (OpResult opResult : state.getAliasingOpResult(opOperand))
- hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
+ // Assuming that `operand` bufferizes in-place: For each write (to each
+ // alias), check if there is a non-writable tensor in the reverse SSA use-def
+ // chain.
+ for (OpOperand *uWrite : usesWrite)
+ if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
+ aliasInfo, state))
+ return true;
- return hasWrite;
+ return false;
}
//===----------------------------------------------------------------------===//
@@ -702,8 +710,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
/// Determine if `operand` can be bufferized in-place.
static LogicalResult bufferizableInPlaceAnalysisImpl(
- OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
- const DominanceInfo &domInfo) {
+ OpOperand &operand, BufferizationAliasInfo &aliasInfo,
+ OneShotAnalysisState &state, const DominanceInfo &domInfo) {
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
@@ -736,7 +744,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
/// RaW dependence violations.
static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
- AnalysisState &state,
+ OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
if (analysisFuzzerSeed) {
@@ -769,7 +777,7 @@ static bool hasTensorSemantics(Operation *op) {
/// Analyze all ops that are contained in `op`.
static LogicalResult inPlaceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo,
- AnalysisState &state,
+ OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
// Collect ops so we can build our own reverse traversal.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
index 9fff7f990b39..3145959e767e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
@@ -31,3 +31,34 @@ func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32>
// CHECK: return %[[r_tensor]]
return %r : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @write_to_alloc_tensor_or_readonly_tensor(
+// CHECK-SAME: %[[arg0:.*]]: tensor<i32>
+func.func @write_to_alloc_tensor_or_readonly_tensor(%arg0: tensor<i32>,
+ %cond: i1, %val: i32)
+ -> tensor<i32>
+{
+ // CHECK: %[[r:.*]] = scf.if {{.*}} {
+ // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
+ // CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]]
+ // CHECK: scf.yield %[[clone]]
+ // CHECK: } else {
+ // CHECK: %[[alloc:.*]] = memref.alloc
+ // CHECK: memref.store %{{.*}}, %[[alloc]]
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK: scf.yield %[[casted]]
+ // CHECK: }
+ // CHECK: %[[r_t:.*]] = bufferization.to_tensor %[[r]]
+ // CHECK: memref.dealloc %[[r]]
+ // CHECK: return %[[r_t]]
+ %3 = scf.if %cond -> (tensor<i32>) {
+ scf.yield %arg0 : tensor<i32>
+ } else {
+ %7 = bufferization.alloc_tensor() : tensor<i32>
+ %8 = tensor.insert %val into %7[] : tensor<i32>
+ scf.yield %8 : tensor<i32>
+ }
+ return %3 : tensor<i32>
+}
More information about the Mlir-commits
mailing list