[Mlir-commits] [mlir] c94b80b - [mlir][linalg][bufferize][NFC] Allow returning arbitrary memrefs

Matthias Springer llvmlistbot at llvm.org
Thu Nov 25 18:32:09 PST 2021


Author: Matthias Springer
Date: 2021-11-26T11:26:46+09:00
New Revision: c94b80b4380ce851b5cf406a961eab472a43b3df

URL: https://github.com/llvm/llvm-project/commit/c94b80b4380ce851b5cf406a961eab472a43b3df
DIFF: https://github.com/llvm/llvm-project/commit/c94b80b4380ce851b5cf406a961eab472a43b3df.diff

LOG: [mlir][linalg][bufferize][NFC] Allow returning arbitrary memrefs

If `allowReturnMemref` is set to true, arbitrary memrefs may be returned from FuncOps. Also remove allocation hoisting code, which is only partly implemented at the moment.

The purpose of this commit is to untangle `bufferize` from `aliasInfo`. (Even with this change, they are not fully untangled yet.)

Differential Revision: https://reviews.llvm.org/D114507

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 03d1cab526d02..e354195dfca0b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -34,11 +34,22 @@ struct ModuleBufferizationState : public BufferizationState {
 
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+
+  /// A mapping of return values to equivalent BlockArguments.
+  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
 };
 } // namespace
 
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
+/// If `value` is a memref::CastOp, return its source. Otherwise, return
+/// `value` directly.
+static Value getNonCastedValue(Value value) {
+  while (auto castOp = value.getDefiningOp<memref::CastOp>())
+    value = castOp.source();
+  return value;
+}
+
 /// Remove the attribute that triggers inplace bufferization on a FuncOp
 /// argument `bbArg`.
 static void removeBufferizationFuncArguments(BlockArgument bbArg) {
@@ -113,62 +124,40 @@ static FunctionType getOrCreateBufferizedFunctionType(
   return it2.first->second;
 }
 
-/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an
-/// an op. Return null otherwise.
-static Operation *getEquivalentAlloc(Value value,
-                                     const BufferizationAliasInfo &aliasInfo) {
-  Operation *res = nullptr;
-  aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
-    if (!res)
-      if (auto interface =
-              dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
-        if (auto effect =
-                interface.getEffectOnValue<MemoryEffects::Allocate>(v))
-          res = v.getDefiningOp();
-  });
-  return res;
-}
+/// Store function BlockArguments that are equivalent to a returned value in
+/// the given ModuleBufferizationState.
+static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
+                                           ModuleBufferizationState &state) {
+  // Support only single return-terminated block in the function.
+  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  assert(returnOp && "expected func with single return op");
 
-/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
-/// Return null if no such bbArg can be found.
-static BlockArgument
-getEquivalentEnclosingFuncBBArg(Value v,
-                                const BufferizationAliasInfo &aliasInfo) {
-  if (!v.getType().isa<RankedTensorType>())
-    return nullptr;
-  Operation *op = v.getParentBlock()->getParentOp();
-  FuncOp funcOp = dyn_cast<FuncOp>(op);
-  if (!funcOp)
-    funcOp = op->getParentOfType<FuncOp>();
-  assert(funcOp && "expected non-null FuncOp");
-  for (BlockArgument bbArg : funcOp.getArguments()) {
-    if (!bbArg.getType().isa<RankedTensorType>())
-      continue;
-    if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
-      return bbArg;
-  }
-  return nullptr;
+  for (Value returnVal : returnOp.operands())
+    if (returnVal.getType().isa<RankedTensorType>())
+      for (BlockArgument bbArg : funcOp.getArguments())
+        if (bbArg.getType().isa<RankedTensorType>())
+          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
+            state.equivalentReturnValToBBArg[returnVal] = bbArg;
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
 /// buffer form (using the canonical memref layout for now), according to the
 /// inPlace-bufferizable information of the function arguments.
+///
 /// This relies on a buffer equivalence analysis of each return operand. When a
-/// result buffer is equivalent to:
-///   1. a BlockArgument of `funcOp`, it can be dropped from the return values
-///      and becomes inplaceable at all callers. This assumes all CallOp perform
-///      the necessary work to clone operands so as to make them inplaceable.
-//       Reliance on this logic will need to be relaxed in thefuture.
-///   2. an op with an Alloc effect, this currently fails bufferization but is a
-///      candidate for hoisting and creating a new inplace operand at all caller
-///      sites.
-///   3. if such a hoisting for 2. is not possible (e.g. data-dependent that
-///      prevents hoisting), this is currently unsupported and will require a
-///      refcounted buffer type.
-static LogicalResult bufferizeFuncOpBoundary(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
-    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
+/// dropped from the return values and becomes inplaceable at all callers. This
+/// assumes all CallOp perform the necessary work to clone operands so as to
+/// make them inplaceable. Reliance on this logic will need to be relaxed in the
+/// future.
+///
+/// Note: Returning a memref currently fails bufferization. If such memrefs
+/// originate from an op with an Alloc effect, they could be hoisted in the
+/// future.
+static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
+                                             ModuleBufferizationState &state) {
   LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+  BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // If nothing to do then we are done.
   if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -197,9 +186,9 @@ static LogicalResult bufferizeFuncOpBoundary(
     if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
       return funcOp->emitError() << "cannot bufferize bodiless function that "
                                  << "returns a tensor";
-    FunctionType bufferizedFuncType =
-        getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
-                                          TypeRange{}, bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+        funcOp, funcOp.getType().getInputs(), TypeRange{},
+        state.bufferizedFunctionTypes);
     funcOp.setType(bufferizedFuncType);
     LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
     return success();
@@ -212,37 +201,27 @@ static LogicalResult bufferizeFuncOpBoundary(
   // 1. For each FuncOp result, keep track of which inplace argument it reuses.
   SmallVector<Value> returnValues;
   for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+    Value returnVal = returnOperand.get();
+
     // If not a renturn tensor type just forward it.
-    if (!returnOperand.get().getType().isa<RankedTensorType>()) {
-      returnValues.push_back(returnOperand.get());
+    if (!returnVal.getType().isa<RankedTensorType>()) {
+      returnValues.push_back(returnVal);
       continue;
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    Value returnVal = returnOperand.get();
-    if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+    if (state.equivalentReturnValToBBArg.count(returnVal))
       continue;
 
-    // TODO: Need to hoist above function boundary.
-    if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
-      returnValues.push_back(allocOp->getResult(0));
-      continue;
-    }
-
-    // Other cases legitimately need to return a tensor, this is currently not
-    // supported. For instance, if hoisting across function boundary has
-    // failed, it may be due to e.g. data-dependent sizes. In such a case, we
-    // would need a better type than memref.
-    int64_t returnIdx = returnOperand.getOperandNumber();
-    return returnOp->emitError()
-           << "buffer result #" << returnIdx << " not produced by an alloc\n";
+    // Cast values at the call site if necessary.
+    returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
   }
 
   // 2. Rewrite the terminator without the inPlace bufferizable values.
   ValueRange retValues{returnValues};
   FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
       funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
-      bufferizedFunctionTypes);
+      state.bufferizedFunctionTypes);
   OpBuilder b(returnOp);
   b.create<ReturnOp>(returnOp.getLoc(), returnValues);
   returnOp->erase();
@@ -495,6 +474,7 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
            "expected Callop to a FuncOp");
+    auto &moduleState = static_cast<ModuleBufferizationState &>(state);
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
@@ -507,11 +487,10 @@ struct CallOpInterface
     //      semantics is ill-defined.
     //    - if the callee has a body, we perform inter-procedural equivalence
     //      analysis. When successful, a result folds onto an operand. When
-    //      unsuccessful, additional work is needed to either:
+    //      unsuccessful, additional work is needed (TODO) to either:
     //        * hoist a result into an inplaceable operand or
     //        * devise a better representation to truly return a buffer.
     SmallVector<Type> resultTypes;
-    SmallVector<Value> hoistedArguments;
     if (funcOp.body().empty()) {
       if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
         return callOp->emitError()
@@ -530,8 +509,9 @@ struct CallOpInterface
 
         // If return operand is equivalent to some bbArg, no need to return it.
         Value returnVal = returnOperand.get();
-        if (BlockArgument bbArg =
-                getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
+        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
+          BlockArgument bbArg =
+              moduleState.equivalentReturnValToBBArg[returnVal];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
           int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
@@ -552,35 +532,17 @@ struct CallOpInterface
           continue;
         }
 
-        // TODO: Need to hoist above function boundary.
-        if (Operation *allocOp =
-                getEquivalentAlloc(returnVal, state.aliasInfo)) {
-          hoistedArguments.push_back(allocOp->getResult(0));
-          continue;
-        }
-
-        // Other cases legitimately need to return a tensor, this is currently
-        // not supported. For instance, if hoisting across function boundary has
-        // failed, it may be due to e.g. data-dependent sizes. In such a case,
-        // we would we need a better type than memref.
         resultTypes.push_back(returnType);
-
-        int64_t returnIdx = returnOperand.getOperandNumber();
-        return returnOp->emitError() << "buffer result #" << returnIdx
-                                     << " not produced by an alloc\n";
       }
     }
 
     // 2. Compute bufferized FunctionType.
     SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
-    ValueRange hoistedArgs{hoistedArguments};
-    llvm::append_range(argumentTypes, hoistedArgs.getTypes());
     // Get the bufferized FunctionType for funcOp or construct it if not yet
     // available.
-    // TODO: Assert that `state` is a ModuleBufferizationState.
-    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
-        funcOp, argumentTypes, resultTypes,
-        static_cast<ModuleBufferizationState &>(state).bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType =
+        getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
+                                          moduleState.bufferizedFunctionTypes);
 
     // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
     SmallVector<Value> newOperands;
@@ -713,6 +675,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     // Analyze and bufferize funcOp.
     if (failed(runComprehensiveBufferize(funcOp, options, state)))
       return failure();
+
+    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)
@@ -721,8 +685,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   for (FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
-                                       state.bufferizedFunctionTypes)))
+    if (failed(bufferizeFuncOpBoundary(funcOp, state)))
       return failure();
 
     if (!options.allowReturnMemref &&

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index d9e0027a3d8ce..c0a91e3f1df83 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -110,6 +110,7 @@ func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iter
 
 // -----
 
+// expected-error @+1 {{memref return type is unsupported}}
 func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   ->  tensor<4xf32>
 {
@@ -121,7 +122,6 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   //     argument aliasing).
   %r0 = tensor.extract_slice %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
 
-  // expected-error @+1 {{buffer result #0 not produced by an alloc}}
   return %r0: tensor<4xf32>
 }
 


        


More information about the Mlir-commits mailing list