[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)
Gil Rapaport
llvmlistbot at llvm.org
Mon Apr 8 06:27:35 PDT 2024
================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+ assert(!ops.empty() && "Expected non-empty operation list");
+ Operation *earliest = ops[0];
+ Operation *latest = ops[ops.size() - 1];
+ Block *block = earliest->getBlock();
+ if (latest->getBlock() != block)
+ return std::nullopt;
+ if (latest->isBeforeInBlock(earliest))
+ std::swap(earliest, latest);
+ for (Operation *op : ops) {
+ if (op->getBlock() != block)
+ return std::nullopt;
+ if (op->isBeforeInBlock(earliest))
+ earliest = op;
+ else if (latest->isBeforeInBlock(op))
+ latest = op;
+ else
+ ;
+ }
+
+ // Make sure all operations between earliest and latest are in ops.
+ auto end = latest->getIterator();
+ for (auto it = earliest->getIterator(); it != end; ++it)
+ if (!ops.contains(&*it))
+ return std::nullopt;
+
+ return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ Value ops = getWhat();
+ SetVector<Operation *> opsPayload;
+ SmallVector<Operation *> opsClones;
+ SmallVector<Value> liveIns;
+ DenseMap<Value, unsigned> liveInIndex;
+ SmallVector<Value> liveOuts;
+ DenseMap<Value, unsigned> liveOutIndex;
+ SmallVector<Type> resultTypes;
+
+ for (Operation *payload : state.getPayloadOps(ops))
+ opsPayload.insert(payload);
+
+ Operation *wherePayload = nullptr;
+ Value whereOp = getWhereOp();
+ auto whereAttr = getWhere();
+ RelativeLocation where;
+ if (whereOp) {
+ auto wherePayloadOps = state.getPayloadOps(whereOp);
+ if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+ auto diag = emitDefiniteFailure()
----------------
aniragil wrote:
Done
https://github.com/llvm/llvm-project/pull/87352
More information about the Mlir-commits
mailing list