[Mlir-commits] [mlir] 2c5c5ca - [mlir][linalg][bufferize] Fix CallOp bufferization

Matthias Springer llvmlistbot at llvm.org
Tue Jan 11 03:10:34 PST 2022


Author: Matthias Springer
Date: 2022-01-11T20:10:21+09:00
New Revision: 2c5c5ca8681a2788229cde61d09129316448508b

URL: https://github.com/llvm/llvm-project/commit/2c5c5ca8681a2788229cde61d09129316448508b
DIFF: https://github.com/llvm/llvm-project/commit/2c5c5ca8681a2788229cde61d09129316448508b.diff

LOG: [mlir][linalg][bufferize] Fix CallOp bufferization

Previously, CallOps did not have any aliasing OpResult/OpOperand pairs. Therefore, CallOps were mostly ignored by the analysis and buffer copies were not inserted when necessary.

This commit introduces the following changes:
* Function bbArgs writable by default. A function can now be bufferized without inspecting its callers.
* Callers must introduce buffer copies of function arguments when necessary. If a function is external, the caller must conservatively assume that a function argument is modified by the callee after bufferization. If the function is not external, the caller inspects the callee to determine if a function argument is modified.

Differential Revision: https://reviews.llvm.org/D116457

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 2167908414319..fe5fc26c3d2ba 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -256,6 +256,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
 /// themselves (e.g., ExtractSliceOp).
 bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
     Value value) const {
+  assert(value.getType().isa<TensorType>() && "expected TensorType");
   SmallVector<OpOperand *> workingSet;
   for (OpOperand &use : value.getUses())
     workingSet.push_back(&use);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 95ecb21cf8e96..5bf26365caa6e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -6,87 +6,68 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Module bufferization is an extension of Comprehensive Bufferize that
+// Module Bufferization is an extension of Comprehensive Bufferize that
 // bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp, along with a few helper
-// functions that control the order in which functions are bufferized.
+// implementations for FuncOp, CallOp and ReturnOp.
 //
-// Three cases can occur during bufferization of FuncOps.
+// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
+// This function analyzed the given module and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
 //
-//     i. inplaceable function arguments may be reused in place after the
-//        function itself has been bufferized. This is encoded by IR resembling:
+// After analyzing a FuncOp, additional information about its bbArgs is
+// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
 //
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//           func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<?xf32> {
-//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
-//            // ... uses of %0
-//            %res = bufferization.to_tensor %0 : memref<?xf32, #map>
-//            return %res : tensor<?xf32>
-//          }
-//        ```
+// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
+//   tensor return value (if any).
+// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
+//   read/written.
 //
-//        this is the cue for the bufferization of the function foo (and calls
-//        to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
-//        To fully achieve bufferization, an additional analysis is needed to
-//        determine whether function argument/operand pairs bufferize to a
-//        single inplace buffer argument (i.e. functions may return tensors in
-//        arbitrary order that may not match argument numbers).
+// Only tensors that are equivalent to some FuncOp bbArg may be returned.
+// Bufferization currently fails if other tensors (in particular tensors that
+// bufferize out-of-place and result in a new buffer allocation) are returned.
+// In the future, such allocations could be hoisted to the caller.
 //
-//    ii. results that don't map to an inplaceable function argument are
-//        generally allocated. Since memref semantics wrt ownership of the
-//        underlying memory region are not well-defined, comprehensive
-//        bufferization chooses to perform allocations in a scoped fashion:
-//        returning memrefs is always considered illegal.
-//        Such scenarios are encoded by IR resembling:
+// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg.
+// ```
+// func @foo() -> tensor<?xf32> {
+//   %0 = linalg.init_tensor [...] : tensor<?xf32>
+//   return %0 : tensor<?xf32>
+// }
+// ```
 //
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//          func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<?xf32> {
-//            %0 = bufferization.to_memref %A : memref<?xf32, #map>
-//            %1 = memref.dim %0, %c0 : memref<?xf32, #map>
-//            %2 = memref.alloc(%1) : memref<?xf32>
-//            %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
-//            // ... uses of %3
-//            memref.dealloc %2 : memref<?xf32, #map>
-//            %res = bufferization.to_tensor %3 : memref<?xf32, #map>
-//            return %res : tensor<?xf32>
-//          }
-//       ```
+// Module Bufferization implements the following calling convention.
 //
-//        this is the cue for the bufferization of the function foo (and calls
-//        to it) that it must bufferize to `func @foo(%A: memref<?xf32,
-//        some_layout>,
-//                   %B: memref<?xf32, some_layout>)` (i.e. make a cloned
-//        allocation of the result tensor)
-//        To fully achieve bufferization, the alloc/dealloc pair must be lifted
-//        out of the function at each call site.
+// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
+//   be written to in-place.
+// * If a tensor operand of a CallOp is read after the CallOp, the operand of
+//   the CallOp must bufferize out-of-place.
 //
-//   iii. as an optimization over ii., it may be possible to reuse an argument
-//        and only want to return a slice.
-//        This may forego allocation by letting *all* callers decide whether to
-//        pass a new *aliasing* memref function argument (i.e. a subview).
-//        Without loss of generality, callers may agree to allocate a new buffer
-//        to avoid this aliasing. Such scenarios are encoded by IR resembling:
+// Example: The tensor.insert op bufferizes in-place because it is allowed to
+// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
+// out-of-place because `%t0` is modified by the callee but read by the
+// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
+// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
+// ```
+// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
+//   %f = ... : f32
+//   %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
+//   return %0 : tensor<?xf32>
+// }
 //
-//        ```
-//          #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-//          func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
-//              -> tensor<4xf32> {
-//            %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
-//            %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
-//                                                memref<4xf32, #map>
-//            // ... inplace computes into %1
-//            %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
-//            return %3 : tensor<4xf32>
-//          }
-//        ```
+// func @caller() -> () {
+//   %t0 = ... : tensor<?xf32>
+//   %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
+//   %2 = tensor.extract %1[...]  : tensor<?xf32>
+// }
+// ```
 //
-//  Note: In the future, it may be worthwhile to design special bufferization
-//  ops to encode the desired semantics at function boundaries for i., ii. and
-//  iii.
+// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
+// analyze the function body. In such a case, the CallOp analysis conservatively
+// assumes that each tensor OpOperand is both read and written.
+//
+// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
+// as "not reading" and/or "not writing".
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
 
@@ -103,6 +84,9 @@ using namespace tensor;
 using namespace comprehensive_bufferize;
 
 namespace {
+/// The state of analysis of a FuncOp.
+enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
+
 /// Extra bufferization state that is required for bufferization of function
 /// boundaries.
 struct ModuleBufferizationState : public DialectBufferizationState {
@@ -110,8 +94,22 @@ struct ModuleBufferizationState : public DialectBufferizationState {
   /// indices.
   DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
 
+  /// A set of all read BlockArguments of FuncOps.
+  // Note: BlockArgument knows about its owner, so we do not need to store
+  // FuncOps here.
+  DenseSet<BlockArgument> readBbArgs;
+
+  /// A set of all written-to BlockArguments of FuncOps.
+  DenseSet<BlockArgument> writtenBbArgs;
+
+  /// Keep track of which FuncOps are fully analyzed or currently being
+  /// analyzed.
+  DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
+
+  // A list of functions in the order in which they are analyzed + bufferized.
   SmallVector<FuncOp> orderedFuncOps;
 
+  // A mapping of FuncOps to their callers.
   DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
 };
 } // namespace
@@ -133,6 +131,17 @@ getModuleBufferizationState(BufferizationState &state) {
       StandardOpsDialect::getDialectNamespace());
 }
 
+/// Return the state (phase) of analysis of the FuncOp.
+static FuncOpAnalysisState
+getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
+  const ModuleBufferizationState &moduleState =
+      getModuleBufferizationState(state);
+  auto it = moduleState.analyzedFuncOps.find(funcOp);
+  if (it == moduleState.analyzedFuncOps.end())
+    return FuncOpAnalysisState::NotAnalyzed;
+  return it->second;
+}
+
 /// Return the unique ReturnOp that terminates `funcOp`.
 /// Return nullptr if there is no such unique ReturnOp.
 static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
@@ -197,6 +206,69 @@ struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
     return success();
   }
 };
+
+/// Return true if the buffer of the given tensor value is written to. Must not
+/// be called for values inside not yet analyzed functions. (Post-analysis
+/// steps do not have to be run yet, i.e., "in progress" is also OK.)
+static bool isValueWritten(Value value, const BufferizationState &state,
+                           const BufferizationAliasInfo &aliasInfo) {
+#ifndef NDEBUG
+  assert(value.getType().isa<TensorType>() && "expected TensorType");
+  FuncOp funcOp;
+  if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+    Operation *owner = bbArg.getOwner()->getParentOp();
+    funcOp = isa<FuncOp>(owner) ? cast<FuncOp>(owner)
+                                : owner->getParentOfType<FuncOp>();
+  } else {
+    funcOp = value.getDefiningOp()->getParentOfType<FuncOp>();
+  }
+  assert(getFuncOpAnalysisState(state, funcOp) !=
+             FuncOpAnalysisState::NotAnalyzed &&
+         "FuncOp must be fully analyzed or analysis in progress");
+#endif // NDEBUG
+
+  bool isWritten = false;
+  aliasInfo.applyOnAliases(value, [&](Value val) {
+    for (OpOperand &use : val.getUses())
+      if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use))
+        isWritten = true;
+  });
+  return isWritten;
+}
+
+/// Determine which FuncOp bbArgs are read and which are written. If this
+/// PostAnalysisStep is run on a function with unknown ops, it will
+/// conservatively assume that such ops bufferize to a read + write.
+struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
+  LogicalResult run(Operation *op, BufferizationState &state,
+                    BufferizationAliasInfo &aliasInfo,
+                    SmallVector<Operation *> &newOps) override {
+    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+    auto funcOp = cast<FuncOp>(op);
+
+    // If the function has no body, conservatively assume that all args are
+    // read + written.
+    if (funcOp.getBody().empty()) {
+      for (BlockArgument bbArg : funcOp.getArguments()) {
+        moduleState.readBbArgs.insert(bbArg);
+        moduleState.writtenBbArgs.insert(bbArg);
+      }
+
+      return success();
+    }
+
+    for (BlockArgument bbArg : funcOp.getArguments()) {
+      if (!bbArg.getType().isa<TensorType>())
+        continue;
+      if (state.isValueRead(bbArg))
+        moduleState.readBbArgs.insert(bbArg);
+      if (isValueWritten(bbArg, state, aliasInfo))
+        moduleState.writtenBbArgs.insert(bbArg);
+    }
+
+    return success();
+  }
+};
 } // namespace
 
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
@@ -575,43 +647,101 @@ namespace std_ext {
 static Optional<int64_t>
 getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
                         int64_t returnValIdx) {
-  if (!state.equivalentFuncArgs.count(funcOp))
+  auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
+  if (funcOpIt == state.equivalentFuncArgs.end())
     // No equivalence info stores for funcOp.
     return None;
 
-  const DenseMap<int64_t, int64_t> &equivFuncArgs =
-      state.equivalentFuncArgs.lookup(funcOp);
-  if (!equivFuncArgs.count(returnValIdx))
+  auto retValIt = funcOpIt->getSecond().find(returnValIdx);
+  if (retValIt == funcOpIt->getSecond().end())
     // Return value has no equivalent bbArg.
     return None;
 
-  return equivFuncArgs.lookup(returnValIdx);
+  return retValIt->getSecond();
 }
 
 struct CallOpInterface
     : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const BufferizationState &state) const {
-    // CallOpInterface alone doesn't bufferize to a memory read, one of the uses
-    // of the matching bbArg may. It is the responsibility of the caller to
-    // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
-    // conservative.
-    return true;
+    CallOp callOp = cast<CallOp>(op);
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+
+    const ModuleBufferizationState &moduleState =
+        getModuleBufferizationState(state);
+    if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
+      // FuncOp not analyzed yet. Assume that OpOperand is read.
+      return true;
+
+    return moduleState.readBbArgs.contains(
+        funcOp.getArgument(opOperand.getOperandNumber()));
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
-    return false;
+    CallOp callOp = cast<CallOp>(op);
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+
+    const ModuleBufferizationState &moduleState =
+        getModuleBufferizationState(state);
+    if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
+      // FuncOp not analyzed yet. Assume that OpOperand is written.
+      return true;
+
+    return moduleState.writtenBbArgs.contains(
+        funcOp.getArgument(opOperand.getOperandNumber()));
   }
 
   OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                const BufferizationState &state) const {
-    // CallOpInterface is special, it needs to wait for the callee to be
-    // bufferized and needs to inspect the BufferAliasInfo object. It can't
-    // make a proper determination by itself and needs to be conservative.
+    CallOp callOp = cast<CallOp>(op);
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+    const ModuleBufferizationState &moduleState =
+        getModuleBufferizationState(state);
+
+    for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
+         ++resultIdx)
+      if (Optional<int64_t> maybeArgNumber =
+              getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
+        if (*maybeArgNumber == opOperand.getOperandNumber())
+          return callOp->getOpResult(resultIdx);
+
+    // Note: Returning a non-equivalent tensor from a FuncOp is currently not
+    // supported an will fail bufferization. (Even if allow-return-memref, it
+    // will fail when the function is called.)
     return OpResult();
   }
 
+  SmallVector<OpOperand *>
+  getAliasingOpOperand(Operation *op, OpResult opResult,
+                       const BufferizationState &state) const {
+    CallOp callOp = cast<CallOp>(op);
+    FuncOp funcOp = getCalledFunction(callOp);
+    assert(funcOp && "expected CallOp to a FuncOp");
+    const ModuleBufferizationState &moduleState =
+        getModuleBufferizationState(state);
+
+    // TODO: We should be looking for aliasing block arguments here. The current
+    // condition is actually stronger than neccesary. Once we check for aliasing
+    // block arguments, we may be multiple.
+    if (Optional<int64_t> maybeArgNumber = getEquivalentFuncArgIdx(
+            funcOp, moduleState, opResult.getResultNumber()))
+      return {&op->getOpOperand(*maybeArgNumber)};
+
+    // Note: Returning a non-equivalent tensor from a FuncOp is currently not
+    // supported an will fail bufferization.
+    return {};
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo,
+                                const BufferizationState &state) const {
+    return BufferRelation::Equivalent;
+  }
+
   /// In a first approximation, all the function arguments of a FuncOp are
   /// marked inplaceable. For now, it is the responsibility of the `callOp`
   /// bufferization to allow FuncOp that are inplaceable to write inPlace.
@@ -667,11 +797,12 @@ struct CallOpInterface
               getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
         // Return operands that are equivalent to some bbArg, are not
         // returned.
-        Value buffer =
-            *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx),
-                             /*forceInPlace=*/true);
-        replacementValues[returnValIdx] = buffer;
-        newOperands[*bbArgIdx] = buffer;
+        FailureOr<Value> bufferOrFailure =
+            state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
+        if (failed(bufferOrFailure))
+          return failure();
+        replacementValues[returnValIdx] = *bufferOrFailure;
+        newOperands[*bbArgIdx] = *bufferOrFailure;
         continue;
       }
 
@@ -700,11 +831,15 @@ struct CallOpInterface
       // Retrieve buffers for tensor operands. Tensor operand buffers, who's
       // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
       // already stored in `newOperands` during Step 1.
-      Value buffer = newOperands[idx] ? newOperands[idx]
-                                      : *state.getBuffer(rewriter, opOperand,
-                                                         /*forceInPlace=*/true);
+      Value buffer = newOperands[idx];
+      if (!buffer) {
+        FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
+        if (failed(bufferOrFailure))
+          return failure();
+        buffer = *bufferOrFailure;
+      }
 
-      // Caller / callee type mistmatch is handled with a CastOp.
+      // Caller / callee type mismatch is handled with a CastOp.
       auto memRefType = bufferizedFuncType.getInput(idx);
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
@@ -782,8 +917,6 @@ struct FuncOpInterface
     auto funcOp = cast<FuncOp>(op);
     BlockArgument bbArg = value.dyn_cast<BlockArgument>();
     assert(bbArg && "expected BlockArgument");
-    const ModuleBufferizationState &moduleState =
-        getModuleBufferizationState(state);
 
     // "linalg.inplaceable" overrides other writability decisions. This is
     // currently used for testing only.
@@ -792,16 +925,8 @@ struct FuncOpInterface
             BufferizableOpInterface::kInplaceableAttrName))
       return inplaceAttr.getValue();
 
-    // In a first approximation:
-    // =========================
-    // If the function is called, we can allocate on the caller side which lets
-    // us force inplace arguments at function boundaries.
-    // TODO: do not rely on this behavior.
-    if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end())
-      return true;
-
-    // All other function arguments are not writable.
-    return false;
+    // All function arguments are writable by default.
+    return true;
   }
 
   bool isAllocationHoistingBarrier(Operation *op) const { return true; }
@@ -849,11 +974,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
                                       moduleState.callerMap)))
     return failure();
 
-  // Interestingly, all function args that are not visible outside of a module
-  // can be fully bufferized inplace by guaranteeing the CallOp is bufferized
-  // inplace. Therefore, we just bufferize funcOp as if none of its results were
-  // inplaceable, detect which operands are cloned internally and decide what to
-  // do at call sites.
+  // Collect bbArg/return value information after the analysis.
+  options->postAnalysisSteps.emplace_back(
+      std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+  options->postAnalysisSteps.emplace_back(
+      std::make_unique<FuncOpBbArgReadWriteAnalysis>());
 
   // Analyze ops.
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
@@ -861,17 +986,20 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     if (funcOp.body().empty())
       continue;
 
-    // Collect bbArg/return value information after the analysis.
-    options->postAnalysisSteps.emplace_back(
-        std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
-
-    // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+    // Now analyzing function.
+    moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
 
     // Analyze funcOp.
     if (failed(analyzeOp(funcOp, state)))
       return failure();
 
+    // Gather equivalence info for CallOps.
+    // TODO: Make this a post-analysis step.
+    equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+
+    // Mark op as fully analyzed.
+    moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
+
     // Add annotations to function arguments.
     if (options->testAnalysisOnly)
       annotateOpsWithBufferizationMarkers(funcOp, state);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 96725d16bd16c..929fc150f8946 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -630,7 +630,7 @@ func @scf_for_deps(
   // of %r1 is read.
   //      CHECK: scf.for
   // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
   // CHECK-NEXT: scf.yield
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
   //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]}
@@ -642,7 +642,7 @@ func @scf_for_deps(
   // %r1 bufferizes inplace fine.
   //      CHECK: scf.for
   // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
   // CHECK-NEXT: scf.yield
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
   //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]}
@@ -655,7 +655,7 @@ func @scf_for_deps(
   // of %r3 is read.
   //      CHECK: linalg.tiled_loop
   // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
   // CHECK-NEXT: linalg.yield
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
   //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]}
@@ -669,7 +669,7 @@ func @scf_for_deps(
   // %r3 bufferizes inplace fine.
   //      CHECK: linalg.tiled_loop
   // CHECK-NEXT: call
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
   // CHECK-NEXT: linalg.yield
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
   //      CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 05c120bcf557d..a9c2bcba865e6 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -410,7 +410,9 @@ func @main() {
 //      CHECK:   %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32>
   %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
 
-//      CHECK:   %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+//      CHECK:   %[[alloc:.*]] = memref.alloc
+//      CHECK:   %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+//      CHECK:   linalg.copy(%[[A]], %[[alloc]])
 //      CHECK:   call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
   call @some_external_func(%A) : (tensor<4xi32>) -> ()
 
@@ -430,7 +432,9 @@ func @main() {
 //      CHECK:   %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32>
   %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
 
-//      CHECK:   %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+//      CHECK:   %[[alloc:.*]] = memref.alloc
+//      CHECK:   %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
+//      CHECK:   linalg.copy(%[[A]], %[[alloc]])
 //      CHECK:   call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
   scf.execute_region {
     call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> ()
@@ -488,16 +492,19 @@ func @bar(
     %lb : index, %ub : index, %step : index)
   -> (tensor<?xf32>, tensor<?xf32>)
 {
-// CHECK-NEXT:   call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
+//      CHECK:   call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
   %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) :
       (tensor<?xf32>, tensor<?xf32>, tensor<4xf32>, index, index, index)
         -> (tensor<?xf32>, tensor<?xf32>)
 
-  // %r0#0 is actually %B after inplaceable results are swapped in the callee.
-// CHECK-NEXT:   call @some_external_func(%[[B]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
+  // %r0#0 requires a copy because we have no idea what the function is doing.
+//      CHECK:   %[[alloc:.*]] = memref.alloc
+//      CHECK:   %[[casted:.*]] = memref.cast %[[alloc]]
+//      CHECK:   linalg.copy(%[[B]], %[[alloc]])
+// CHECK-NEXT:   call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
   call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
 
-// CHECK-NEXT:   return
+//      CHECK:   return
   return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -745,8 +752,21 @@ func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -
 func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false},
             %B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false},
             %C : tensor<?xf32> {linalg.inplaceable = false}) {
-// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
-// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
+// Note: `callee` does not write to its bbArg directly, but `external_func`
+// does. Inside `callee`, the writes via `external_func` do not cause a
+// conflict. However, inside `entry`, the writes do cause a conflict because
+// %A, %B and %C are not inplaceable. This test case shows that this kind of
+// conflict detection has a "transitive" nature.
+//      CHECK: %[[ALLOC_C:.*]] = memref.alloc
+//      CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
+//      CHECK: %[[ALLOC_B:.*]] = memref.alloc
+//      CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
+//      CHECK: %[[ALLOC_A:.*]] = memref.alloc
+//      CHECK: linalg.copy(%[[A]], %[[ALLOC_A]])
+//      CHECK: linalg.copy(%[[B]], %[[ALLOC_B]])
+//      CHECK: linalg.copy(%[[C]], %[[ALLOC_C]])
+//      CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
   call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
   return
 }
@@ -992,9 +1012,10 @@ func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
 func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
                             %c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
   %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
-    // TODO: There should be a memory copy here. This is a bug in CallOp
-    // bufferization.
-    // CHECK: call @inner_func_2(%[[arg0]])
+    // CHECK: %[[alloc:.*]] = memref.alloc
+    // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
+    // CHECK: linalg.copy(%[[arg0]], %[[alloc]])
+    // CHECK: call @inner_func_2(%[[casted]])
     %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
     scf.yield %t1 : tensor<?xf32>
   }


        


More information about the Mlir-commits mailing list