[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)
Gil Rapaport
llvmlistbot at llvm.org
Mon Apr 8 06:27:58 PDT 2024
================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+ assert(!ops.empty() && "Expected non-empty operation list");
+ Operation *earliest = ops[0];
+ Operation *latest = ops[ops.size() - 1];
+ Block *block = earliest->getBlock();
+ if (latest->getBlock() != block)
+ return std::nullopt;
+ if (latest->isBeforeInBlock(earliest))
+ std::swap(earliest, latest);
+ for (Operation *op : ops) {
+ if (op->getBlock() != block)
+ return std::nullopt;
+ if (op->isBeforeInBlock(earliest))
+ earliest = op;
+ else if (latest->isBeforeInBlock(op))
+ latest = op;
+ else
+ ;
+ }
+
+ // Make sure all operations between earliest and latest are in ops.
+ auto end = latest->getIterator();
+ for (auto it = earliest->getIterator(); it != end; ++it)
+ if (!ops.contains(&*it))
+ return std::nullopt;
+
+ return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ Value ops = getWhat();
+ SetVector<Operation *> opsPayload;
+ SmallVector<Operation *> opsClones;
+ SmallVector<Value> liveIns;
+ DenseMap<Value, unsigned> liveInIndex;
+ SmallVector<Value> liveOuts;
+ DenseMap<Value, unsigned> liveOutIndex;
+ SmallVector<Type> resultTypes;
+
+ for (Operation *payload : state.getPayloadOps(ops))
+ opsPayload.insert(payload);
+
+ Operation *wherePayload = nullptr;
+ Value whereOp = getWhereOp();
+ auto whereAttr = getWhere();
+ RelativeLocation where;
+ if (whereOp) {
+ auto wherePayloadOps = state.getPayloadOps(whereOp);
+ if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+ auto diag = emitDefiniteFailure()
+ << "expects a single location for the scope";
+ diag.attachNote(whereOp.getLoc()) << "single location";
+ return diag;
+ }
+ wherePayload = *wherePayloadOps.begin();
+ where = *whereAttr;
+ } else {
+ // No insertion point specified, so payload ops must form an interval.
+ auto interval = getInterval(opsPayload);
+ if (!interval) {
+ auto diag = emitDefiniteFailure()
+ << "payload ops must form an interval unless insertion point "
+ "is specified";
+ diag.attachNote(ops.getLoc()) << "not an interval";
+ return diag;
+ }
+ Operation *last;
+ std::tie(std::ignore, last) = *interval;
+ wherePayload = last;
+ where = RelativeLocation::After;
+ }
+
+ auto isInScope = [&opsPayload](Operation *operation) {
+ return opsPayload.contains(operation);
+ };
+
+ unsigned nextScopeOperandIndex = 0;
+ unsigned nextScopeResultIndex = 0;
+
+ for (auto payload : opsPayload) {
+ if (payload == getOperation()) {
+ auto diag = emitDefiniteFailure()
+ << "scope ops must not contain the transform being applied";
+ diag.attachNote(payload->getLoc()) << "scope";
+ return diag;
+ }
+
+ for (Value operand : payload->getOperands()) {
+ if (liveInIndex.contains(operand))
+ continue; // Already set as a live in.
+
+ Operation *def = operand.getDefiningOp();
+ if (def && opsPayload.contains(def))
+ continue; // Will be defined within scope, so not a live in.
+
+ // Set this operand as the next operand of the scope op.
+ liveIns.push_back(operand);
+ liveInIndex[operand] = nextScopeOperandIndex++;
+ }
+ for (Value result : payload->getResults()) {
+ if (liveOutIndex.contains(result))
+ continue; // Already set as a live out.
+ if (llvm::all_of(result.getUsers(), isInScope))
+ continue; // All users are in scope, so not a live out.
+
+ // Set this result as the next result of the scope op.
+ liveOuts.push_back(result);
+ liveOutIndex[result] = nextScopeResultIndex++;
+ resultTypes.push_back(result.getType());
+ }
+ }
+
+ if (where == RelativeLocation::Before)
+ rewriter.setInsertionPoint(wherePayload);
+ else
+ rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+ TypeRange noTypes;
+ ValueRange noValues;
+ auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+ wherePayload->getLoc(), noTypes, noValues);
+ Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+ rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+ auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+ resultTypes, liveIns);
+ Region *scopeBody = &scope.getBody();
+ // TODO: Move into builder.
+ Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+ // TODO: Move into builder.
----------------
aniragil wrote:
Done
https://github.com/llvm/llvm-project/pull/87352
More information about the Mlir-commits
mailing list