[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)

Gil Rapaport llvmlistbot at llvm.org
Mon Apr 8 06:27:35 PDT 2024


================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+  assert(!ops.empty() && "Expected non-empty operation list");
+  Operation *earliest = ops[0];
+  Operation *latest = ops[ops.size() - 1];
+  Block *block = earliest->getBlock();
+  if (latest->getBlock() != block)
+    return std::nullopt;
+  if (latest->isBeforeInBlock(earliest))
+    std::swap(earliest, latest);
+  for (Operation *op : ops) {
+    if (op->getBlock() != block)
+      return std::nullopt;
+    if (op->isBeforeInBlock(earliest))
+      earliest = op;
+    else if (latest->isBeforeInBlock(op))
+      latest = op;
+    else
+      ;
+  }
+
+  // Make sure all operations between earliest and latest are in ops.
+  auto end = latest->getIterator();
+  for (auto it = earliest->getIterator(); it != end; ++it)
+    if (!ops.contains(&*it))
+      return std::nullopt;
+
+  return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  Value ops = getWhat();
+  SetVector<Operation *> opsPayload;
+  SmallVector<Operation *> opsClones;
+  SmallVector<Value> liveIns;
+  DenseMap<Value, unsigned> liveInIndex;
+  SmallVector<Value> liveOuts;
+  DenseMap<Value, unsigned> liveOutIndex;
+  SmallVector<Type> resultTypes;
+
+  for (Operation *payload : state.getPayloadOps(ops))
+    opsPayload.insert(payload);
+
+  Operation *wherePayload = nullptr;
+  Value whereOp = getWhereOp();
+  auto whereAttr = getWhere();
+  RelativeLocation where;
+  if (whereOp) {
+    auto wherePayloadOps = state.getPayloadOps(whereOp);
+    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+      auto diag = emitDefiniteFailure()
----------------
aniragil wrote:

Done

https://github.com/llvm/llvm-project/pull/87352


More information about the Mlir-commits mailing list