[Mlir-commits] [mlir] 1fdf06d - [mlir][bufferization] Reads from tensors with undefined data are not a conflict

Matthias Springer llvmlistbot at llvm.org
Mon Feb 6 07:16:23 PST 2023


Author: Matthias Springer
Date: 2023-02-06T16:11:13+01:00
New Revision: 1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5

URL: https://github.com/llvm/llvm-project/commit/1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5
DIFF: https://github.com/llvm/llvm-project/commit/1fdf06d6d79ea0ced79d680b7fcd622ef63fb9a5.diff

LOG: [mlir][bufferization] Reads from tensors with undefined data are not a conflict

Reading from tensor.empty or bufferization.alloc_tensor (without copy) cannot cause a conflict because these ops do not specify the contents of their result tensors.

Differential Revision: https://reviews.llvm.org/D143183

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
    mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 61ead5b58c2ae..a93a5d9a2cfed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -412,14 +412,16 @@ class AnalysisState {
   /// in the operands) because their defining ops do not define the contents of
   /// the tensor.
   ///
+  /// Example:
+  /// %a = tensor.empty() : tensor<10xf32>
+  /// %b = arith.constant ... : tensor<10xf32>
+  /// %r = arith.select %cond, %a, %b : tensor<10xf32>
+  /// findDefinitions(%r) = {%b}. %a is excluded because it does not define the
+  /// contents of the tensor.
+  ///
   /// Note: OpResults of unknown ops are handled conservatively and assumed to
   /// be definitions.
-  ///
-  /// Note: When reaching an end of the reverse SSA use-def chain, that value
-  /// is included regardless of whether it is a definition or not unless
-  /// `alwaysIncludeLeaves` is unset.
-  SetVector<Value> findDefinitions(Value value,
-                                   bool alwaysIncludeLeaves = true) const;
+  SetVector<Value> findDefinitions(Value value) const;
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   virtual bool isInPlace(OpOperand &opOperand) const;

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 10f5e294f06b6..12fc89740d15c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -494,11 +494,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
 }
 
 // Find the values that define the contents of the given value.
-llvm::SetVector<Value>
-AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const {
+llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
   return findValueInReverseUseDefChain(
       value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
-      /*followEquivalentOnly=*/false, alwaysIncludeLeaves);
+      /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 24a3aad53af8e..7aaf79b8d72f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -272,7 +272,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
       // If there is no preceding definition, the tensor contents are
       // undefined.
-      if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty())
+      if (findDefinitions(opResult).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -513,8 +513,11 @@ static bool hasReadAfterWriteInterference(
 
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
+    LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
+    LLVM_DEBUG(llvm::dbgs() << "  uRead = operand " << uRead->getOperandNumber()
+                            << " of " << *readingOp << "\n");
 
-    // Find most recent writes of uRead by following the SSA use-def chain.
+    // Find the definition of uRead by following the SSA use-def chain.
     // E.g.:
     //
     // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
@@ -525,14 +528,16 @@ static bool hasReadAfterWriteInterference(
     // definition is %0. Note that operations that create an alias but do not
     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
     SetVector<Value> definitions = state.findDefinitions(uRead->get());
+    if (definitions.empty()) {
+      // Fast path: No conflict if there are no definitions.
+      LLVM_DEBUG(llvm::dbgs()
+                 << "  no conflict: read value has no definitions\n");
+      continue;
+    }
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
     // alias that have been decided to bufferize inplace.
     for (OpOperand *uConflictingWrite : usesWrite) {
-      LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
-      LLVM_DEBUG(llvm::dbgs()
-                 << "  uRead = operand " << uRead->getOperandNumber() << " of "
-                 << *uRead->getOwner() << "\n");
       LLVM_DEBUG(llvm::dbgs() << "  unConflictingWrite = operand "
                               << uConflictingWrite->getOperandNumber() << " of "
                               << *uConflictingWrite->getOwner() << "\n");
@@ -608,15 +613,15 @@ static bool hasReadAfterWriteInterference(
         LLVM_DEBUG(llvm::dbgs() << "  * definition = " << definition << "\n");
 
         // No conflict if the conflicting write happens before the definition.
-        if (Operation *writingOp = definition.getDefiningOp()) {
-          if (happensBefore(conflictingWritingOp, writingOp, domInfo)) {
-            // conflictingWritingOp happens before writingOp. No conflict.
+        if (Operation *defOp = definition.getDefiningOp()) {
+          if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
+            // conflictingWritingOp happens before defOp. No conflict.
             LLVM_DEBUG(llvm::dbgs()
                        << "    no conflict: write happens before definition\n");
             continue;
           }
-          // No conflict if conflictingWritingOp is contained in writingOp.
-          if (writingOp->isProperAncestor(conflictingWritingOp)) {
+          // No conflict if conflictingWritingOp is contained in defOp.
+          if (defOp->isProperAncestor(conflictingWritingOp)) {
             LLVM_DEBUG(
                 llvm::dbgs()
                 << "    no conflict: write is contained in definition\n");

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index 6da126837a1fa..4eaa7dc2bcbf5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -32,3 +32,27 @@ func.func @unknown_op_writing(%f: f32, %f2: f32, %pos: index) -> f32 {
   %3 = tensor.extract %1[%pos] : tensor<10xf32>
   return %3 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @read_of_undef_is_not_a_conflict(
+func.func @read_of_undef_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
+  %0 = tensor.empty() : tensor<10xf32>
+  // This can be in-place because the read below does reads undefined data.
+  // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
+  %2 = tensor.extract %0[%idx] : tensor<10xf32>
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @read_of_alloc_tensor_is_not_a_conflict(
+func.func @read_of_alloc_tensor_is_not_a_conflict(%f: f32, %idx: index) -> f32 {
+  %0 = bufferization.alloc_tensor() : tensor<10xf32>
+  // This can be in-place because the read below does reads undefined data.
+  // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
+  %1 = tensor.insert %f into %0[%idx] : tensor<10xf32>
+  %2 = tensor.extract %0[%idx] : tensor<10xf32>
+  return %2 : f32
+}

diff  --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 587eed843c71f..131f3066ec2c3 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -713,26 +713,28 @@ func.func @scf_foreach_privatized_but_not_copied(
 // -----
 
 // CHECK-LABEL: func @scf_if_memory_space
-func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32)
+func.func @scf_if_memory_space(%c: i1, %f: f32, %cst: f32) -> (f32, f32)
 {
   %c0 = arith.constant 0 : index
   // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
-  %0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
+  %alloc = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
+  // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<5xf32, 1>)
+  %filled = linalg.fill ins(%cst : f32) outs(%alloc : tensor<5xf32>) -> tensor<5xf32>
   // CHECK: scf.if %{{.*}} -> (memref<5xf32, 1>) {
   %1 = scf.if %c -> tensor<5xf32> {
     // CHECK: %[[cloned:.*]] = bufferization.clone %[[alloc]]
     // CHECK: scf.yield %[[cloned]]
-    scf.yield %0 : tensor<5xf32>
+    scf.yield %filled : tensor<5xf32>
   } else {
     // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
     // CHECK: memref.store %{{.*}}, %[[alloc2]]
     // CHECK: %[[cloned2:.*]] = bufferization.clone %[[alloc2]]
     // CHECK: memref.dealloc %[[alloc2]]
     // CHECK: scf.yield %[[cloned2]]
-    %2 = tensor.insert %f into %0[%c0] : tensor<5xf32>
+    %2 = tensor.insert %f into %filled[%c0] : tensor<5xf32>
     scf.yield %2 : tensor<5xf32>
   }
-  %r0 = tensor.extract %0[%c0] : tensor<5xf32>
+  %r0 = tensor.extract %filled[%c0] : tensor<5xf32>
   %r1 = tensor.extract %1[%c0] : tensor<5xf32>
   return %r0, %r1 : f32, f32
 }


        


More information about the Mlir-commits mailing list