[Mlir-commits] [mlir] cb4d0bf - [mlir][linalg][bufferize][NFC] Collect equivalent FuncOp BBArgs in PostAnalysisStep

Matthias Springer llvmlistbot at llvm.org
Mon Dec 6 00:32:26 PST 2021


Author: Matthias Springer
Date: 2021-12-06T17:31:39+09:00
New Revision: cb4d0bf9976ccb0da2850ea894254635f96fa9c2

URL: https://github.com/llvm/llvm-project/commit/cb4d0bf9976ccb0da2850ea894254635f96fa9c2
DIFF: https://github.com/llvm/llvm-project/commit/cb4d0bf9976ccb0da2850ea894254635f96fa9c2.diff

LOG: [mlir][linalg][bufferize][NFC] Collect equivalent FuncOp BBArgs in PostAnalysisStep

Collect equivalent BBArgs right after the equivalence analysis of the FuncOp and before bufferizing. This is in preparation of decoupling bufferization from aliasInfo.

Also gather equivalence info for CallOps, which was missing in the
previous commit.

Differential Revision: https://reviews.llvm.org/D114847

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 2d690b5f1045e..75ca131ff6d34 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -71,6 +71,8 @@ struct PostAnalysisStep {
                             SmallVector<Operation *> &newOps) = 0;
 };
 
+using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
+
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
   BufferizationOptions();
@@ -107,7 +109,7 @@ struct BufferizationOptions {
   bool testAnalysisOnly = false;
 
   /// Registered post analysis steps.
-  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
+  PostAnalysisStepList postAnalysisSteps;
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index cd6b5268f442f..aa9cc9ced9c70 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -18,13 +18,16 @@ namespace comprehensive_bufferize {
 
 struct BufferizationOptions;
 struct BufferizationState;
+struct PostAnalysisStep;
 
 /// Bufferize the given function. Does not bufferize the function boundary.
+/// Reuses an existing BufferizationState object.
 // TODO: This function is meant to be called from ModuleBufferize and not can
 // not yet be called standalone.
-LogicalResult runComprehensiveBufferize(FuncOp funcOp,
-                                        const BufferizationOptions &options,
-                                        BufferizationState &state);
+LogicalResult runComprehensiveBufferize(
+    FuncOp funcOp, const BufferizationOptions &options,
+    BufferizationState &state,
+    const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 23da486b34f69..74d0e0b33d84c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -726,7 +726,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     FuncOp funcOp, const BufferizationOptions &options,
-    BufferizationState &state) {
+    BufferizationState &state, const PostAnalysisStepList &extraSteps) {
 
   DominanceInfo domInfo(funcOp);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -744,16 +744,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     return failure();
   equivalenceAnalysis(op, aliasInfo);
 
-  for (const std::unique_ptr<PostAnalysisStep> &step :
-       options.postAnalysisSteps) {
-    SmallVector<Operation *> newOps;
-    if (failed(step->run(funcOp, state, newOps)))
-      return failure();
-    // Analyze ops that were created by the PostAnalysisStep.
-    if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
-      return failure();
-    equivalenceAnalysis(newOps, aliasInfo);
-  }
+  auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
+    for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
+      SmallVector<Operation *> newOps;
+      if (failed(step->run(funcOp, state, newOps)))
+        return failure();
+      // Analyze ops that were created by the PostAnalysisStep.
+      if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+        return failure();
+      equivalenceAnalysis(newOps, aliasInfo);
+    }
+    return success();
+  };
+
+  if (failed(runPostAnalysisSteps(extraSteps)))
+    return failure();
+  if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
+    return failure();
 
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly) {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 4b014c9198dc8..e65b1eb441a33 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -33,8 +33,9 @@ struct ModuleBufferizationState : public DialectBufferizationState {
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
 
-  /// A mapping of return values to equivalent BlockArguments.
-  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
+  /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
+  /// indices.
+  DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
 };
 } // namespace
 
@@ -44,6 +45,70 @@ getModuleBufferizationState(BufferizationState &state) {
       StandardOpsDialect::getDialectNamespace());
 }
 
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+  ReturnOp returnOp;
+  for (Block &b : funcOp.body()) {
+    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+      if (returnOp)
+        return nullptr;
+      returnOp = candidateOp;
+    }
+  }
+  return returnOp;
+}
+
+namespace {
+/// Store function BlockArguments that are equivalent to a returned value in
+/// ModuleBufferizationState.
+struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
+  /// Annotate IR with the results of the analysis. For testing purposes only.
+  static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
+    const char *kEquivalentArgsAttr = "__equivalent_func_args__";
+    Operation *op = returnVal.getOwner();
+
+    SmallVector<int64_t> equivBbArgs;
+    if (op->hasAttr(kEquivalentArgsAttr)) {
+      auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+      equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
+        return a.cast<IntegerAttr>().getValue().getSExtValue();
+      }));
+    } else {
+      equivBbArgs.append(op->getNumOperands(), -1);
+    }
+    equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
+
+    OpBuilder b(op->getContext());
+    op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+  }
+
+  LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    SmallVector<Operation *> &newOps) override {
+    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+
+    // Support only single return-terminated block in the function.
+    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+    assert(returnOp && "expected func with single return op");
+
+    for (OpOperand &returnVal : returnOp->getOpOperands())
+      if (returnVal.get().getType().isa<RankedTensorType>())
+        for (BlockArgument bbArg : funcOp.getArguments())
+          if (bbArg.getType().isa<RankedTensorType>())
+            if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
+                                                              bbArg)) {
+              moduleState
+                  .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
+                  bbArg.getArgNumber();
+              if (state.options.testAnalysisOnly)
+                annotateReturnOp(returnVal, bbArg);
+            }
+
+    return success();
+  }
+};
+} // namespace
+
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
 /// If `value` is a memref::CastOp, return its source. Otherwise, return
@@ -73,20 +138,6 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  ReturnOp returnOp;
-  for (Block &b : funcOp.body()) {
-    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
-}
-
 /// Return the FunctionType with `argumentTypes` and `resultTypes` where each
 /// tensor is replaced by the corresponding buffer type.
 /// In order for all the callers to agree, this *must* bufferize to the most
@@ -128,22 +179,30 @@ static FunctionType getOrCreateBufferizedFunctionType(
   return it2.first->second;
 }
 
-/// Store function BlockArguments that are equivalent to a returned value in
-/// the given ModuleBufferizationState.
-static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
-                                           BufferizationState &state) {
-  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
-
-  // Support only single return-terminated block in the function.
-  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
+/// Gather equivalence info of CallOps.
+/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
+// TODO: This does not handle cyclic function call graphs etc.
+static void equivalenceAnalysis(FuncOp funcOp,
+                                BufferizationAliasInfo &aliasInfo,
+                                ModuleBufferizationState &moduleState) {
+  funcOp->walk([&](CallOp callOp) {
+    FuncOp calledFunction = getCalledFunction(callOp);
+    assert(calledFunction && "could not retrieved called FuncOp");
+
+    // No equivalence info available for the called function.
+    if (!moduleState.equivalentFuncArgs.count(calledFunction))
+      return WalkResult::skip();
+
+    for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
+      int64_t returnIdx = it.first;
+      int64_t bbargIdx = it.second;
+      Value returnVal = callOp.getResult(returnIdx);
+      Value argVal = callOp->getOperand(bbargIdx);
+      aliasInfo.unionEquivalenceClasses(returnVal, argVal);
+    }
 
-  for (Value returnVal : returnOp.operands())
-    if (returnVal.getType().isa<RankedTensorType>())
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (bbArg.getType().isa<RankedTensorType>())
-          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
-            moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
+    return WalkResult::advance();
+  });
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
@@ -217,7 +276,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    if (moduleState.equivalentReturnValToBBArg.count(returnVal))
+    if (moduleState.equivalentFuncArgs[funcOp].count(
+            returnOperand.getOperandNumber()))
       continue;
 
     // Cast values at the call site if necessary.
@@ -493,12 +553,12 @@ struct CallOpInterface
         }
 
         // If return operand is equivalent to some bbArg, no need to return it.
-        Value returnVal = returnOperand.get();
-        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
-          BlockArgument bbArg =
-              moduleState.equivalentReturnValToBBArg[returnVal];
+        if (moduleState.equivalentFuncArgs[funcOp].count(
+                returnOperand.getOperandNumber())) {
+          int64_t idx =
+              moduleState
+                  .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
-          int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
           // Add CallOp operand/result equivalence: this is interprocedural
           // info.
@@ -661,6 +721,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     return failure();
 
   BufferizationState state(moduleOp, options);
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // Interestingly, all function args that are not visible outside of a module
@@ -692,11 +753,17 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
         aliasInfo.setBufferizesToWritableMemory(bbArg);
     }
 
+    // Register extra post analysis steps. These cannot be stored in `options`
+    // because `options` is immutable.
+    PostAnalysisStepList extraSteps;
+    extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+
+    // Gather equivalence info for CallOps.
+    equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+
     // Analyze and bufferize funcOp.
-    if (failed(runComprehensiveBufferize(funcOp, options, state)))
+    if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
       return failure();
-
-    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 2e2792b1146cd..6e82f65cc905a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -40,15 +40,17 @@ func @insert_slice_fun(
   -> (tensor<?xf32>, tensor<?xf32>)
 {
   // must bufferize out of place.
-  //     CHECK: tensor.insert_slice
+  //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r0 = tensor.insert_slice %C into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   // bufferizes inplace.
-  //     CHECK: tensor.insert_slice
+  //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %r1 = tensor.insert_slice %C into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -81,6 +83,8 @@ func @conflict_on_B(
                      outs(%B: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, -1, 1]}
   return %C, %D, %E: tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>
 }
 
@@ -136,6 +140,8 @@ func @insert_slice_insert_slice(
   // CHECK: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -172,6 +178,8 @@ func @extract_slice_nonmatching_insert_slice(
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -208,6 +216,8 @@ func @extract_slice_matching_insert_slice(
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -234,6 +244,9 @@ func @read_of_matching_insert_slice_source(
   %2 = tensor.insert_slice %1 into %A[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
 
   %3 = vector.transfer_read %1[%idx2], %cst2 : tensor<?xf32>, vector<5xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %2, %3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -274,6 +287,8 @@ func @read_of_matching_insert_slice_source_interleaved(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %6 = tensor.insert_slice %5 into %2[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %6, %3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -306,6 +321,8 @@ func @extract_slice_linalg_readonly_use(
                      outs(%C: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 2]}
   return %D, %E: tensor<4x4xf32>, tensor<4x4xf32>
 }
 
@@ -372,6 +389,8 @@ func @insert_slice_double_extract_slice(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<?x?xf32> into tensor<30x20xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [6]}
   return %20 : tensor<30x20xf32>
 }
 
@@ -504,6 +523,8 @@ func @nested_extract_slice_and_insert(
   %rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
   %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1, 2]}
   return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
 
@@ -533,6 +554,8 @@ func @scf_for_yield_only(%A : tensor<?xf32>,
     scf.yield %t : tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -564,6 +587,8 @@ func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
     scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -623,6 +648,8 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
     linalg.yield %t : tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -768,6 +795,8 @@ builtin.func @matmul_on_tensors(
          ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
         outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2]}
   return %r : tensor<256x256xf32>
 }
 
@@ -813,6 +842,8 @@ builtin.func @matmul_on_tensors(
          ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
         outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2]}
   return %r : tensor<256x256xf32>
 }
 
@@ -858,6 +889,8 @@ func @insert_slice_chain(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [4]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -883,6 +916,9 @@ func @ip(%t: tensor<10x20xf32> {linalg.inplaceable = true},
     %t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32>
     scf.yield %t3 : tensor<10x20xf32>
   }
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
  return %r : tensor<10x20xf32>
 }
 
@@ -910,6 +946,9 @@ func @linalg_op_same_out_tensors(
       ^bb(%0: f32, %1: f32, %2 : f32) :
         linalg.yield %0, %0 : f32, f32
     } -> (tensor<?xf32>, tensor<?xf32>)
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [1, -1]}
   return %o#0, %o#1 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -951,6 +990,8 @@ func @double_insert_slice_into_alias(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<?x?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2, -1]}
   return %8, %15 : tensor<62x90xf32>, tensor<?x?xf32>
 }
 
@@ -980,6 +1021,8 @@ func @interleaved_extract_insert_slice_chain_1(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1009,6 +1052,8 @@ func @interleaved_extract_insert_slice_chain_2(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1031,6 +1076,8 @@ func @extract_once_insert_twice(
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %2 into %8[15, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1132,6 +1179,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
         linalg.yield %cst : f32
     } -> (tensor<?xf32>)
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %o, %v3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -1160,6 +1209,9 @@ func @buffer_forwarding_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = true
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 0]}
   return %2, %3 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -1180,6 +1232,9 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 0]}
   return %2, %2 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -1214,6 +1269,8 @@ func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
     %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
     scf.yield %t2 : tensor<?xf32>
   }
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r : tensor<?xf32>
 }
 
@@ -1263,6 +1320,9 @@ func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
     scf.yield %r : tensor<?xf32>
   }
   %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
 }
 
@@ -1288,6 +1348,9 @@ func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1318,6 +1381,9 @@ func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
     %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
     scf.yield %t3 : tensor<?xf32>
   }
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r : tensor<?xf32>
 }
 
@@ -1396,6 +1462,9 @@ func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1420,6 +1489,9 @@ func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1533,3 +1605,44 @@ func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
 
   return %r1, %r2 : vector<5xf32>, vector<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @inner_func
+func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
+  return %t : tensor<?xf32>
+}
+
+func @equivalent_func_arg(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
+  // This test does not check IR. It just asserts there is no failure due to
+  // non-equivalent scf.for yield values.
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    %3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @inner_func_2
+func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
+  return %0 : tensor<?xf32>
+}
+
+func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
+  // This test does not check IR. It just asserts there is no failure due to
+  // non-equivalent scf.for yield values.
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 3a70adba5c2f9..2cda811365607 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -928,3 +928,54 @@ func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
   // CHECK: return
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @inner_func(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: memref.store %{{.*}}, %[[arg0]]
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @equivalent_func_arg(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @equivalent_func_arg(%t0: tensor<?xf32> {linalg.inplaceable = true},
+                          %c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
+  // CHECK-NOT: copy
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    // CHECK: call @inner_func(%[[arg0]])
+    %3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @inner_func_2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: memref.store %{{.*}}, %[[arg0]]
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @equivalent_func_arg_2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
+                            %c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    // TODO: There should be a memory copy here. This is a bug in CallOp
+    // bufferization.
+    // CHECK: call @inner_func_2(%[[arg0]])
+    %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %t1 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list