[Mlir-commits] [mlir] a98c5a0 - [mlir][linalg][bufferize] Fix CallOps with non-tensor operands
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 5 07:19:45 PST 2022
Author: Matthias Springer
Date: 2022-01-06T00:19:23+09:00
New Revision: a98c5a08b15ebc31b9303206b10a50d9949eb4d4
URL: https://github.com/llvm/llvm-project/commit/a98c5a08b15ebc31b9303206b10a50d9949eb4d4
DIFF: https://github.com/llvm/llvm-project/commit/a98c5a08b15ebc31b9303206b10a50d9949eb4d4.diff
LOG: [mlir][linalg][bufferize] Fix CallOps with non-tensor operands
Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for.
Differential Revision: https://reviews.llvm.org/D116445
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ee4ced0b17396..f919bbfc2363d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,6 +490,19 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
+/// Return the index of the parent function's bbArg that is equivalent to the
+/// given ReturnOp operand (if any).
+static Optional<int64_t>
+getEquivalentFuncArgIdx(ModuleBufferizationState &state,
+ OpOperand &returnOperand) {
+ FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
+ if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
+ // Return value has no equivalent bbArg.
+ return None;
+
+ return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
+}
+
struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@@ -515,57 +528,67 @@ struct CallOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
+ unsigned numResults = callOp.getNumResults();
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
- "expected Callop to a FuncOp");
+ "expected CallOp to a FuncOp");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- // 1. Filter return types:
- // - if the callee is bodiless / external, we cannot inspect it and we
- // cannot assume anything. We can just assert that it does not return a
- // tensor as this would have to bufferize to "return a memref", whose
- // semantics is ill-defined.
- // - if the callee has a body, we perform inter-procedural equivalence
- // analysis. When successful, a result folds onto an operand. When
- // unsuccessful, additional work is needed (TODO) to either:
- // * hoist a result into an inplaceable operand or
- // * devise a better representation to truly return a buffer.
+ // Result types of the bufferized CallOp.
SmallVector<Type> resultTypes;
+ // Replacement values for the existing CallOp. These are usually the results
+ // of the bufferized CallOp, unless a tensor result folds onto an operand.
+ SmallVector<Value> replacementValues(numResults, Value());
+ // For non-tensor results: A mapping from return val indices of the old
+ // CallOp to return val indices of the bufferized CallOp.
+ SmallVector<Optional<unsigned>> retValMapping(numResults, None);
+
if (funcOp.body().empty()) {
- if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
- return callOp->emitError()
- << "cannot bufferize bodiless function that returns a tensor";
+ // The callee is bodiless / external, so we cannot inspect it and we
+ // cannot assume anything. We can just assert that it does not return a
+ // tensor as this would have to bufferize to "return a memref", whose
+ // semantics is ill-defined.
+ for (int i = 0; i < numResults; ++i) {
+ Type returnType = callOp.getResult(i).getType();
+ if (isaTensor(returnType))
+ return callOp->emitError()
+ << "cannot bufferize bodiless function that returns a tensor";
+ resultTypes.push_back(returnType);
+ retValMapping[i] = i;
+ }
} else {
+ // The callee has a body. Based on previously gathered equivalence
+ // information, we know if a tensor result folds onto an operand. These
+ // are the only tensor value returns that are supported at the moment.
+ //
+ // For tensors return values that do not fold onto an operand, additional
+ // work is needed (TODO) to either:
+ // * hoist a result into an inplaceable operand or
+ // * devise a better representation to truly return a buffer.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ unsigned returnIdx = returnOperand.getOperandNumber();
Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) {
+ // Non-tensor values are returned.
+ retValMapping[returnIdx] = resultTypes.size();
resultTypes.push_back(returnType);
continue;
}
- // If return operand is equivalent to some bbArg, no need to return it.
- if (moduleState.equivalentFuncArgs[funcOp].count(
- returnOperand.getOperandNumber())) {
- int64_t idx =
- moduleState
- .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
- Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
- Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
- // Add a ToTensorOp to kill all uses of the CallOp return.
- // Replace all uses of the CallOp results so we can erase the CallOp.
- // This ToTensorOp must fold/DCE away or bufferization should be
- // considered failed.
- Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
- callOp.getLoc(), buffer);
- oldRes.replaceAllUsesWith(toTensorOp);
+ if (Optional<int64_t> bbArgIdx =
+ getEquivalentFuncArgIdx(moduleState, returnOperand)) {
+ // Return operands that are equivalent to some bbArg, are not
+ // returned.
+ replacementValues[returnIdx] =
+ state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
continue;
}
- resultTypes.push_back(returnType);
+ llvm_unreachable("returning non-equivalent tensors not supported");
}
}
@@ -612,8 +635,13 @@ struct CallOpInterface
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
- // 5. Delete the op at the end of bufferization.
- callOp->erase();
+ // 5. Replace the old op with the new op.
+ for (int i = 0; i < replacementValues.size(); ++i) {
+ if (replacementValues[i])
+ continue;
+ replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
+ }
+ state.replaceOp(rewriter, callOp, replacementValues);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index f55f3008f2bd2..8f08e37c6774e 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
// -----
+// CHECK-LABEL: func @inner_func(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+ // CHECK-NOT: copy
+ %f = arith.constant 1.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: memref.store %{{.*}}, %[[arg0]]
+ %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+ // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+ %1 = tensor.extract %0[%c1] : tensor<?xf32>
+ // CHECK: return %[[load]] : f32
+ return %0, %1 : tensor<?xf32>, f32
+}
+
+// CHECK-LABEL: func @call_func_with_non_tensor_return(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
+func @call_func_with_non_tensor_return(
+ %t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
+ // CHECK-NOT: copy
+ // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+ %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+ // CHECK: return %[[call]] : f32
+ return %1, %0 : f32, tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @func_without_tensor_args
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// CHECK: %[[alloc:.*]] = memref.alloc()
More information about the Mlir-commits
mailing list