[Mlir-commits] [mlir] 06dacf5 - [mlir][func][bufferization][NFC] Simplify implementation
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 15 03:00:21 PDT 2023
Author: Matthias Springer
Date: 2023-08-15T12:00:12+02:00
New Revision: 06dacf5ea798741930fa714ec752c2e971625679
URL: https://github.com/llvm/llvm-project/commit/06dacf5ea798741930fa714ec752c2e971625679
DIFF: https://github.com/llvm/llvm-project/commit/06dacf5ea798741930fa714ec752c2e971625679.diff
LOG: [mlir][func][bufferization][NFC] Simplify implementation
The bufferization implementation of `func.func` and `func.call` can be simplified. It still contained code that was necessary when One-Shot Bufferize removed return values. This functionality has been extracted into a separate pass a while ago.
Differential Revision: https://reviews.llvm.org/D157893
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 1f55728cd5f533..da9b1d9868b571 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -438,16 +438,9 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Otherwise, we have to use a memref type with a fully dynamic layout map to
// avoid copies. We are currently missing patterns for layout maps to
// canonicalize away (or canonicalize to more precise layouts).
- //
- // FuncOps must be bufferized before their bodies, so add them to the worklist
- // first.
SmallVector<Operation *> worklist;
- op->walk([&](func::FuncOp funcOp) {
- if (hasTensorSemantics(funcOp))
- worklist.push_back(funcOp);
- });
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
- if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
+ if (hasTensorSemantics(op))
worklist.push_back(op);
});
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 89904db43508f5..fbefaf06e11737 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -197,68 +197,67 @@ struct CallOpInterface
return result;
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto callOp = cast<func::CallOp>(op);
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+
+ // The callee was already bufferized, so we can directly take the type from
+ // its signature.
+ FunctionType funcType = funcOp.getFunctionType();
+ return cast<BaseMemRefType>(
+ funcType.getResult(cast<OpResult>(value).getResultNumber()));
+ }
+
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
func::CallOp callOp = cast<func::CallOp>(op);
- unsigned numResults = callOp.getNumResults();
- unsigned numOperands = callOp->getNumOperands();
- FuncOp funcOp = getCalledFunction(callOp);
- assert(funcOp && "expected CallOp to a FuncOp");
- FunctionType funcType = funcOp.getFunctionType();
-
- // Result types of the bufferized CallOp.
- SmallVector<Type> resultTypes;
- // Replacement values for the existing CallOp. These are usually the results
- // of the bufferized CallOp, unless a tensor result folds onto an operand.
- SmallVector<Value> replacementValues(numResults, Value());
- // For non-tensor results: A mapping from return val indices of the old
- // CallOp to return val indices of the bufferized CallOp.
- SmallVector<std::optional<unsigned>> retValMapping(numResults,
- std::nullopt);
- // Operands of the bufferized CallOp.
- SmallVector<Value> newOperands(numOperands, Value());
// 1. Compute the result types of the new CallOp.
- for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
- unsigned returnValIdx = it.index();
- Type returnType = it.value();
+ SmallVector<Type> resultTypes;
+ for (Value result : callOp.getResults()) {
+ Type returnType = result.getType();
if (!isa<TensorType>(returnType)) {
// Non-tensor values are returned.
- retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType);
continue;
}
// Returning a memref.
- retValMapping[returnValIdx] = resultTypes.size();
- resultTypes.push_back(funcType.getResult(resultTypes.size()));
+ FailureOr<BaseMemRefType> resultType =
+ bufferization::getBufferType(result, options);
+ if (failed(resultType))
+ return failure();
+ resultTypes.push_back(*resultType);
}
- // 2. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
- for (OpOperand &opOperand : callOp->getOpOperands()) {
- unsigned idx = opOperand.getOperandNumber();
- Value tensorOperand = opOperand.get();
+ // 2. Rewrite tensor operands as memrefs based on type of the already
+ // bufferized callee.
+ SmallVector<Value> newOperands;
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(funcOp && "expected CallOp to a FuncOp");
+ FunctionType funcType = funcOp.getFunctionType();
+ for (OpOperand &opOperand : callOp->getOpOperands()) {
// Non-tensor operands are just copied.
- if (!isa<TensorType>(tensorOperand.getType())) {
- newOperands[idx] = tensorOperand;
+ if (!isa<TensorType>(opOperand.get().getType())) {
+ newOperands.push_back(opOperand.get());
continue;
}
// Retrieve buffers for tensor operands.
- Value buffer = newOperands[idx];
- if (!buffer) {
- FailureOr<Value> maybeBuffer =
- getBuffer(rewriter, opOperand.get(), options);
- if (failed(maybeBuffer))
- return failure();
- buffer = *maybeBuffer;
- }
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, opOperand.get(), options);
+ if (failed(maybeBuffer))
+ return failure();
+ Value buffer = *maybeBuffer;
// Caller / callee type mismatch is handled with a CastOp.
- auto memRefType = funcType.getInput(idx);
+ auto memRefType = funcType.getInput(opOperand.getOperandNumber());
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
@@ -272,22 +271,16 @@ struct CallOpInterface
memRefType, buffer);
buffer = castBuffer;
}
- newOperands[idx] = buffer;
+ newOperands.push_back(buffer);
}
// 3. Create the new CallOp.
Operation *newCallOp = rewriter.create<func::CallOp>(
callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
- // Get replacement values.
- for (unsigned i = 0; i < replacementValues.size(); ++i) {
- if (replacementValues[i])
- continue;
- replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
- }
// 4. Replace the old op with the new op.
- replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
+ replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->getResults());
return success();
}
@@ -326,6 +319,17 @@ struct ReturnOpInterface
struct FuncOpInterface
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto funcOp = cast<FuncOp>(op);
+ auto bbArg = cast<BlockArgument>(value);
+ // Unstructured control flow is not supported.
+ assert(bbArg.getOwner() == &funcOp.getBody().front() &&
+ "expected that block argument belongs to first block");
+ return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
+ }
+
/// Rewrite function bbArgs and return values into buffer form. This function
/// bufferizes the function signature and the ReturnOp. When the entire
/// function body has been bufferized, function return types can be switched
@@ -384,9 +388,11 @@ struct FuncOpInterface
bbArgUses.push_back(&use);
// Change the bbArg type to memref.
- Type memrefType =
- getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
- bbArg.setType(memrefType);
+ FailureOr<BaseMemRefType> memrefType =
+ bufferization::getBufferType(bbArg, options);
+ if (failed(memrefType))
+ return failure();
+ bbArg.setType(*memrefType);
// Replace all uses of the original tensor bbArg.
rewriter.setInsertionPointToStart(&frontBlock);
More information about the Mlir-commits
mailing list