[Mlir-commits] [mlir] 5f1a1af - [mlir][Linalg] Properly order extract_slice traversal in comprehensive bufferization

Nicolas Vasilache llvmlistbot at llvm.org
Fri Sep 10 00:10:11 PDT 2021


Author: Nicolas Vasilache
Date: 2021-09-10T07:10:06Z
New Revision: 5f1a1af4bfb1314081e259939ff313eade72aeab

URL: https://github.com/llvm/llvm-project/commit/5f1a1af4bfb1314081e259939ff313eade72aeab
DIFF: https://github.com/llvm/llvm-project/commit/5f1a1af4bfb1314081e259939ff313eade72aeab.diff

LOG: [mlir][Linalg] Properly order extract_slice traversal in comprehensive bufferization

This revision fixes the traversal order of extract_slice during the inplace analysis.
It was previously thought that such ops could be analyzed at the very end.
This is unfortunately not true as the AliasInfo for dependents of these ops need to be updated.

This change allows the aliases introduced by the bufferization of extract_slice to be properly propagated.

Differential Revision: https://reviews.llvm.org/D109519

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index b19fca41af77..cb7f0c39d304 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -130,7 +130,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 
-#define DEBUG_TYPE "comprehensive-func-bufferize"
+#define DEBUG_TYPE "comprehensive-module-bufferize"
 
 using namespace mlir;
 using namespace linalg;
@@ -140,8 +140,8 @@ using namespace tensor;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
 
 // Forward declarations.
-static std::string printOperationInfo(Operation *);
-static std::string printValueInfo(Value);
+static std::string printOperationInfo(Operation *, bool prefix = true);
+static std::string printValueInfo(Value, bool prefix = true);
 
 //===----------------------------------------------------------------------===//
 // Generic helpers.
@@ -365,30 +365,31 @@ static void printTensorOrBufferInfo(std::string prefix, Value value,
 }
 
 /// Print the operation name and bufferization information.
-static std::string printOperationInfo(Operation *op) {
+static std::string printOperationInfo(Operation *op, bool prefix) {
   std::string result;
   llvm::raw_string_ostream os(result);
   AsmState state(op->getParentOfType<mlir::FuncOp>());
-  os << op->getName();
+  StringRef tab = prefix ? "\n[" DEBUG_TYPE "]\t" : "";
+  os << tab << op->getName();
   SmallVector<Value> shapedOperands;
   for (OpOperand &opOperand : op->getOpOperands()) {
     std::string prefix =
-        llvm::formatv("\n\t-> #{0} ", opOperand.getOperandNumber());
+        llvm::formatv("{0}  -> #{1} ", tab, opOperand.getOperandNumber());
     printTensorOrBufferInfo(prefix, opOperand.get(), state, os);
   }
   for (OpResult opResult : op->getOpResults()) {
     std::string prefix =
-        llvm::formatv("\n\t<- #{0} ", opResult.getResultNumber());
+        llvm::formatv("{0}  <- #{1} ", tab, opResult.getResultNumber());
     printTensorOrBufferInfo(prefix, opResult, state, os);
   }
   return result;
 }
 
 /// Print the bufferization information for the defining op or block argument.
-static std::string printValueInfo(Value value) {
+static std::string printValueInfo(Value value, bool prefix) {
   auto *op = value.getDefiningOp();
   if (op)
-    return printOperationInfo(op);
+    return printOperationInfo(op, prefix);
   // Print the block argument bufferization information.
   std::string result;
   llvm::raw_string_ostream os(result);
@@ -552,6 +553,7 @@ static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
       .Case([&](scf::ForOp op) {
         return &op.getIterOpOperands()[result.getResultNumber()];
       })
+      .Case([&](InitTensorOp op) { return nullptr; })
       .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); })
       .Case([&](LinalgOp op) {
         return op.getOutputTensorOperands()[result.getResultNumber()];
@@ -580,7 +582,7 @@ static Optional<OpResult> getAliasingOpResult(OpOperand &opOperand) {
     return None;
   return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
       // These terminators legitimately have no result.
-      .Case<ReturnOp, linalg::YieldOp, scf::YieldOp>(
+      .Case<ReturnOp, linalg::InitTensorOp, linalg::YieldOp, scf::YieldOp>(
           [&](auto op) { return OpResult(); })
       // DimOp has no tensor result.
       .Case<tensor::DimOp>([&](auto op) { return None; })
@@ -759,10 +761,12 @@ class BufferizationAliasInfo {
   void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
 
   /// Print to `os`.
-  void print(raw_ostream &os) const;
+  void printAliases(raw_ostream &os) const;
+  void printEquivalences(raw_ostream &os) const;
 
   /// Print to `errs()`.
-  void dump() const { print(llvm::errs()); }
+  void dumpAliases() const { printAliases(llvm::errs()); }
+  void dumpEquivalences() const { printEquivalences(llvm::errs()); }
 
 private:
   /// Check that aliasInfo for `v` exists and return a reference to it.
@@ -954,10 +958,12 @@ void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
   setInPlaceOpResult(result, InPlaceSpec::True);
   if (mergeAliases(result, operand.get()))
     mergeAliasesToFixedPoint();
+  // Dump the updated alias analysis.
+  LLVM_DEBUG(dumpAliases());
   if (bufferRelation == BufferRelation::Equivalent)
     equivalentInfo.unionSets(result, operand.get());
-  // Dump the updated analysis.
-  LLVM_DEBUG(dump());
+  // Dump the updated equivalence analysis.
+  LLVM_DEBUG(dumpEquivalences());
 }
 
 /// Set the inPlace bufferization spec to false.
@@ -984,7 +990,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
   Operation *opToBufferize = result.getDefiningOp();
   Value root = (*maybeAliasingOperand)->get();
   LDBG("----Start wouldCreateReadAfterWriteInterference\n");
-  LDBG("--------rootValue: " << printValueInfo(root) << "\n");
+  LDBG("--------aliasing rootValue: " << printValueInfo(root) << "\n");
 
   // Collect:
   //   1. all the inplace write uses of some alias of `root`.
@@ -1046,10 +1052,10 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
       // At this point, aliasingWriteOp properly dominates aliasingReadOp or
       // there is no clear dominance and we need to be conservative.
       LDBG("---->found RaW interference\n");
-      LDBG("     Interfering  read -> #" << uRead->getOperandNumber() << ":\n"
+      LDBG("     Interfering  read -> #" << uRead->getOperandNumber() << ":"
                                          << printOperationInfo(aliasingReadOp)
                                          << '\n');
-      LDBG("     Interfering write -> #" << uWrite->getOperandNumber() << ":\n"
+      LDBG("     Interfering write -> #" << uWrite->getOperandNumber() << ":"
                                          << printOperationInfo(aliasingWriteOp)
                                          << '\n');
       LDBG("---->opportunity to clobber RaW interference\n");
@@ -1098,28 +1104,34 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
   }
 }
 
-void BufferizationAliasInfo::print(raw_ostream &os) const {
+void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
   os << "\n/========================== AliasInfo "
         "==========================\n";
   for (auto it : aliasInfo) {
-    os << "|\n| -- source: " << printValueInfo(it.getFirst()) << '\n';
+    os << "|\n| -- source: " << printValueInfo(it.getFirst(), /*prefix=*/false)
+       << '\n';
     for (auto v : it.getSecond())
-      os << "| ---- target: " << printValueInfo(v) << '\n';
+      os << "| ---- target: " << printValueInfo(v, /*prefix=*/false) << '\n';
   }
   os << "|\n\\====================== End AliasInfo "
         "======================\n\n";
+}
+
+void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const {
   os << "\n/********************* Equivalent Buffers *********************\n";
   for (auto it = equivalentInfo.begin(), eit = equivalentInfo.end(); it != eit;
        ++it) {
     if (!it->isLeader())
       continue;
     Value leader = it->getData();
-    os << "|\n| -- leader: " << printValueInfo(leader) << '\n';
+    os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false)
+       << '\n';
     for (auto mit = equivalentInfo.member_begin(it),
               meit = equivalentInfo.member_end();
          mit != meit; ++mit) {
       Value v = static_cast<Value>(*mit);
-      os << "| ---- equivalent member: " << printValueInfo(v) << '\n';
+      os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false)
+         << '\n';
     }
   }
   os << "|\n\\***************** End Equivalent Buffers *****************\n\n";
@@ -1195,12 +1207,13 @@ bool BufferizationAliasInfo::existsInterleavedValueClobber(
     auto leaderIt = equivalentInfo.findLeader(valueToClobber);
     for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
          ++mit) {
-      /// Note: the "would write to memory after bufferization" condition is
-      /// verified by `candidateOp` since it would produce a value that
-      /// bufferizes to an equivalent buffer.
       Operation *candidateOp = mit->v.getDefiningOp();
       if (!candidateOp)
         continue;
+      auto maybeAliasingOperand = getAliasingOpOperand(mit->v.cast<OpResult>());
+      if (!maybeAliasingOperand || !*maybeAliasingOperand ||
+          !bufferizesToMemoryWrite(**maybeAliasingOperand))
+        continue;
       LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp)
                                          << '\n');
       if (domInfo.properlyDominates(aliasingWriteOp, candidateOp) &&
@@ -2311,7 +2324,12 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
   return success();
 }
 
-/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
+/// Analyze the `funcOp` body to determine which OpResults are inplaceable:
+///   1. First, analyze InsertSliceOp greedily: we almost never want to
+///      bufferize the tensor "inserted into" to become out-of-place.
+///   2. Walk the other ops in reverse. This is a good starter heuristic.
+///      ExtractSliceOps are interleaved with other ops in traversal order.
+///
 static LogicalResult
 inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
                           const DominanceInfo &domInfo) {
@@ -2321,26 +2339,22 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
          "expected a funcOp definition with a body");
 
   // Collect ops so we can build our own traversal.
-  SmallVector<ExtractSliceOp> extractSliceOps;
+  SmallVector<Operation *> otherOps;
   SmallVector<InsertSliceOp> insertSliceOps;
-  SmallVector<Operation *> nonSliceOps;
   funcOp.walk([&](Operation *op) {
-    if (auto extractSliceOp = dyn_cast<ExtractSliceOp>(op))
-      return extractSliceOps.push_back(extractSliceOp);
     if (auto insertSliceOp = dyn_cast<InsertSliceOp>(op))
       return insertSliceOps.push_back(insertSliceOp);
     // No tensors => no buffers.
     if (none_of(op->getOperandTypes(), isaTensor) &&
         none_of(op->getResultTypes(), isaTensor))
       return;
-    nonSliceOps.push_back(op);
+    otherOps.push_back(op);
   });
 
-  // Bufferize InsertSliceOp greedily: we almost never want to bufferize
+  // First, analyze InsertSliceOp greedily: we almost never want to bufferize
   // the tensor "inserted into" to become out-of-place. This implementation
   // does not distinguish between 
diff erent InsertSliceOp. If we want
   // finer-grained behavior, we could order the InsertSliceOp with some metric.
-  // Walk InsertSliceOp in reverse for better interference behavior.
   for (InsertSliceOp insertSliceOp : reverse(insertSliceOps)) {
     OpOperand &destOpOperand = insertSliceOp->getOpOperand(1);
     if (failed(bufferizableInPlaceAnalysis(
@@ -2349,23 +2363,27 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
       return failure();
   }
 
-  // Analyze all ops that return a tensors, except ExtractSliceOp and
-  // InsertSliceOp which are handled separately.
-  // Walk other ops in reverse for better interference behavior.
-  for (Operation *op : reverse(nonSliceOps))
-    for (OpOperand &opOperand : op->getOpOperands())
+  // Walk ops in reverse for better interference analysis.
+  for (Operation *op : reverse(otherOps)) {
+    for (OpOperand &opOperand : op->getOpOperands()) {
       if (OpResult result = getInplaceableOpResult(opOperand))
         if (result.getType().isa<TensorType>() &&
             failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
                                                domInfo)))
           return failure();
-
-  // Finally, bufferize ExtractSliceOp.
-  // Walk ExtractSliceOps in reverse for better clobbering behavior: it is
-  // easier to detect clobbers of smaller slices before larger ones.
-  for (ExtractSliceOp extractSliceOp : reverse(extractSliceOps))
-    if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
-      return failure();
+    }
+    // Special logic to analyze ExtractSliceOp.
+    // Note that ExtractSliceOp analysis needs to be interleaved with other ops
+    // to properly capture aliases.
+    // Walk ExtractSliceOps in reverse for better clobbering analysis behavior:
+    // it is easier to detect clobbers of smaller slices before larger ones.
+    if (auto extractSliceOp = dyn_cast<ExtractSliceOp>(op)) {
+      if (failed(
+              bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
+        return failure();
+      continue;
+    }
+  }
 
   LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
 

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 2c2a14ead0f3..1fa9ae6dbd41 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -247,20 +247,20 @@ func @extract_slice_to_linalg_write_use(
     %C : tensor<?x?xf32> {linalg.inplaceable = true})
   ->  (tensor<4x4xf32>, tensor<4x4xf32>)
 {
-  // Step 3. %sB forward propagates to a write in %D but it is not inplace.
+  // Step 4. %sB forward propagates to a write in %D but it is not inplace.
   // So this is only ever read and can bufferize inplace.
   //     CHECK: tensor.extract_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %sB = tensor.extract_slice %B[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
 
-  // Step 2. %sB has a read interference in %E, it does not bufferize inplace.
+  // Step 3. %sB has a read interference in %E, it does not bufferize inplace.
   //     CHECK: linalg.matmul
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %D = linalg.matmul  ins(%B, %C: tensor<?x?xf32>, tensor<?x?xf32>)
                      outs(%sB: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
-  // Step 4. %sC forward propagates to an inplace write in %E.
+  // Step 2. %sC forward propagates to an inplace write in %E.
   // %sC backward propagates to %C which is inplaceable.
   // As a consequence this is bufferized inplace.
   //     CHECK: tensor.extract_slice
@@ -298,7 +298,7 @@ func @extract_slice_to_linalg_write_use(
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %sB = tensor.extract_slice %B[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
 
-  // Step 1. %sB backprops to the tensor.extract_slice producer which is not
+  // Step 3. %sB backprops to the tensor.extract_slice producer which is not
   // considered an interference. This bufferizes inplace.
   //     CHECK: linalg.matmul
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
@@ -306,7 +306,7 @@ func @extract_slice_to_linalg_write_use(
                      outs(%sB: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
-  // Step 3. %sC forward propagates to an inplace write in %E.
+  // Step 2. %sC forward propagates to an inplace write in %E.
   // %sC backward propagates to %C which is inplaceable.
   // As a consequence this is bufferized inplace.
   //     CHECK: tensor.extract_slice
@@ -482,7 +482,7 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
                    %lb : index, %ub : index, %step : index)
   -> (tensor<?xf32>, tensor<?xf32>)
 {
-  // %r0 must be out of place because one use of %t in the subsequent production 
+  // %r0 must be out of place because one use of %t in the subsequent production
   // of %r1 is read.
   //      CHECK: scf.for
   // CHECK-NEXT: call
@@ -503,7 +503,7 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
     scf.yield %t : tensor<?xf32>
   }
 
-  // %r2 must be out of place because one use of %t in the subsequent production 
+  // %r2 must be out of place because one use of %t in the subsequent production
   // of %r3 is read.
   //      CHECK: linalg.tiled_loop
   // CHECK-NEXT: call
@@ -619,3 +619,86 @@ func @read_dependence_through_scf_and_call(
   call @bar(%B2) : (tensor<64xf32>) -> ()
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Transitive cases through extract_slice.
+//===----------------------------------------------------------------------===//
+
+builtin.func @matmul_on_tensors(
+    %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
+    %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
+    %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true})
+    -> tensor<256x256xf32>
+{
+  %c0 = constant 0 : index
+  %cst_0 = constant 0.000000e+00 : f32
+  %cst_1 = constant 1.000000e+00 : f32
+
+  %7 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
+
+  //      CHECK: linalg.fill
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
+  //      CHECK: linalg.fill
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
+  %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: linalg.matmul
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  %sA = tensor.extract_slice %8[0, 0][256, 16][1, 1]: tensor<256x256xf32> to tensor<256x16xf32>
+  %sB = tensor.extract_slice %11[0, 0][16, 256][1, 1]: tensor<256x256xf32> to tensor<16x256xf32>
+  %r = linalg.matmul
+         ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
+        outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
+
+  return %r : tensor<256x256xf32>
+}
+
+// -----
+
+builtin.func @matmul_on_tensors(
+    %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
+    %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},
+    %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true})
+    -> tensor<256x256xf32>
+{
+  %c0 = constant 0 : index
+  %cst_0 = constant 0.000000e+00 : f32
+  %cst_1 = constant 1.000000e+00 : f32
+
+  %7 = linalg.init_tensor [256, 256] : tensor<256x256xf32>
+
+  //     CHECK: linalg.fill
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %8 = linalg.fill(%cst_0, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %9 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32>
+  %10 = vector.transfer_write %9, %8[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32>
+
+  //      CHECK: linalg.fill
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %11 = linalg.fill(%cst_1, %7) : f32, tensor<256x256xf32> -> tensor<256x256xf32>
+  %12 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [false, true]} : tensor<518x518xf32>, vector<256x256xf32>
+  %13 = vector.transfer_write %12, %11[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf32>, tensor<256x256xf32>
+
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  //      CHECK: linalg.matmul
+  // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
+  %sA = tensor.extract_slice %10[0, 0][256, 16][1, 1]: tensor<256x256xf32> to tensor<256x16xf32>
+  %sB = tensor.extract_slice %13[0, 0][16, 256][1, 1]: tensor<256x256xf32> to tensor<16x256xf32>
+  %r = linalg.matmul
+         ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
+        outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
+
+  return %r : tensor<256x256xf32>
+}


        


More information about the Mlir-commits mailing list