[Mlir-commits] [mlir] 70777d9 - [mlir][bufferize][NFC] Move FuncOp bufferization to BufferizableOpInterface impl

Matthias Springer llvmlistbot at llvm.org
Fri Apr 22 02:47:58 PDT 2022


Author: Matthias Springer
Date: 2022-04-22T18:47:12+09:00
New Revision: 70777d967fb79b4b12caff2fabbff06b7f11acc7

URL: https://github.com/llvm/llvm-project/commit/70777d967fb79b4b12caff2fabbff06b7f11acc7
DIFF: https://github.com/llvm/llvm-project/commit/70777d967fb79b4b12caff2fabbff06b7f11acc7.diff

LOG: [mlir][bufferize][NFC] Move FuncOp bufferization to BufferizableOpInterface impl

FuncOps are now less special. They must still be analyzed + bufferized in a certain order, but they are now bufferized same as other ops that have a region: Bufferize the op first (`bufferize` interface method), then bufferize the region body with other bufferization patterns. In the case of FuncOps, the function signature is bufferized together with ReturnOps. Similar to how, e.g., scf.for ops are bufferized together with scf.yield ops.

This change is essentially a reimplementation of the FuncOp bufferization, but mostly NFC from a user's perspective (apart from error messages). This change is in preparation of moving the code to the bufferization dialect.

Differential Revision: https://reviews.llvm.org/D123214

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 8571617b2a677..02130ae459f55 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -237,6 +237,12 @@ static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
 /// Return true if the given op has a tensor result or a tensor operand.
 static bool hasTensorSemantics(Operation *op) {
+  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+    bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
+    bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
+    return hasTensorArg || hasTensorResult;
+  }
+
   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
   return hasTensorResult || hasTensorOperand;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 7178e03a0c22d..b798728b33dd6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -337,16 +337,6 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
 }
 } // namespace
 
-static bool isaTensor(Type t) { return t.isa<TensorType>(); }
-
-/// If `value` is a memref::CastOp, return its source. Otherwise, return
-/// `value` directly.
-static Value getNonCastedValue(Value value) {
-  while (auto castOp = value.getDefiningOp<memref::CastOp>())
-    value = castOp.source();
-  return value;
-}
-
 /// Remove the attribute that triggers inplace bufferization on a func::FuncOp
 /// argument `bbArg`.
 static void removeBufferizationFuncArguments(BlockArgument bbArg) {
@@ -366,26 +356,15 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
-/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
-/// tensor is replaced by the corresponding buffer type.
-/// In order for all the callers to agree, this *must* bufferize to the most
-/// dynamic buffer type supported.
-/// A later pass across all CallOps in the module can decide whether to simplify
-/// the types of to version according to some cost model.
-static FunctionType
-getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes,
-                          TypeRange resultTypes,
-                          const BufferizationOptions &options) {
-  auto rewrite = [&](Type t) -> Type {
-    // TODO: non-zero address space.
-    // TODO: layout information if relevant.
-    if (auto tensorType = t.dyn_cast<TensorType>())
-      return getMemRefType(tensorType, options);
-    return t;
-  };
-  auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
-  auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
-  return FunctionType::get(ctx, argTypes, retTypes);
+/// Return the index-th bufferized function argument type. This assumes that the
+/// specified argument is a tensor.
+static BaseMemRefType
+getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index,
+                             const BufferizationOptions &options) {
+  auto tensorType =
+      funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
+  assert(tensorType && "expected TensorType");
+  return getMemRefType(tensorType, options);
 }
 
 /// Gather equivalence info of CallOps.
@@ -415,150 +394,6 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
   });
 }
 
-/// Rewrite the `funcOp` arguments analysis return values and terminator into
-/// buffer form (using the canonical memref layout for now), according to the
-/// inPlace-bufferizable information of the function arguments.
-///
-/// This relies on a buffer equivalence analysis of each return operand. When a
-/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
-/// dropped from the return values and becomes inplaceable at all callers. This
-/// assumes all CallOp perform the necessary work to clone operands so as to
-/// make them inplaceable. Reliance on this logic will need to be relaxed in the
-/// future.
-///
-/// Note: Returning a memref currently fails bufferization. If such memrefs
-/// originate from an op with an Alloc effect, they could be hoisted in the
-/// future.
-static LogicalResult bufferizeFuncOpBoundary(func::FuncOp funcOp,
-                                             RewriterBase &rewriter,
-                                             BufferizationState &state) {
-  const FuncAnalysisState &funcState =
-      getFuncAnalysisState(state.getAnalysisState());
-
-  // If nothing to do then we are done.
-  if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) &&
-      !llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor))
-    return success();
-
-  // Get the bufferized FunctionType for funcOp or construct it if not yet
-  // available.
-  // TODO: Atm we have 3 cases:
-  // 1. if a function is called from within the Module, it must have bufferized
-  //    to inplaceable tensor results.
-  // 2. if it is bodiless, it must have bufferized and is not allowed to have
-  //    result tensors.
-  // 3. if it is not called internally, it still must bufferize to inplaceable
-  //    tensor results and we construct it now (e.g. top-level function called
-  //    externally).
-  // -> Figure out a better layering.
-  TypeRange resultTypes;
-
-  // Corner case: Bodiless FuncOp
-  // ============================
-  // The body of such functions is assumed opaque and we can't know the
-  // bufferization contract they want to enforce atm.
-  // As a consequence, only support functions that don't return any tensor atm.
-  if (funcOp.getBody().empty()) {
-    if (llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor))
-      return funcOp->emitError() << "cannot bufferize bodiless function that "
-                                 << "returns a tensor";
-    FunctionType bufferizedFuncType = getBufferizedFunctionType(
-        funcOp.getContext(), funcOp.getFunctionType().getInputs(),
-        funcOp.getFunctionType().getResults(), state.getOptions());
-    funcOp.setType(bufferizedFuncType);
-    return success();
-  }
-
-  // Support only single return-terminated block in the function.
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
-
-  // 1. For each FuncOp result, keep track of which inplace argument it reuses.
-  SmallVector<Value> returnValues;
-  for (OpOperand &returnOperand : returnOp->getOpOperands()) {
-    Value returnVal = returnOperand.get();
-
-    // If not a renturn tensor type just forward it.
-    if (!returnVal.getType().isa<RankedTensorType>()) {
-      returnValues.push_back(returnVal);
-      continue;
-    }
-
-    // If return operand is equivalent to some bbArg, no need to return it.
-    auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp);
-    if (funcOpIt != funcState.equivalentFuncArgs.end() &&
-        funcOpIt->second.count(returnOperand.getOperandNumber()))
-      continue;
-
-    // Cast values at the call site if necessary.
-    returnValues.push_back(
-        getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
-  }
-
-  // 2. Rewrite the terminator without the inPlace bufferizable values.
-  ValueRange retValues{returnValues};
-  FunctionType bufferizedFuncType = getBufferizedFunctionType(
-      funcOp.getContext(), funcOp.getFunctionType().getInputs(),
-      retValues.getTypes(), state.getOptions());
-  OpBuilder b(returnOp);
-  b.create<func::ReturnOp>(returnOp.getLoc(), returnValues);
-  returnOp->erase();
-
-  // 3. Rewrite the bbArgs.
-  // Iterate on the original `numArgs` and replace them in order.
-  // This guarantees the argument order still matches after the rewrite.
-  Block &frontBlock = funcOp.getBody().front();
-  unsigned numArgs = frontBlock.getNumArguments();
-  for (unsigned idx = 0; idx < numArgs; ++idx) {
-    auto bbArg = frontBlock.getArgument(0);
-    auto tensorType = bbArg.getType().dyn_cast<TensorType>();
-    // Non-tensor types are just forwarded.
-    if (!tensorType) {
-      frontBlock.addArgument(bbArg.getType(), bbArg.getLoc());
-      bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
-      frontBlock.eraseArgument(0);
-      continue;
-    }
-
-    // Get the buffer type from the bufferized function type.
-    Type memrefType = bufferizedFuncType.getInput(idx);
-    Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
-    OpBuilder b(funcOp->getContext());
-    b.setInsertionPointToStart(&frontBlock);
-    // Replace all uses of bbArg through a ToMemRefOp.
-    for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
-      if (auto toMemrefOp =
-              dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
-        if (memref.getType() != toMemrefOp.memref().getType()) {
-          // Type has changed, insert a cast.
-          assert(memref::CastOp::areCastCompatible(
-                     memref.getType(), toMemrefOp.memref().getType()) &&
-                 "bufferizeFuncOpBoundary: cast incompatible");
-          auto castOp = b.create<memref::CastOp>(
-              funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
-          toMemrefOp.memref().replaceAllUsesWith(castOp);
-        } else {
-          // Type did not change, replace directly.
-          toMemrefOp.memref().replaceAllUsesWith(memref);
-        }
-      }
-    }
-    // Replace all remaining uses by a to_tensor.
-    if (!bbArg.use_empty()) {
-      auto toTensorOp =
-          b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
-      bbArg.replaceAllUsesWith(toTensorOp);
-    }
-    frontBlock.eraseArgument(0);
-    // TODO: add support to erase aliasInfo entries if deemed necessary.
-  }
-
-  // 4. Rewrite the FuncOp type to buffer form.
-  funcOp.setType(bufferizedFuncType);
-
-  return success();
-}
-
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
 /// callee-caller order (i.e. callees without callers first).
 /// Store the map of FuncOp to all its callers in `callerMap`.
@@ -826,9 +661,8 @@ struct CallOpInterface
     return BufferRelation::Equivalent;
   }
 
-  /// In a first approximation, all the function arguments of a func::FuncOp are
-  /// marked inplaceable. For now, it is the responsibility of the `callOp`
-  /// bufferization to allow func::FuncOp that are inplaceable to write inPlace.
+  /// All function arguments are writable. It is the responsibility of the
+  /// CallOp to insert buffer copies where necessary.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     func::CallOp callOp = cast<func::CallOp>(op);
@@ -871,7 +705,7 @@ struct CallOpInterface
     for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
       unsigned returnValIdx = it.index();
       Type returnType = it.value();
-      if (!isaTensor(returnType)) {
+      if (!returnType.isa<TensorType>()) {
         // Non-tensor values are returned.
         retValMapping[returnValIdx] = resultTypes.size();
         resultTypes.push_back(returnType);
@@ -903,12 +737,10 @@ struct CallOpInterface
           funcOp.getFunctionType().getResult(resultTypes.size()));
     }
 
-    // 2. Compute bufferized FunctionType.
-    SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
-    // Get the bufferized FunctionType for funcOp or construct it if not yet
-    // available.
-    FunctionType bufferizedFuncType = getBufferizedFunctionType(
-        funcOp.getContext(), argumentTypes, resultTypes, options);
+    // 2. Get the bufferized FunctionType of the called function. Recursive or
+    // circular call graphs are not currently supported, so we can be sure that
+    // the called function was already bufferized.
+    FunctionType bufferizedFuncType = funcOp.getFunctionType();
 
     // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
     for (OpOperand &opOperand : callOp->getOpOperands()) {
@@ -993,6 +825,8 @@ struct ReturnOpInterface
     assert(isa<func::FuncOp>(returnOp->getParentOp()) &&
            "only support FuncOp parent for ReturnOp");
 #endif // NDEBUG
+
+    // ReturnOps are bufferized as part of FuncOps.
     return failure();
   }
 };
@@ -1000,9 +834,128 @@ struct ReturnOpInterface
 struct FuncOpInterface
     : public BufferizableOpInterface::ExternalModel<FuncOpInterface,
                                                     func::FuncOp> {
+  /// Rewrite function bbArgs and return values into buffer form (using the
+  /// canonical memref layout for now). This function bufferizes the function
+  /// signature and the ReturnOp. When the entire function body has been
+  /// bufferized, function return types can be switched to more concise memref
+  /// types as part of `foldMemRefCasts`.
+  ///
+  /// When a tensor function argument is known to be equivalent to a tensor
+  /// result, it is dropped from the return values.
+  ///
+  /// All function bbArgs are writable unless they are explicitly marked as
+  /// read-only. Callers must insert copies when needed.
+  ///
+  /// Note: Returning a memref is possible, but corresponding CallOp
+  /// bufferizations fail unless `allowReturnAllocs`.
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
-    return failure();
+    auto funcOp = cast<func::FuncOp>(op);
+    FunctionType funcType = funcOp.getFunctionType();
+    const FuncAnalysisState &moduleState =
+        getFuncAnalysisState(state.getAnalysisState());
+    const BufferizationOptions &options = state.getOptions();
+
+    // Construct the bufferized function type.
+    SmallVector<Type> argTypes;
+    for (const auto &it : llvm::enumerate(funcType.getInputs())) {
+      Type argType = it.value();
+      if (auto tensorType = argType.dyn_cast<TensorType>()) {
+        argTypes.push_back(
+            getBufferizedFunctionArgType(funcOp, it.index(), options));
+        continue;
+      }
+      argTypes.push_back(argType);
+    }
+
+    // Bodiless functions are assumed opaque and we cannot know the
+    // bufferization contract they want to enforce. As a consequence, only
+    // support functions that don't return any tensors atm.
+    if (funcOp.getBody().empty()) {
+      SmallVector<Type> retTypes;
+      for (Type resultType : funcType.getResults()) {
+        if (resultType.isa<TensorType>())
+          return funcOp->emitError() << "cannot bufferize bodiless function "
+                                     << "that returns a tensor";
+        retTypes.push_back(resultType);
+      }
+      funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
+      return success();
+    }
+
+    // TODO: Support functions with multiple returns.
+    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+    assert(returnOp && "expected func with single return op");
+
+    // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
+    Block &frontBlock = funcOp.getBody().front();
+    for (BlockArgument &bbArg : frontBlock.getArguments()) {
+      auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+      // Non-tensor types stay the same.
+      if (!tensorType)
+        continue;
+
+      // Collect all uses of the bbArg.
+      SmallVector<OpOperand *> bbArgUses;
+      for (OpOperand &use : bbArg.getUses())
+        bbArgUses.push_back(&use);
+
+      // Change the bbArg type to memref.
+      Type memrefType =
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
+      bbArg.setType(memrefType);
+
+      // Replace all uses of the original tensor bbArg.
+      rewriter.setInsertionPointToStart(&frontBlock);
+      if (!bbArgUses.empty()) {
+        // Insert to_tensor because the remaining function body has not been
+        // bufferized yet.
+        Value toTensorOp =
+            rewriter.create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
+        for (OpOperand *use : bbArgUses)
+          use->set(toTensorOp);
+      }
+    }
+
+    // 2. For each result, keep track of which inplace argument it reuses.
+    SmallVector<Value> returnValues;
+    for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+      Value returnVal = returnOperand.get();
+
+      // If not a tensor type just forward it.
+      if (!returnVal.getType().isa<RankedTensorType>()) {
+        returnValues.push_back(returnVal);
+        continue;
+      }
+
+      // If return operand is equivalent to some bbArg, no need to return it.
+      if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
+              funcOp, moduleState, returnOperand.getOperandNumber())) {
+        rewriter.setInsertionPoint(returnOp);
+        Location loc = returnOp.getLoc();
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            loc, getMemRefType(returnVal.getType().cast<TensorType>(), options),
+            returnVal);
+        BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
+        // Note: This copy will fold away. It must be inserted here to ensure
+        // that `returnVal` still has at least one use and does not fold away.
+        if (failed(
+                createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
+          return funcOp->emitError("could not generate copy for bbArg");
+        continue;
+      }
+
+      returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
+    }
+
+    // 3. Rewrite the terminator without the in-place bufferizable values.
+    returnOp.operandsMutable().assign(returnValues);
+
+    // 4. Rewrite the FuncOp type to buffer form.
+    funcOp.setType(FunctionType::get(op->getContext(), argTypes,
+                                     ValueRange(returnValues).getTypes()));
+
+    return success();
   }
 
   /// Return `true` if the given function argument is writable.
@@ -1058,6 +1011,34 @@ static void annotateOpsWithBufferizationMarkers(func::FuncOp funcOp,
       setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
 }
 
+/// Fold return values that are memref casts and update function return types.
+///
+/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
+/// is not known yet. Therefore, the bufferization uses memref types with the
+/// most generic layout map as function return types. After bufferizing the
+/// entire function body, a more concise memref type can potentially be used for
+/// the return type of the function.
+static void foldMemRefCasts(func::FuncOp funcOp) {
+  if (funcOp.getBody().empty())
+    return;
+
+  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  SmallVector<Type> resultTypes;
+
+  for (OpOperand &operand : returnOp->getOpOperands()) {
+    if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
+      operand.set(castOp.source());
+      resultTypes.push_back(castOp.source().getType());
+    } else {
+      resultTypes.push_back(operand.get().getType());
+    }
+  }
+
+  auto newFuncType = FunctionType::get(
+      funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
+  funcOp.setType(newFuncType);
+}
+
 LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
     ModuleOp moduleOp, OneShotBufferizationOptions options) {
   IRRewriter rewriter(moduleOp.getContext());
@@ -1108,15 +1089,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
 
   // Bufferize functions.
   for (func::FuncOp funcOp : orderedFuncOps) {
-    // No body => no analysis.
-    if (!funcOp.getBody().empty())
-      if (failed(bufferizeOp(funcOp, bufferizationState)))
-        return failure();
-
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState)))
+    if (failed(bufferizeOp(funcOp, bufferizationState)))
       return failure();
+    foldMemRefCasts(funcOp);
   }
 
   // Check result.

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 43c59328678cf..34a553f93516d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -11,6 +11,7 @@ func.func @bar() -> tensor<?xf32> {
 
 // -----
 
+// expected-error @+2 {{op was not bufferized}}
 // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
 func.func private @foo() -> tensor<?xf32>
 
@@ -212,6 +213,7 @@ func.func @to_memref_op_is_writing(
 
 // -----
 
+// expected-error @+2 {{op was not bufferized}}
 // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
 func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
 


        


More information about the Mlir-commits mailing list