[Mlir-commits] [mlir] [mlir][transform] Add an op for replacing values with function calls (PR #78398)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Jan 17 01:58:13 PST 2024


================
@@ -36,6 +37,202 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// CastAndCallOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
+                                transform::TransformResults &results,
+                                transform::TransformState &state) {
+  SmallVector<Value> inputs;
+  if (getInputs())
+    for (Value input : state.getPayloadValues(getInputs()))
+      inputs.push_back(input);
+  SmallVector<Value> outputs;
+  if (getOutputs())
+    for (Value output : state.getPayloadValues(getOutputs()))
+      outputs.push_back(output);
+
+  // Verify that the set of output values to be replaced is unique.
+  llvm::SmallDenseSet<Value> outputSet;
+  for (Value output : outputs) {
+    outputSet.insert(output);
+  }
+  if (outputSet.size() != outputs.size()) {
+    return emitSilenceableFailure(getLoc())
+           << "cast and call output values must be unique";
+  }
+
+  // Get the insertion point for the call.
+  auto insertionOps = state.getPayloadOps(getInsertionPoint());
+  if (!llvm::hasSingleElement(insertionOps)) {
+    return emitSilenceableFailure(getLoc())
+           << "Only one op can be specified as an insertion point";
+  }
+  bool insertAfter = getInsertAfter();
+  Operation *insertionPoint = *insertionOps.begin();
+
+  // Check that all inputs dominate the insertion point, and the insertion
+  // point dominates all users of the outputs.
+  DominanceInfo dom(insertionPoint);
+  for (Value output : outputs) {
+    for (Operation *user : output.getUsers()) {
+      // If we are inserting after the insertion point operation, the
+      // insertion point operation must properly dominate the user. Otherwise
+      // basic dominance is enough.
+      bool doesDominate = insertAfter
+                              ? dom.properlyDominates(insertionPoint, user)
+                              : dom.dominates(insertionPoint, user);
+      if (!doesDominate) {
+        return emitDefiniteFailure()
+               << "User " << user << " is not dominated by insertion point "
+               << insertionPoint;
+      }
+    }
+  }
+
+  for (Value input : inputs) {
+    // If we are inserting before the insertion point operation, the
+    // input must properly dominate the insertion point operation. Otherwise
+    // basic dominance is enough.
+    bool doesDominate = insertAfter
+                            ? dom.dominates(input, insertionPoint)
+                            : dom.properlyDominates(input, insertionPoint);
+    if (!doesDominate) {
+      return emitDefiniteFailure()
+             << "input " << input << " does not dominate insertion point "
+             << insertionPoint;
+    }
+  }
+
+  // Get the function to inline. This can either be specified by symbol or as a
+  // transform handle.
+  func::FuncOp targetFunction = nullptr;
+  if (getFunctionName()) {
+    targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
+        insertionPoint, *getFunctionName());
+    if (!targetFunction) {
+      return emitDefiniteFailure()
+             << "unresolved symbol " << *getFunctionName();
+    }
+  } else if (getFunction()) {
+    auto payloadOps = state.getPayloadOps(getFunction());
+    if (!llvm::hasSingleElement(payloadOps)) {
+      return emitDefiniteFailure() << "requires a single function to call";
+    }
+    targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
+    if (!targetFunction) {
+      return emitDefiniteFailure() << "invalid non-function callee";
+    }
+  } else {
+    llvm_unreachable("Invalid CastAndCall op without a function to call");
+    return emitDefiniteFailure();
+  }
+  assert(targetFunction && "no target function found");
+
+  // Verify that the function argument and result lengths match the inputs and
+  // outputs given to this op.
+  if (targetFunction.getNumArguments() != inputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function arguments "
+           << targetFunction.getNumArguments() << " and number of inputs "
+           << inputs.size();
+  }
+  if (targetFunction.getNumResults() != outputs.size()) {
+    return emitSilenceableFailure(targetFunction.getLoc())
+           << "mismatch between number of function results "
+           << targetFunction->getNumResults() << " and number of outputs "
+           << outputs.size();
+  }
+
+  // Gather all specified converters.
+  MLIRContext *ctx = insertionPoint->getContext();
+  mlir::TypeConverter converter;
+  if (!getRegion().empty()) {
+    for (Operation &op : getRegion().front()) {
+      cast<transform::TypeConversionOpInterface>(&op)
+          .populateTypeMaterializations(converter);
+    }
+  }
+
+  OpBuilder builder(ctx);
----------------
ftynse wrote:

This should use the rewriter provided by to the `apply` call.

https://github.com/llvm/llvm-project/pull/78398


More information about the Mlir-commits mailing list