[Mlir-commits] [mlir] 988748c - [mlir][bufferize] Do not copy buffers with undefined contents

Matthias Springer llvmlistbot at llvm.org
Fri May 6 01:35:02 PDT 2022


Author: Matthias Springer
Date: 2022-05-06T17:31:01+09:00
New Revision: 988748c0774f3acc408bd65f5e57e5da6fe8aecb

URL: https://github.com/llvm/llvm-project/commit/988748c0774f3acc408bd65f5e57e5da6fe8aecb
DIFF: https://github.com/llvm/llvm-project/commit/988748c0774f3acc408bd65f5e57e5da6fe8aecb.diff

LOG: [mlir][bufferize] Do not copy buffers with undefined contents

Buffers with undefined contents (e.g., the result of an init_tensor) are no longer copied.

Differential Revision: https://reviews.llvm.org/D125015

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 94504c1e9b104..7d80e472c90ea 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -337,6 +337,9 @@ class AnalysisState {
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
 
+  /// Return `true` if the given tensor has undefined contents.
+  virtual bool hasUndefinedContents(OpOperand *opOperand) const = 0;
+
   /// Return true if the given tensor (or an aliasing tensor) is yielded from
   /// the containing block. Also include all aliasing tensors in the same block.
   ///
@@ -410,6 +413,9 @@ class AlwaysCopyAnalysisState : public AnalysisState {
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
 
+  /// Return `true` if the given tensor has undefined contents.
+  bool hasUndefinedContents(OpOperand *opOperand) const override;
+
   /// Return true if the given tensor (or an aliasing tensor) is yielded from
   /// the containing block. Also include all aliasing tensors in the same block.
   bool isTensorYielded(Value tensor) const override;

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 2a6020118c2f7..22a7e0c402ba6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -166,10 +166,17 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
 
+  /// Return `true` if the given tensor has undefined contents.
+  bool hasUndefinedContents(OpOperand *opOperand) const override;
+
   /// Return true if the given tensor (or an aliasing tensor) is yielded from
   /// the containing block. Also include all aliasing tensors in the same block.
   bool isTensorYielded(Value tensor) const override;
 
+  /// Find all tensor values in the given operation that have undefined contents
+  /// and store them in `undefinedTensorUses`.
+  void gatherUndefinedTensorUses(Operation *op);
+
   /// Find all tensors that are yielded/returned from a block and store them in
   /// `yieldedTensors`. Also include all aliasing tensors in the same block.
   void gatherYieldedTensors(Operation *op);
@@ -182,6 +189,9 @@ class OneShotAnalysisState : public AnalysisState {
   /// A set of all tensors (and maybe aliasing tensors) that yielded from a
   /// block.
   DenseSet<Value> yieldedTensors;
+
+  /// A set of uses of tensors that have undefined contents.
+  DenseSet<OpOperand *> undefinedTensorUses;
 };
 
 /// Analyze `op` and its nested ops. Bufferization decisions are stored in

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 73da6c85e761f..ed1ff6a2c021f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -303,18 +303,8 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
       rewriter, loc, operandBuffer, dealloc && getOptions().createDeallocs);
   if (failed(resultBuffer))
     return failure();
-  // Do not copy if the last preceding writes of `operand` are ops that do
-  // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
-  // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
-  // use-def chain, it returns that value, regardless of whether it is a
-  // memory write or not.
-  SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand);
-  if (llvm::none_of(lastWrites, [&](Value lastWrite) {
-        if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
-          return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
-                                              analysisState);
-        return true;
-      }))
+  // Do not copy the buffer if its contents are undefined.
+  if (analysisState.hasUndefinedContents(&opOperand))
     return resultBuffer;
   // Do not copy if the copied data is never read.
   if (!aliasingOpResults.empty() &&
@@ -407,6 +397,12 @@ bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
   return false;
 }
 
+/// Return `true` if the given tensor has undefined contents.
+bool AlwaysCopyAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
+  // There is no analysis, so the conservative answer is "false".
+  return false;
+}
+
 /// Return true if the given tensor (or an aliasing tensor) is yielded from
 /// the containing block. Also include all aliasing tensors in the same block.
 bool AlwaysCopyAnalysisState::isTensorYielded(Value tensor) const {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 593cbf0bb43ab..704a5164c543f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -249,6 +249,43 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
   });
 }
 
+void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
+  op->walk([&](Operation *op) {
+    // Skip unknown ops.
+    auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
+    if (!bufferizableOp)
+      return WalkResult::skip();
+
+    // Check all tensor OpResults.
+    for (OpResult opResult : op->getOpResults()) {
+      if (!opResult.getType().isa<TensorType>())
+        continue;
+
+      // If there is no preceding memory write, the tensor contents are
+      // undefined.
+      // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
+      // use-def chain, it returns that value, regardless of whether it is a
+      // memory write or not.
+      SetVector<Value> lastWrites = findLastPrecedingWrite(opResult);
+      bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) {
+        if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite))
+          return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
+                                              *this);
+        return true;
+      });
+      if (isUndefined)
+        for (OpOperand &use : opResult.getUses())
+          undefinedTensorUses.insert(&use);
+    }
+
+    return WalkResult::advance();
+  });
+}
+
+bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
+  return undefinedTensorUses.contains(opOperand);
+}
+
 bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
   return yieldedTensors.contains(tensor);
 }
@@ -915,8 +952,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
         failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps));
   }
 
-  // Gather all yielded tensors.
+  // Gather some extra analysis data.
   state.gatherYieldedTensors(op);
+  state.gatherUndefinedTensorUses(op);
 
   // Analysis verification: After setting up alias/equivalence sets, each op
   // can check for expected invariants/limitations and fail the analysis if

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 7a1072c75d234..dc4135560778f 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -359,3 +359,19 @@ func.func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x
   return %3 : tensor<?x1x6x8xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @do_not_copy_init_tensors(
+func.func @do_not_copy_init_tensors(%f1: f32, %f2: f32, %idx: index)
+  -> (tensor<5xf32>, tensor<5xf32>)
+{
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  // CHECK-NOT: copy
+  // CHECK: memref.store
+  // CHECK: memref.store
+  %0 = linalg.init_tensor [5] : tensor<5xf32>
+  %1 = tensor.insert %f1 into %0[%idx] : tensor<5xf32>
+  %2 = tensor.insert %f2 into %0[%idx] : tensor<5xf32>
+  return %1, %2 : tensor<5xf32>, tensor<5xf32>
+}


        


More information about the Mlir-commits mailing list