[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)

Gil Rapaport llvmlistbot at llvm.org
Thu Apr 18 03:51:31 PDT 2024


================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+  assert(!ops.empty() && "Expected non-empty operation list");
+  Operation *earliest = ops[0];
+  Operation *latest = ops[ops.size() - 1];
+  Block *block = earliest->getBlock();
+  if (latest->getBlock() != block)
+    return std::nullopt;
+  if (latest->isBeforeInBlock(earliest))
+    std::swap(earliest, latest);
+  for (Operation *op : ops) {
+    if (op->getBlock() != block)
+      return std::nullopt;
+    if (op->isBeforeInBlock(earliest))
+      earliest = op;
+    else if (latest->isBeforeInBlock(op))
+      latest = op;
+    else
+      ;
+  }
+
+  // Make sure all operations between earliest and latest are in ops.
+  auto end = latest->getIterator();
+  for (auto it = earliest->getIterator(); it != end; ++it)
+    if (!ops.contains(&*it))
+      return std::nullopt;
+
+  return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  Value ops = getWhat();
+  SetVector<Operation *> opsPayload;
+  SmallVector<Operation *> opsClones;
+  SmallVector<Value> liveIns;
+  DenseMap<Value, unsigned> liveInIndex;
+  SmallVector<Value> liveOuts;
+  DenseMap<Value, unsigned> liveOutIndex;
+  SmallVector<Type> resultTypes;
+
+  for (Operation *payload : state.getPayloadOps(ops))
+    opsPayload.insert(payload);
+
+  Operation *wherePayload = nullptr;
+  Value whereOp = getWhereOp();
+  auto whereAttr = getWhere();
+  RelativeLocation where;
+  if (whereOp) {
+    auto wherePayloadOps = state.getPayloadOps(whereOp);
+    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+      auto diag = emitDefiniteFailure()
+                  << "expects a single location for the scope";
+      diag.attachNote(whereOp.getLoc()) << "single location";
+      return diag;
+    }
+    wherePayload = *wherePayloadOps.begin();
+    where = *whereAttr;
+  } else {
+    // No insertion point specified, so payload ops must form an interval.
+    auto interval = getInterval(opsPayload);
+    if (!interval) {
+      auto diag = emitDefiniteFailure()
+                  << "payload ops must form an interval unless insertion point "
+                     "is specified";
+      diag.attachNote(ops.getLoc()) << "not an interval";
+      return diag;
+    }
+    Operation *last;
+    std::tie(std::ignore, last) = *interval;
+    wherePayload = last;
+    where = RelativeLocation::After;
+  }
+
+  auto isInScope = [&opsPayload](Operation *operation) {
+    return opsPayload.contains(operation);
+  };
+
+  unsigned nextScopeOperandIndex = 0;
+  unsigned nextScopeResultIndex = 0;
+
+  for (auto payload : opsPayload) {
+    if (payload == getOperation()) {
+      auto diag = emitDefiniteFailure()
+                  << "scope ops must not contain the transform being applied";
+      diag.attachNote(payload->getLoc()) << "scope";
+      return diag;
+    }
+
+    for (Value operand : payload->getOperands()) {
+      if (liveInIndex.contains(operand))
+        continue; // Already set as a live in.
+
+      Operation *def = operand.getDefiningOp();
+      if (def && opsPayload.contains(def))
+        continue; // Will be defined within scope, so not a live in.
+
+      // Set this operand as the next operand of the scope op.
+      liveIns.push_back(operand);
+      liveInIndex[operand] = nextScopeOperandIndex++;
+    }
+    for (Value result : payload->getResults()) {
+      if (liveOutIndex.contains(result))
+        continue; // Already set as a live out.
+      if (llvm::all_of(result.getUsers(), isInScope))
+        continue; // All users are in scope, so not a live out.
+
+      // Set this result as the next result of the scope op.
+      liveOuts.push_back(result);
+      liveOutIndex[result] = nextScopeResultIndex++;
+      resultTypes.push_back(result.getType());
+    }
+  }
+
+  if (where == RelativeLocation::Before)
+    rewriter.setInsertionPoint(wherePayload);
+  else
+    rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+  TypeRange noTypes;
+  ValueRange noValues;
+  auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+      wherePayload->getLoc(), noTypes, noValues);
+  Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+  rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+  auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+                                             resultTypes, liveIns);
+  Region *scopeBody = &scope.getBody();
+  // TODO: Move into builder.
+  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+  // TODO: Move into builder.
+  for (Value arg : liveIns)
+    scopeBlock->addArgument(arg.getType(), arg.getLoc());
+
+  IRMapping mapper;
+  for (Value liveIn : liveIns)
+    mapper.map(liveIn, scopeBlock->getArgument(liveInIndex[liveIn]));
+
+  rewriter.setInsertionPointToEnd(scopeBlock);
+  for (auto payload : opsPayload)
+    rewriter.clone(*payload, mapper);
+
+  SmallVector<Value> scopeResults;
+  for (Value liveOut : liveOuts)
+    scopeResults.push_back(mapper.lookup(liveOut));
+  rewriter.create<scf::YieldOp>(scope.getLoc(), scopeResults);
+
+  Region &body = getBody();
+
+  DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
+
+  // Create a region scope for the transform::as_scope region and run the
+  // transformations it contains. This is done in a syntactic C++ scope since
+  // we'll be deleting the scf::scope, and its mapping will become invalid if
+  // left until compactOpHandles() is called.
+  {
+    auto regionScope = state.make_region_scope(body);
+    if (failed(state.mapBlockArguments(body.front().getArgument(0), {scope})))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (Operation &transform : body.front().without_terminator()) {
+      result = state.applyTransform(cast<TransformOpInterface>(transform));
+      if (result.isSilenceableFailure()) {
+        LLVM_DEBUG(DBGS() << "transformation on scope failed: "
+                          << result.getMessage() << "\n");
+        break;
+      }
+
+      if (::mlir::failed(result.silence()))
+        return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+#ifdef WRAP_IN_EXECUTE_REGION
+    // Since the transformations applied may have replaced the scope, get the
+    // current scope from the wrapping scf.execute_region.
+    scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
+    scope->moveAfter(executeRegion);
+    rewriter.eraseOp(executeRegion);
+#else
----------------
aniragil wrote:

keeping the wrapping op solution for now, as this seems to require a broader discussion.

https://github.com/llvm/llvm-project/pull/87352


More information about the Mlir-commits mailing list