[Mlir-commits] [mlir] 8d0994e - [mlir][linalg][bufferize][NFC] Remove special casing of CallOps
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 22 18:16:56 PST 2021
Author: Matthias Springer
Date: 2021-11-23T11:14:10+09:00
New Revision: 8d0994ed21b2fa063bc530ece580109d55593d29
URL: https://github.com/llvm/llvm-project/commit/8d0994ed21b2fa063bc530ece580109d55593d29
DIFF: https://github.com/llvm/llvm-project/commit/8d0994ed21b2fa063bc530ece580109d55593d29.diff
LOG: [mlir][linalg][bufferize][NFC] Remove special casing of CallOps
Differential Revision: https://reviews.llvm.org/D113966
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index a9a344d75c1c3..7ef016f7c5dd6 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -11,6 +11,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
@@ -240,9 +241,8 @@ struct AllocationCallbacks {
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
- BufferizationState(BufferizationAliasInfo &aliasInfo,
- AllocationCallbacks &allocationFns)
- : aliasInfo(aliasInfo), allocationFns(allocationFns) {}
+ BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns)
+ : aliasInfo(moduleOp), allocationFns(allocationFns) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -270,8 +270,11 @@ struct BufferizationState {
/// Mark `op` as obsolete, so that it is deleted after bufferization.
void markOpObsolete(Operation *op);
+ /// Erase all ops that were marked obsolete.
+ void eraseObsoleteOps();
+
/// `aliasInfo` keeps track of aliasing and equivalent values.
- BufferizationAliasInfo &aliasInfo;
+ BufferizationAliasInfo aliasInfo;
/// `allocationFns` contains helper functions for creating alloc ops, dealloc
/// ops and memcpy ops.
@@ -283,6 +286,10 @@ struct BufferizationState {
/// Obsolete ops that should be deleted after bufferization.
SmallVector<Operation *> obsoleteOps;
+
+ /// A map for looking up bufferized function types.
+ // TODO: Entangle function calls and FuncOps from the remaining bufferization.
+ DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 653ec7b36eb86..6db6ba6db3c5e 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -25,11 +25,7 @@ static constexpr int64_t kBufferAlignments = 128;
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// Bufferize one particular op.
-/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
-/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
-LogicalResult
-bufferizeOp(Operation *op, BufferizationState &state,
- DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
+LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 150ffd7e45f3a..630415bf469de 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -470,3 +470,10 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete(
Operation *op) {
obsoleteOps.push_back(op);
}
+
+void mlir::linalg::comprehensive_bufferize::BufferizationState::
+ eraseObsoleteOps() {
+ for (Operation *op : obsoleteOps)
+ op->erase();
+ obsoleteOps.clear();
+}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 22f5493b6b80b..ea9309e01d874 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -783,144 +783,6 @@ static Value createNewAllocDeallocPairForShapedValue(
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
-/// In a first approximation, all the function arguments of a FuncOp are marked
-/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
-/// to allow FuncOp that are inplaceable to write inPlace.
-static LogicalResult
-bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
- FuncOp funcOp = getCalledFunction(callOp);
- assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
- "expected Callop to a FuncOp");
-
- // If nothing to do then we are done.
- if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
- !llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return success();
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(callOp);
-
- // 1. Filter return types:
- // - if the callee is bodiless / external, we cannot inspect it and we
- // cannot assume anything. We can just assert that it does not return a
- // tensor as this would have to bufferize to "return a memref", whose
- // semantics is ill-defined.
- // - if the callee has a body, we perform inter-procedural equivalence
- // analysis. When successful, a result folds onto an operand. When
- // unsuccessful, additional work is needed to either:
- // * hoist a result into an inplaceable operand or
- // * devise a better representation to truly return a buffer.
- SmallVector<Type> resultTypes;
- SmallVector<Value> hoistedArguments;
- if (funcOp.body().empty()) {
- if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return callOp->emitError()
- << "cannot bufferize bodiless function that returns a tensor";
- } else {
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- // For each FuncOp result, keep track of which inplace argument it reuses.
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- Type returnType = returnOperand.get().getType();
- if (!isaTensor(returnType)) {
- resultTypes.push_back(returnType);
- continue;
- }
-
- // If return operand is equivalent to some bbArg, no need to return it.
- Value returnVal = returnOperand.get();
- if (BlockArgument bbArg =
- getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
- Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
- int64_t idx = bbArg.getArgNumber();
- Value buffer = state.lookupBuffer(callOp->getOperand(idx));
- // Add CallOp operand/result equivalence: this is interprocedural info.
- state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
- state.mapBuffer(oldRes, buffer);
- // Add a TensorLoadOp to kill all uses of the CallOp return.
- // Replace all uses of the CallOp results so we can erase the CallOp.
- // This TensorLoadOp must fold/DCE away or bufferization should be
- // considered failed.
- Value tensorLoad =
- b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
- oldRes.replaceAllUsesWith(tensorLoad);
- // Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
- state.mapBuffer(tensorLoad, buffer);
- continue;
- }
-
- // TODO: Need to hoist above function boundary.
- if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) {
- hoistedArguments.push_back(allocOp->getResult(0));
- continue;
- }
-
- // Other cases legitimately need to return a tensor, this is currently not
- // supported. For instance, if hoisting across function boundary has
- // failed, it may be due to e.g. data-dependent sizes. In such a case, we
- // would we need a better type than memref.
- resultTypes.push_back(returnType);
-
- int64_t returnIdx = returnOperand.getOperandNumber();
- return returnOp->emitError()
- << "buffer result #" << returnIdx << " not produced by an alloc\n";
- }
- }
-
- // 2. Compute bufferized FunctionType.
- SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
- ValueRange hoistedArgs{hoistedArguments};
- llvm::append_range(argumentTypes, hoistedArgs.getTypes());
- // Get the bufferized FunctionType for funcOp or construct it if not yet
- // available.
- FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
- funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes);
-
- // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
- SmallVector<Value> newOperands;
- newOperands.reserve(callOp->getNumOperands());
- for (OpOperand &opOperand : callOp->getOpOperands()) {
- Value tensorOperand = opOperand.get();
- // Non-tensor operands are just copied.
- if (!tensorOperand.getType().isa<TensorType>()) {
- newOperands.push_back(tensorOperand);
- continue;
- }
-
- // Tensor operands are guaranteed to have been buferized.
- int64_t idx = opOperand.getOperandNumber();
- Value buffer = state.lookupBuffer(tensorOperand);
-
- // Caller / callee type mistmatch is handled with a CastOp.
- auto memRefType = bufferizedFuncType.getInput(idx);
- // Since we don't yet have a clear layout story, buffer_cast may
- // conservatively turn tensors into more dynamic memref than necessary.
- // If the memref type of the callee fails, introduce an extra memref.cast
- // that will either canonicalize away or fail compilation until we can do
- // something better.
- if (buffer.getType() != memRefType) {
- Value castBuffer =
- b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
- // Add new op equivalence info.
- state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
- state.mapBuffer(tensorOperand, castBuffer);
- buffer = castBuffer;
- }
- newOperands.push_back(buffer);
- }
-
- // 4. Create the new CallOp.
- Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
- resultTypes, newOperands);
- newCallOp->setAttrs(callOp->getAttrs());
- // Delete the op at the end of bufferization.
- return success();
-}
-
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BufferizationState &state) {
@@ -1065,20 +927,11 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
-LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
- Operation *op, BufferizationState &state,
- DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
+ BufferizationState &state) {
OpBuilder b(op->getContext());
- // CallOps are handled separately.
- if (auto callOp = dyn_cast<CallOpInterface>(op)) {
- LDBG("Begin bufferize:\n" << callOp << '\n');
- if (!bufferizedFunctionTypes)
- llvm_unreachable(
- "null bufferizedFunctionTypes when bufferizing CallOpInterface");
- return bufferize(b, callOp, state, *bufferizedFunctionTypes);
- }
-
// Skip BufferCast and TensorLoad ops.
if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
return success();
@@ -1098,9 +951,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
return op->emitError() << "unsupported op with tensors";
}
-static LogicalResult bufferizeFuncOpInternals(
- FuncOp funcOp, BufferizationState &state,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
+ BufferizationState &state) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
@@ -1109,19 +961,9 @@ static LogicalResult bufferizeFuncOpInternals(
if (failed(bufferize(b, funcOp, state)))
return failure();
- // Cannot erase ops during the traversal. Do that afterwards.
- SmallVector<Operation *> toErase;
-
auto walkFunc = [&](Operation *op) -> WalkResult {
- if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes)))
+ if (failed(bufferizeOp(op, state)))
return failure();
-
- // Register post-walk erasure, if necessary.
- if (isa<CallOpInterface>(op))
- if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
- llvm::any_of(op->getResultTypes(), isaTensor))
- toErase.push_back(op);
-
return success();
};
@@ -1133,9 +975,6 @@ static LogicalResult bufferizeFuncOpInternals(
LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
- for (Operation *op : toErase)
- op->erase();
-
return success();
}
@@ -1516,12 +1355,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
ModuleOp moduleOp, const BufferizationOptions &options) {
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
- DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
DominanceInfo domInfo(moduleOp);
- BufferizationAliasInfo aliasInfo(moduleOp);
+ BufferizationState state(moduleOp, *options.allocationFns);
+ BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
// can be fully bufferized inplace by guaranteeing the CallOp is bufferized
@@ -1564,16 +1403,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
- BufferizationState state(aliasInfo, *options.allocationFns);
-
// Bufferize all ops in funcOp.
- if (failed(
- bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes)))
+ if (failed(bufferizeFuncOpInternals(funcOp, state)))
return failure();
// Erase all obsolete ops.
- for (Operation *op : state.obsoleteOps)
- op->erase();
+ state.eraseObsoleteOps();
}
}
// Annotate operations if we only want to report the analysis.
@@ -1586,7 +1421,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
- bufferizedFunctionTypes)))
+ state.bufferizedFunctionTypes)))
return failure();
if (!options.allowReturnMemref &&
@@ -1986,10 +1821,142 @@ struct CallOpInterface
return BufferRelation::Equivalent;
}
+ /// In a first approximation, all the function arguments of a FuncOp are
+ /// marked inplaceable. For now, it is the responsibility of the `callOp`
+ /// bufferization to allow FuncOp that are inplaceable to write inPlace.
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
- llvm_unreachable("CallOps are handled separately");
- return failure();
+ CallOp callOp = cast<CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
+ "expected Callop to a FuncOp");
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(callOp);
+
+ // 1. Filter return types:
+ // - if the callee is bodiless / external, we cannot inspect it and we
+ // cannot assume anything. We can just assert that it does not return a
+ // tensor as this would have to bufferize to "return a memref", whose
+ // semantics is ill-defined.
+ // - if the callee has a body, we perform inter-procedural equivalence
+ // analysis. When successful, a result folds onto an operand. When
+ // unsuccessful, additional work is needed to either:
+ // * hoist a result into an inplaceable operand or
+ // * devise a better representation to truly return a buffer.
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> hoistedArguments;
+ if (funcOp.body().empty()) {
+ if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return callOp->emitError()
+ << "cannot bufferize bodiless function that returns a tensor";
+ } else {
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ assert(returnOp && "expected func with single return op");
+
+ // For each FuncOp result, keep track of which inplace argument it reuses.
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ Type returnType = returnOperand.get().getType();
+ if (!isaTensor(returnType)) {
+ resultTypes.push_back(returnType);
+ continue;
+ }
+
+ // If return operand is equivalent to some bbArg, no need to return it.
+ Value returnVal = returnOperand.get();
+ if (BlockArgument bbArg =
+ getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
+ Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
+ int64_t idx = bbArg.getArgNumber();
+ Value buffer = state.lookupBuffer(callOp->getOperand(idx));
+ // Add CallOp operand/result equivalence: this is interprocedural
+ // info.
+ state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+ state.mapBuffer(oldRes, buffer);
+ // Add a TensorLoadOp to kill all uses of the CallOp return.
+ // Replace all uses of the CallOp results so we can erase the CallOp.
+ // This TensorLoadOp must fold/DCE away or bufferization should be
+ // considered failed.
+ Value tensorLoad =
+ b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
+ oldRes.replaceAllUsesWith(tensorLoad);
+ // Add new op equivalence info.
+ state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
+ state.mapBuffer(tensorLoad, buffer);
+ continue;
+ }
+
+ // TODO: Need to hoist above function boundary.
+ if (Operation *allocOp =
+ getEquivalentAlloc(returnVal, state.aliasInfo)) {
+ hoistedArguments.push_back(allocOp->getResult(0));
+ continue;
+ }
+
+ // Other cases legitimately need to return a tensor, this is currently
+ // not supported. For instance, if hoisting across function boundary has
+ // failed, it may be due to e.g. data-dependent sizes. In such a case,
+ // we would we need a better type than memref.
+ resultTypes.push_back(returnType);
+
+ int64_t returnIdx = returnOperand.getOperandNumber();
+ return returnOp->emitError() << "buffer result #" << returnIdx
+ << " not produced by an alloc\n";
+ }
+ }
+
+ // 2. Compute bufferized FunctionType.
+ SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
+ ValueRange hoistedArgs{hoistedArguments};
+ llvm::append_range(argumentTypes, hoistedArgs.getTypes());
+ // Get the bufferized FunctionType for funcOp or construct it if not yet
+ // available.
+ FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+ funcOp, argumentTypes, resultTypes, state.bufferizedFunctionTypes);
+
+ // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
+ SmallVector<Value> newOperands;
+ newOperands.reserve(callOp->getNumOperands());
+ for (OpOperand &opOperand : callOp->getOpOperands()) {
+ Value tensorOperand = opOperand.get();
+ // Non-tensor operands are just copied.
+ if (!tensorOperand.getType().isa<TensorType>()) {
+ newOperands.push_back(tensorOperand);
+ continue;
+ }
+
+ // Tensor operands are guaranteed to have been buferized.
+ int64_t idx = opOperand.getOperandNumber();
+ Value buffer = state.lookupBuffer(tensorOperand);
+
+ // Caller / callee type mistmatch is handled with a CastOp.
+ auto memRefType = bufferizedFuncType.getInput(idx);
+ // Since we don't yet have a clear layout story, buffer_cast may
+ // conservatively turn tensors into more dynamic memref than necessary.
+ // If the memref type of the callee fails, introduce an extra memref.cast
+ // that will either canonicalize away or fail compilation until we can do
+ // something better.
+ if (buffer.getType() != memRefType) {
+ Value castBuffer =
+ b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+ // Add new op equivalence info.
+ state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+ state.mapBuffer(tensorOperand, castBuffer);
+ buffer = castBuffer;
+ }
+ newOperands.push_back(buffer);
+ }
+
+ // 4. Create the new CallOp.
+ Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
+ resultTypes, newOperands);
+ newCallOp->setAttrs(callOp->getAttrs());
+
+ // 5. Delete the op at the end of bufferization.
+ state.markOpObsolete(callOp);
+
+ return success();
}
};
More information about the Mlir-commits
mailing list