[Mlir-commits] [mlir] b15b015 - [mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps
Matthias Springer
llvmlistbot at llvm.org
Wed Jan 5 07:29:12 PST 2022
Author: Matthias Springer
Date: 2022-01-06T00:28:47+09:00
New Revision: b15b0156cae73c12fe9251688266c0b2302e5d05
URL: https://github.com/llvm/llvm-project/commit/b15b0156cae73c12fe9251688266c0b2302e5d05
DIFF: https://github.com/llvm/llvm-project/commit/b15b0156cae73c12fe9251688266c0b2302e5d05.diff
LOG: [mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps
There is no need to inspect the ReturnOp of the called function.
This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.)
Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message.
Differential Revision: https://reviews.llvm.org/D116446
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index f919bbfc2363..c84545dc21df 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,17 +490,16 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
-/// Return the index of the parent function's bbArg that is equivalent to the
-/// given ReturnOp operand (if any).
+/// Return the index of the bbArg in the given FuncOp that is equivalent to the
+/// specified return value (if any).
static Optional<int64_t>
-getEquivalentFuncArgIdx(ModuleBufferizationState &state,
- OpOperand &returnOperand) {
- FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
- if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
+getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
+ int64_t returnValIdx) {
+ if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
// Return value has no equivalent bbArg.
return None;
- return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
+ return state.equivalentFuncArgs[funcOp][returnValIdx];
}
struct CallOpInterface
@@ -529,6 +528,7 @@ struct CallOpInterface
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults();
+ unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected CallOp to a FuncOp");
@@ -542,54 +542,48 @@ struct CallOpInterface
// For non-tensor results: A mapping from return val indices of the old
// CallOp to return val indices of the bufferized CallOp.
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
-
- if (funcOp.body().empty()) {
- // The callee is bodiless / external, so we cannot inspect it and we
- // cannot assume anything. We can just assert that it does not return a
- // tensor as this would have to bufferize to "return a memref", whose
- // semantics is ill-defined.
- for (int i = 0; i < numResults; ++i) {
- Type returnType = callOp.getResult(i).getType();
- if (isaTensor(returnType))
- return callOp->emitError()
- << "cannot bufferize bodiless function that returns a tensor";
+ // Operands of the bufferized CallOp.
+ SmallVector<Value> newOperands(numOperands, Value());
+
+ // Based on previously gathered equivalence information, we know if a
+ // tensor result folds onto an operand. These are the only tensor value
+ // results that are supported at the moment.
+ //
+ // For tensors return values that do not fold onto an operand, additional
+ // work is needed (TODO) to either:
+ // * hoist a result into an inplaceable operand or
+ // * devise a better representation to truly return a buffer.
+ //
+ // Note: If a function has no body, no equivalence information is
+ // available. Consequently, a tensor return value cannot be proven to fold
+ // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
+ // the moment.
+
+ // 1. Compute the result types of the new CallOp. Tensor results that are
+ // equivalent to a FuncOp bbArg are no longer returned.
+ for (auto it : llvm::enumerate(callOp.getResultTypes())) {
+ unsigned returnValIdx = it.index();
+ Type returnType = it.value();
+ if (!isaTensor(returnType)) {
+ // Non-tensor values are returned.
+ retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType);
- retValMapping[i] = i;
+ continue;
}
- } else {
- // The callee has a body. Based on previously gathered equivalence
- // information, we know if a tensor result folds onto an operand. These
- // are the only tensor value returns that are supported at the moment.
- //
- // For tensors return values that do not fold onto an operand, additional
- // work is needed (TODO) to either:
- // * hoist a result into an inplaceable operand or
- // * devise a better representation to truly return a buffer.
- ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- // For each FuncOp result, keep track of which inplace argument it reuses.
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- unsigned returnIdx = returnOperand.getOperandNumber();
- Type returnType = returnOperand.get().getType();
- if (!isaTensor(returnType)) {
- // Non-tensor values are returned.
- retValMapping[returnIdx] = resultTypes.size();
- resultTypes.push_back(returnType);
- continue;
- }
-
- if (Optional<int64_t> bbArgIdx =
- getEquivalentFuncArgIdx(moduleState, returnOperand)) {
- // Return operands that are equivalent to some bbArg, are not
- // returned.
- replacementValues[returnIdx] =
- state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
- continue;
- }
-
- llvm_unreachable("returning non-equivalent tensors not supported");
+
+ if (Optional<int64_t> bbArgIdx =
+ getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
+ // Return operands that are equivalent to some bbArg, are not
+ // returned.
+ Value buffer =
+ state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
+ replacementValues[returnValIdx] = buffer;
+ newOperands[*bbArgIdx] = buffer;
+ continue;
}
+
+ return callOp->emitError(
+ "call to FuncOp that returns non-equivalent tensors not supported");
}
// 2. Compute bufferized FunctionType.
@@ -601,23 +595,26 @@ struct CallOpInterface
moduleState.bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
- SmallVector<Value> newOperands;
- newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) {
+ unsigned idx = opOperand.getOperandNumber();
Value tensorOperand = opOperand.get();
+
// Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) {
- newOperands.push_back(tensorOperand);
+ newOperands[idx] = tensorOperand;
continue;
}
- // Tensor operands are guaranteed to have been buferized.
- int64_t idx = opOperand.getOperandNumber();
- Value buffer = state.lookupBuffer(rewriter, tensorOperand);
+ // Retrieve buffers for tensor operands. Tensor operand buffers, who's
+ // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
+ // already stored in `newOperands` during Step 1.
+ Value buffer = newOperands[idx]
+ ? newOperands[idx]
+ : state.lookupBuffer(rewriter, tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
- // Since we don't yet have a clear layout story, buffer_cast may
+ // Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
@@ -627,20 +624,21 @@ struct CallOpInterface
memRefType, buffer);
buffer = castBuffer;
}
- newOperands.push_back(buffer);
+ newOperands[idx] = buffer;
}
// 4. Create the new CallOp.
Operation *newCallOp = rewriter.create<CallOp>(
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
-
- // 5. Replace the old op with the new op.
+ // Get replacement values for non-tensor / non-equivalent results.
for (int i = 0; i < replacementValues.size(); ++i) {
if (replacementValues[i])
continue;
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
}
+
+ // 5. Replace the old op with the new op.
state.replaceOp(rewriter, callOp, replacementValues);
return success();
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 02431d9175a9..14bb6ce48f2d 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
return %r1, %r2 : vector<5xf32>, vector<5xf32>
}
+
+// -----
+
+func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+ // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
+ call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+ return
+}
+
+// -----
+
+func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
+ %0 = linalg.init_tensor [5] : tensor<5xf32>
+ return %0 : tensor<5xf32>
+}
+
+func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
+ // expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
+ call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return
+}
More information about the Mlir-commits
mailing list