[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)

Gil Rapaport llvmlistbot at llvm.org
Mon Apr 8 06:27:58 PDT 2024


================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+  assert(!ops.empty() && "Expected non-empty operation list");
+  Operation *earliest = ops[0];
+  Operation *latest = ops[ops.size() - 1];
+  Block *block = earliest->getBlock();
+  if (latest->getBlock() != block)
+    return std::nullopt;
+  if (latest->isBeforeInBlock(earliest))
+    std::swap(earliest, latest);
+  for (Operation *op : ops) {
+    if (op->getBlock() != block)
+      return std::nullopt;
+    if (op->isBeforeInBlock(earliest))
+      earliest = op;
+    else if (latest->isBeforeInBlock(op))
+      latest = op;
+    else
+      ;
+  }
+
+  // Make sure all operations between earliest and latest are in ops.
+  auto end = latest->getIterator();
+  for (auto it = earliest->getIterator(); it != end; ++it)
+    if (!ops.contains(&*it))
+      return std::nullopt;
+
+  return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  Value ops = getWhat();
+  SetVector<Operation *> opsPayload;
+  SmallVector<Operation *> opsClones;
+  SmallVector<Value> liveIns;
+  DenseMap<Value, unsigned> liveInIndex;
+  SmallVector<Value> liveOuts;
+  DenseMap<Value, unsigned> liveOutIndex;
+  SmallVector<Type> resultTypes;
+
+  for (Operation *payload : state.getPayloadOps(ops))
+    opsPayload.insert(payload);
+
+  Operation *wherePayload = nullptr;
+  Value whereOp = getWhereOp();
+  auto whereAttr = getWhere();
+  RelativeLocation where;
+  if (whereOp) {
+    auto wherePayloadOps = state.getPayloadOps(whereOp);
+    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+      auto diag = emitDefiniteFailure()
+                  << "expects a single location for the scope";
+      diag.attachNote(whereOp.getLoc()) << "single location";
+      return diag;
+    }
+    wherePayload = *wherePayloadOps.begin();
+    where = *whereAttr;
+  } else {
+    // No insertion point specified, so payload ops must form an interval.
+    auto interval = getInterval(opsPayload);
+    if (!interval) {
+      auto diag = emitDefiniteFailure()
+                  << "payload ops must form an interval unless insertion point "
+                     "is specified";
+      diag.attachNote(ops.getLoc()) << "not an interval";
+      return diag;
+    }
+    Operation *last;
+    std::tie(std::ignore, last) = *interval;
+    wherePayload = last;
+    where = RelativeLocation::After;
+  }
+
+  auto isInScope = [&opsPayload](Operation *operation) {
+    return opsPayload.contains(operation);
+  };
+
+  unsigned nextScopeOperandIndex = 0;
+  unsigned nextScopeResultIndex = 0;
+
+  for (auto payload : opsPayload) {
+    if (payload == getOperation()) {
+      auto diag = emitDefiniteFailure()
+                  << "scope ops must not contain the transform being applied";
+      diag.attachNote(payload->getLoc()) << "scope";
+      return diag;
+    }
+
+    for (Value operand : payload->getOperands()) {
+      if (liveInIndex.contains(operand))
+        continue; // Already set as a live in.
+
+      Operation *def = operand.getDefiningOp();
+      if (def && opsPayload.contains(def))
+        continue; // Will be defined within scope, so not a live in.
+
+      // Set this operand as the next operand of the scope op.
+      liveIns.push_back(operand);
+      liveInIndex[operand] = nextScopeOperandIndex++;
+    }
+    for (Value result : payload->getResults()) {
+      if (liveOutIndex.contains(result))
+        continue; // Already set as a live out.
+      if (llvm::all_of(result.getUsers(), isInScope))
+        continue; // All users are in scope, so not a live out.
+
+      // Set this result as the next result of the scope op.
+      liveOuts.push_back(result);
+      liveOutIndex[result] = nextScopeResultIndex++;
+      resultTypes.push_back(result.getType());
+    }
+  }
+
+  if (where == RelativeLocation::Before)
+    rewriter.setInsertionPoint(wherePayload);
+  else
+    rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+  TypeRange noTypes;
+  ValueRange noValues;
+  auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+      wherePayload->getLoc(), noTypes, noValues);
+  Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+  rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+  auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+                                             resultTypes, liveIns);
+  Region *scopeBody = &scope.getBody();
+  // TODO: Move into builder.
+  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+  // TODO: Move into builder.
----------------
aniragil wrote:

Done

https://github.com/llvm/llvm-project/pull/87352


More information about the Mlir-commits mailing list