[Mlir-commits] [mlir] 1fdf06d - [mlir][bufferization] Reads from tensors with undefined data are not a conflict
Matthias Springer
llvmlistbot at llvm.org
Mon Feb 6 07:16:23 PST 2023
Author: Matthias Springer
Date: 2023-02-06T16:11:13+01:00
New Revision: 1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5
URL: https://github.com/llvm/llvm-project/commit/1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5
DIFF: https://github.com/llvm/llvm-project/commit/1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5.diff
LOG: [mlir][bufferization] Reads from tensors with undefined data are not a conflict
Reading from tensor.empty or bufferization.alloc_tensor (without copy) cannot cause a conflict because these ops do not specify the contents of their result tensors.
Differential Revision: https://reviews.llvm.org/D143183
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 61ead5b58c2ae..a93a5d9a2cfed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -412,14 +412,16 @@ class AnalysisState {
/// in the operands) because their defining ops do not define the contents of
/// the tensor.
///
+ /// Example:
+ /// %a = tensor.empty() : tensor<10xf32>
+ /// %b = arith.constant ... : tensor<10xf32>
+ /// %r = arith.select %cond, %a, %b : tensor<10xf32>
+ /// findDefinitions(%r) = {%b}. %a is excluded because it does not define the
+ /// contents of the tensor.
+ ///
/// Note: OpResults of unknown ops are handled conservatively and assumed to
/// be definitions.
- ///
- /// Note: When reaching an end of the reverse SSA use-def chain, that value
- /// is included regardless of whether it is a definition or not unless
- /// `alwaysIncludeLeaves` is unset.
- SetVector<Value> findDefinitions(Value value,
- bool alwaysIncludeLeaves = true) const;
+ SetVector<Value> findDefinitions(Value value) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 10f5e294f06b6..12fc89740d15c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -494,11 +494,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
}
// Find the values that define the contents of the given value.
-llvm::SetVector<Value>
-AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const {
+llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
return findValueInReverseUseDefChain(
value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
- /*followEquivalentOnly=*/false, alwaysIncludeLeaves);
+ /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 24a3aad53af8e..7aaf79b8d72f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -272,7 +272,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
+ if (findDefinitions(opResult).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -513,8 +513,11 @@ static bool hasReadAfterWriteInterference(
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
+ LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
+ LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
+ << " of " << *readingOp << "\n");
- // Find most recent writes of uRead by following the SSA use-def chain.
+ // Find the definition of uRead by following the SSA use-def chain.
// E.g.:
//
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
@@ -525,14 +528,16 @@ static bool hasReadAfterWriteInterference(
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
SetVector<Value> definitions = state.findDefinitions(uRead->get());
+ if (definitions.empty()) {
+ // Fast path: No conflict if there are no definitions.
+ LLVM_DEBUG(llvm::dbgs()
+ << " no conflict: read value has no definitions\n");
+ continue;
+ }
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
for (OpOperand *uConflictingWrite : usesWrite) {
- LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
- LLVM_DEBUG(llvm::dbgs()
- << " uRead = operand " << uRead->getOperandNumber() << " of "
- << *uRead->getOwner() << "\n");
LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
<< uConflictingWrite->getOperandNumber() << " of "
<< *uConflictingWrite->getOwner() << "\n");
@@ -608,15 +613,15 @@ static bool hasReadAfterWriteInterference(
LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
// No conflict if the conflicting write happens before the definition.
- if (Operation *writingOp = definition.getDefiningOp()) {
- if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
- // conflictingWritingOp happens before writingOp. No conflict.
+ if (Operation *defOp = definition.getDefiningOp()) {
+ if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
+ // conflictingWritingOp happens before defOp. No conflict.
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: write happens before definition\n");
continue;
}
- // No conflict if conflictingWritingOp is contained in writingOp.
- if (writingOp->isProperAncestor(conflictingWritingOp)) {
+ // No conflict if conflictingWritingOp is contained in defOp.
+ if (defOp->isProperAncestor(conflictingWritingOp)) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: write is contained in definition\n");
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index 6da126837a1fa..4eaa7dc2bcbf5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -32,3 +32,27 @@ func.func @unknown_op_writing(%f: f32, %f2: f32, %pos: index) -> f32 {
%3 = tensor.extract %1[%pos] : tensor<10xf32>
return %3 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @read_of_undef_is_not_a_conflict(
+func.func @read_of_undef_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
+ %0 = tensor.empty() : tensor<10xf32>
+ // This can be in-place because the read below does reads undefined data.
+ // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+ %1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
+ %2 = tensor.extract %0[%idx] : tensor<10xf32>
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @read_of_alloc_tensor_is_not_a_conflict(
+func.func @read_of_alloc_tensor_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
+ %0 = bufferization.alloc_tensor() : tensor<10xf32>
+ // This can be in-place because the read below does reads undefined data.
+ // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+ %1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
+ %2 = tensor.extract %0[%idx] : tensor<10xf32>
+ return %2 : f32
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 587eed843c71f..131f3066ec2c3 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -713,26 +713,28 @@ func.func @scf_foreach_privatized_but_not_copied(
// -----
// CHECK-LABEL: func @scf_if_memory_space
-func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32)
+func.func @scf_if_memory_space(%c: i1, %f: f32, %cst: f32) -> (f32, f32)
{
%c0 = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
- %0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
+ %alloc = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
+ // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<5xf32, 1>)
+ %filled = linalg.fill ins(%cst : f32) outs(%alloc : tensor<5xf32>) -> tensor<5xf32>
// CHECK: scf.if %{{.*}} -> (memref<5xf32, 1>) {
%1 = scf.if %c -> tensor<5xf32> {
// CHECK: %[[cloned:.*]] = bufferization.clone %[[alloc]]
// CHECK: scf.yield %[[cloned]]
- scf.yield %0 : tensor<5xf32>
+ scf.yield %filled : tensor<5xf32>
} else {
// CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
// CHECK: memref.store %{{.*}}, %[[alloc2]]
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[alloc2]]
// CHECK: memref.dealloc %[[alloc2]]
// CHECK: scf.yield %[[cloned2]]
- %2 = tensor.insert %f into %0[%c0] : tensor<5xf32>
+ %2 = tensor.insert %f into %filled[%c0] : tensor<5xf32>
scf.yield %2 : tensor<5xf32>
}
- %r0 = tensor.extract %0[%c0] : tensor<5xf32>
+ %r0 = tensor.extract %filled[%c0] : tensor<5xf32>
%r1 = tensor.extract %1[%c0] : tensor<5xf32>
return %r0, %r1 : f32, f32
}
More information about the Mlir-commits
mailing list