[Mlir-commits] [mlir] e9fb4dc - [mlir][linalg][bufferize] Remove buffer equivalence from bufferize
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 6 00:51:07 PST 2021
Author: Matthias Springer
Date: 2021-12-06T17:48:31+09:00
New Revision: e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd
URL: https://github.com/llvm/llvm-project/commit/e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd
DIFF: https://github.com/llvm/llvm-project/commit/e9fb4dc9e918e23384550df9b66c2fd87cb1ffdd.diff
LOG: [mlir][linalg][bufferize] Remove buffer equivalence from bufferize
Remove all function calls related to buffer equivalence from bufferize implementations.
Add a new PostAnalysisStep for scf.for that ensures that yielded values are equivalent to the corresponding BBArgs. (This was previously checked in `bufferize`.) This will be relaxed in a subsequent commit.
Note: This commit changes two test cases. These were broken by design
and should not have passed. With the new scf.for PostAnalysisStep, this
bug was fixed.
Differential Revision: https://reviews.llvm.org/D114927
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index 0a4b140a1f96..3ab5cc3525fc 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -19,6 +19,13 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
+/// Equivalence analysis for scf.for. Raise an error if iter_args are not
+/// equivalent to their corresponding loop yield values.
+struct AssertDestinationPassingStyle : public PostAnalysisStep {
+ LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) override;
+};
+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace scf_ext
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index fa0c96275daf..0298a492360a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -37,7 +37,6 @@ struct ConstantOpInterface
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
- state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 03ea6bdd63b9..af5362bfe1f5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -141,22 +141,7 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
/// Return `true` if a value was marked as in-place bufferized.
bool BufferizationAliasInfo::isInPlace(OpResult opResult) const {
- bool inplace = inplaceBufferized.contains(opResult);
-#ifndef NDEBUG
- if (inplace) {
- auto bufferizableOp =
- dyn_cast<BufferizableOpInterface>(opResult.getDefiningOp());
- assert(bufferizableOp &&
- "expected that in-place bufferized op is bufferizable");
- SmallVector<OpOperand *> operands =
- bufferizableOp.getAliasingOpOperand(opResult);
- for (OpOperand *operand : operands)
- assert(areAliasingBufferizedValues(operand->get(), opResult) &&
- "expected that in-place bufferized OpResult aliases with "
- "aliasing OpOperand");
- }
-#endif // NDEBUG
- return inplace;
+ return inplaceBufferized.contains(opResult);
}
/// Set the inPlace bufferization spec to true.
@@ -593,7 +578,6 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
- aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
// 2. Create memory deallocation.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 6a3a25c93582..decfb1d41f8a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -253,8 +253,6 @@ struct TiledLoopOpInterface
return failure();
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(opResult, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
// Insert new operand and bbArg.
@@ -263,9 +261,6 @@ struct TiledLoopOpInterface
body->insertArgument(nextOutputBBArgIndex, resultBuffer.getType());
BlockArgument oldTensorBBArg = body->getArgument(oldOutputBBArgIndex);
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
- state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
- newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Set operand of `linalg.yield` to the bbArg so it just canonicalizes
@@ -303,9 +298,6 @@ struct TiledLoopOpInterface
BlockArgument oldTensorBBArg = body->getArgument(oldInputBBArgIndex);
// Insert mapping and aliasing info.
- state.aliasInfo.createAliasInfoEntry(newBufferBBArg);
- state.aliasInfo.insertNewBufferEquivalence(oldTensorBBArg,
- newBufferBBArg);
state.mapBuffer(oldTensorBBArg, newBufferBBArg);
// Increment indices.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index e65b1eb441a3..5dc37968c95b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -223,7 +223,6 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
BufferizationState &state) {
LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -321,15 +320,12 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
- aliasInfo.insertNewBufferEquivalence(castOp.dest(),
- toMemrefOp.memref());
}
}
// Replace all remaining uses by a to_tensor.
if (!bbArg.use_empty()) {
auto toTensorOp =
b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
- aliasInfo.insertNewBufferEquivalence(toTensorOp, bbArg);
bbArg.replaceAllUsesWith(toTensorOp);
}
frontBlock.eraseArgument(0);
@@ -562,7 +558,6 @@ struct CallOpInterface
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
- state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
state.mapBuffer(oldRes, buffer);
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
@@ -572,7 +567,6 @@ struct CallOpInterface
b.create<bufferization::ToTensorOp>(callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
// Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(toTensorOp, buffer);
state.mapBuffer(toTensorOp, buffer);
continue;
}
@@ -615,7 +609,6 @@ struct CallOpInterface
Value castBuffer =
b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
// Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
state.mapBuffer(tensorOperand, castBuffer);
buffer = castBuffer;
}
@@ -663,7 +656,6 @@ struct ReturnOpInterface
Value returnTensor = b.create<bufferization::ToTensorOp>(
returnOp.getLoc(), v);
operand.set(returnTensor);
- state.aliasInfo.insertNewBufferEquivalence(returnTensor, v);
state.mapBuffer(returnTensor, v);
}
return success();
@@ -690,7 +682,6 @@ struct FuncOpInterface
: getContiguousOrUnrankedMemRefType(tensorType);
Value bufferCast = b.create<bufferization::ToMemrefOp>(funcOp.getLoc(),
memRefType, bbArg);
- state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
state.mapBuffer(bbArg, bufferCast);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index c41457d7da76..d9416347fab5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -147,7 +147,6 @@ struct IfOpInterface
if (!resultBuffer)
return failure();
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
@@ -237,8 +236,6 @@ struct ForOpInterface
OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
state.mapBuffer(bbArg, resultBuffer);
state.mapBuffer(opResult, resultBuffer);
}
@@ -257,15 +254,6 @@ struct ForOpInterface
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
- bbArg)) {
- // TODO: this could get resolved with copies but it can also turn into
- // swaps so we need to be careful about order of copies.
- return yieldOp->emitError()
- << "Yield operand #" << operand.getOperandNumber()
- << " does not bufferize to an equivalent buffer to the matching"
- << " enclosing scf::for operand";
- }
// Buffers are equivalent so the work is already done and we just yield
// the bbArg so that it later canonicalizes away.
@@ -275,6 +263,41 @@ struct ForOpInterface
}
};
+LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
+ AssertDestinationPassingStyle::run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) {
+ LogicalResult status = success();
+ funcOp->walk([&](scf::YieldOp yieldOp) {
+ auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+ if (!forOp)
+ return WalkResult::advance();
+
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+
+ OpOperand &forOperand = forOp.getOpOperandForResult(
+ forOp->getResult(operand.getOperandNumber()));
+ auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+ if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
+ bbArg)) {
+ // TODO: this could get resolved with copies but it can also turn into
+ // swaps so we need to be careful about order of copies.
+ status =
+ yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to an equivalent buffer to the matching"
+ << " enclosing scf::for operand";
+ return WalkResult::interrupt();
+ }
+ }
+
+ return WalkResult::advance();
+ });
+ return status;
+}
+
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index f595de42b7e7..7f1bdb703d18 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -80,7 +80,6 @@ struct CastOpInterface
castOp.getResult().getType(), layout, memorySpace);
Value res =
b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
state.mapBuffer(castOp.getResult(), res);
return success();
}
@@ -233,7 +232,6 @@ struct InsertOpInterface
b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
insertOp.indices());
state.mapBuffer(insertOp, destMemref);
- state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
return success();
}
@@ -421,8 +419,6 @@ struct InsertSliceOpInterface
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- // Insert new alias.
- state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 9da3ed883cfb..f90b5e5e20ec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -96,6 +96,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// TODO: Find a way to enable this step automatically when bufferizing tensor
// dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+ options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 6e82f65cc905..0e1e7602b2af 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1113,7 +1113,7 @@ func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// Read from %t1 via alias %e.
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
- scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
+ scf.yield %t2, %v2 : tensor<?xf32>, vector<5xf32>
}
// CHECK: __inplace_results_attr__ = ["true", "false"]
@@ -1154,14 +1154,10 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// This loop does not read from %t1. It only writes to it.
// CHECK: scf.for
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
- // CHECK: tensor.extract_slice
- // CHECK-SAME: __inplace_results_attr__ = ["true"]
- %e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
-
- // Write to %t1 via alias. (Overwrite %t3.)
+ // Write to %t1 via %t2. (Overwrite %t3.)
// CHECK: linalg.generic
// CHECK-SAME: __inplace_results_attr__ = ["true"]
- %o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
+ %o2 = linalg.generic #trait outs (%t2 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32
} -> (tensor<?xf32>)
@@ -1172,8 +1168,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
}
// Use %t3 in some way without reading it, so that it does not get DCE'd.
- // CHECK: linalg.generic
- // CHECK-SAME: __inplace_results_attr__ = ["true"]
+ // CHECK: linalg.generic
+ // CHECK-SAME: __inplace_results_attr__ = ["true"]
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
^bb(%0: f32) :
linalg.yield %cst : f32
More information about the Mlir-commits
mailing list