[Mlir-commits] [mlir] 6b16683 - [mlir][Linalg] Improve comprehensive bufferization for scf.yield.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 12 03:36:36 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-12T10:36:25Z
New Revision: 6b1668397fd33440847f5a82675c5b83c4137018
URL: https://github.com/llvm/llvm-project/commit/6b1668397fd33440847f5a82675c5b83c4137018
DIFF: https://github.com/llvm/llvm-project/commit/6b1668397fd33440847f5a82675c5b83c4137018.diff
LOG: [mlir][Linalg] Improve comprehensive bufferization for scf.yield.
Previously, comprehensive bufferization of scf.yield did not have enough information
to detect whether an enclosing scf::for bbargs would bufferize to a buffer equivalent
to that of the matching scf::yield operand.
As a consequence a separate sanity check step would be required to determine whether
bufferization occured properly.
This late check would miss the case of calling a function in an loop.
Instead, we now pass and update aliasInfo during bufferization and it is possible to
imrpove bufferization of scf::yield and drop that post-pass check.
Add an example use case that was failing previously.
This slightly modifies the error conditions, which are also updated as part of this
revision.
Differential Revision: https://reviews.llvm.org/D105803
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 8c37bebe10003..be39eec14a993 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -2075,14 +2075,24 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
+
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (getInPlace(bbArg) == InPlaceSpec::True)
- operand.set(bbArg);
- else
- operand.set(
- b.create<memref::TensorLoadOp>(yieldOp.getLoc(), lookup(bvm, bbArg)));
+ Value yieldedBuffer = lookup(bvm, operand.get());
+ Value bbArgBuffer = lookup(bvm, bbArg);
+ if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) {
+ // TODO: this could get resolved with copies but it can also turn into
+ // swaps so we need to be careful about order of copies.
+ return yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to an equivalent buffer to the matching"
+ << " enclosing scf::for operand";
+ }
+
+ // Buffers are equivalent so the work is already done and we just yield the
+ // bbArg so that it later canonicalizes away.
+ operand.set(bbArg);
}
return success();
}
@@ -2205,38 +2215,6 @@ bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result,
return success();
}
-/// Return `failure()` if either
-/// scf::YieldOp are not explicitly bufferized and we need to perform a separate
-/// sanity check for now.
-static LogicalResult
-bufferizationSanityCheck(scf::YieldOp yieldOp,
- const BufferizationAliasInfo &aliasInfo) {
- auto parentForOp = yieldOp->getParentOfType<scf::ForOp>();
- if (!parentForOp)
- return yieldOp->emitError() << "not nested under ForOp";
-
- for (OpOperand &operand : yieldOp->getOpOperands()) {
- OpResult matchingForOpResult =
- parentForOp->getResult(operand.getOperandNumber());
- // Nothing to do if operand bufferizes out of place.
- if (getInPlace(matchingForOpResult) != InPlaceSpec::True)
- continue;
- OpOperand &machingForOpOperand =
- parentForOp.getOpOperandForResult(matchingForOpResult);
- BlockArgument matchingForOpIterArg =
- parentForOp.getRegionIterArgForOpOperand(machingForOpOperand);
- if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg,
- operand.get())) {
- return yieldOp->emitError()
- << "Yield operand #" << operand.getOperandNumber()
- << " does not bufferize to an equivalent buffer to the matching"
- << " enclosing scf::for operand -> Fail the pass\n";
- }
- }
-
- return success();
-}
-
/// Analyze the `funcOp` body to determine which OpResults are inplaceable.
static LogicalResult
inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
@@ -2275,13 +2253,14 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return failure();
}
- // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled
- // separately.
+ // Analyze all ops that return a tensors, except ExtractSliceOp and
+ // InsertSliceOp which are handled separately.
// Walk other ops in reverse for better interference behavior.
for (Operation *op : reverse(nonSliceOps))
for (OpOperand &opOperand : op->getOpOperands())
if (OpResult result = getInplaceableOpResult(opOperand))
- if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
+ if (result.getType().isa<TensorType>() &&
+ failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo,
domInfo)))
return failure();
@@ -2292,14 +2271,9 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
return failure();
- // Sanity checks.
- auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult {
- return bufferizationSanityCheck(yieldOp, aliasInfo);
- });
-
LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');
- return success(!walkResult.wasInterrupted());
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 15be096dd86a5..cdf35c035ef14 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -18,7 +18,7 @@ func private @foo() -> tensor<?xf32>
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
- -> (tensor<f32>, tensor<f32>)
+ -> (tensor<f32>, tensor<f32>)
{
cond_br %cond1, ^bb1, ^bb2
@@ -64,7 +64,7 @@ func @scf_for(%A : tensor<?xf32>,
// Throw a wrench in the system by swapping yielded values: this result in a
// ping-pong of values at each iteration on which we currently want to fail.
- // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}}
+ // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
}
@@ -73,6 +73,27 @@ func @scf_for(%A : tensor<?xf32>,
// -----
+func private @fun_with_side_effects(%A: tensor<?xf32> {linalg.inplaceable = true})
+
+func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> (tensor<?xf32>) {
+ call @fun_with_side_effects(%A) : (tensor<?xf32>) -> ()
+ return %A: tensor<?xf32>
+}
+
+func @scf_yield_needs_copy(%A : tensor<?xf32> {linalg.inplaceable = true}, %iters : index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor<?xf32>) {
+ %r = call @foo(%A) : (tensor<?xf32>) -> (tensor<?xf32>)
+ // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}}
+ scf.yield %r : tensor<?xf32>
+ }
+ call @fun_with_side_effects(%res) : (tensor<?xf32>) -> ()
+ return
+}
+
+// -----
+
func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
-> tensor<4xf32>
{
@@ -92,8 +113,8 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
- %r = scf.if %b -> (tensor<4xf32>) {
- // expected-error @+1 {{not nested under ForOp}}
+ // expected-error @+1 {{unsupported op with tensors}}
+ %r = scf.if %b -> (tensor<4xf32>) {
scf.yield %A : tensor<4xf32>
} else {
scf.yield %B : tensor<4xf32>
More information about the Mlir-commits
mailing list