[Mlir-commits] [mlir] [mlir][transform] Add an op for replacing values with function calls (PR #78398)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Jan 17 01:58:13 PST 2024
================
@@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
return success();
}
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> inputs;
+ if (getInputs())
+ for (Value input : state.getPayloadValues(getInputs()))
+ inputs.push_back(input);
+ SmallVector<Value> outputs;
+ if (getOutputs())
+ for (Value output : state.getPayloadValues(getOutputs()))
+ outputs.push_back(output);
+
+ // Verify that the set of output values to be replaced is unique.
+ llvm::SmallDenseSet<Value> outputSet;
+ for (Value output : outputs) {
+ outputSet.insert(output);
+ }
+ if (outputSet.size() != outputs.size()) {
+ return emitSilenceableFailure(getLoc())
+ << "cast and call output values must be unique";
+ }
+
+ // Get the insertion point for the call.
+ auto insertionOps = state.getPayloadOps(getInsertionPoint());
+ if (!llvm::hasSingleElement(insertionOps)) {
+ return emitSilenceableFailure(getLoc())
+ << "Only one op can be specified as an insertion point";
+ }
+ bool insertAfter = getInsertAfter();
+ Operation *insertionPoint = *insertionOps.begin();
+
+ // Check that all inputs dominate the insertion point, and the insertion
+ // point dominates all users of the outputs.
+ DominanceInfo dom(insertionPoint);
+ for (Value output : outputs) {
+ for (Operation *user : output.getUsers()) {
+ // If we are inserting after the insertion point operation, the
+ // insertion point operation must properly dominate the user. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.properlyDominates(insertionPoint, user)
+ : dom.dominates(insertionPoint, user);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "User " << user << " is not dominated by insertion point "
+ << insertionPoint;
+ }
+ }
+ }
+
+ for (Value input : inputs) {
+ // If we are inserting before the insertion point operation, the
+ // input must properly dominate the insertion point operation. Otherwise
+ // basic dominance is enough.
+ bool doesDominate = insertAfter
+ ? dom.dominates(input, insertionPoint)
+ : dom.properlyDominates(input, insertionPoint);
+ if (!doesDominate) {
+ return emitDefiniteFailure()
+ << "input " << input << " does not dominate insertion point "
+ << insertionPoint;
+ }
+ }
+
+ // Get the function to inline. This can either be specified by symbol or as a
+ // transform handle.
+ func::FuncOp targetFunction = nullptr;
+ if (getFunctionName()) {
+ targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+ insertionPoint, *getFunctionName());
+ if (!targetFunction) {
+ return emitDefiniteFailure()
+ << "unresolved symbol " << *getFunctionName();
+ }
+ } else if (getFunction()) {
+ auto payloadOps = state.getPayloadOps(getFunction());
+ if (!llvm::hasSingleElement(payloadOps)) {
+ return emitDefiniteFailure() << "requires a single function to call";
+ }
+ targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+ if (!targetFunction) {
+ return emitDefiniteFailure() << "invalid non-function callee";
+ }
+ } else {
+ llvm_unreachable("Invalid CastAndCall op without a function to call");
+ return emitDefiniteFailure();
+ }
+ assert(targetFunction && "no target function found");
+
+ // Verify that the function argument and result lengths match the inputs and
+ // outputs given to this op.
+ if (targetFunction.getNumArguments() != inputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function arguments "
+ << targetFunction.getNumArguments() << " and number of inputs "
+ << inputs.size();
+ }
+ if (targetFunction.getNumResults() != outputs.size()) {
+ return emitSilenceableFailure(targetFunction.getLoc())
+ << "mismatch between number of function results "
+ << targetFunction->getNumResults() << " and number of outputs "
+ << outputs.size();
+ }
+
+ // Gather all specified converters.
+ MLIRContext *ctx = insertionPoint->getContext();
+ mlir::TypeConverter converter;
+ if (!getRegion().empty()) {
+ for (Operation &op : getRegion().front()) {
+ cast<transform::TypeConversionOpInterface>(&op)
+ .populateTypeMaterializations(converter);
+ }
+ }
+
+ OpBuilder builder(ctx);
----------------
ftynse wrote:
This should use the rewriter provided by to the `apply` call.
https://github.com/llvm/llvm-project/pull/78398
More information about the Mlir-commits
mailing list