[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)
Gil Rapaport
llvmlistbot at llvm.org
Mon Apr 8 06:31:12 PDT 2024
================
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+ assert(!ops.empty() && "Expected non-empty operation list");
+ Operation *earliest = ops[0];
+ Operation *latest = ops[ops.size() - 1];
+ Block *block = earliest->getBlock();
+ if (latest->getBlock() != block)
+ return std::nullopt;
+ if (latest->isBeforeInBlock(earliest))
+ std::swap(earliest, latest);
+ for (Operation *op : ops) {
+ if (op->getBlock() != block)
+ return std::nullopt;
+ if (op->isBeforeInBlock(earliest))
+ earliest = op;
+ else if (latest->isBeforeInBlock(op))
+ latest = op;
+ else
+ ;
+ }
+
+ // Make sure all operations between earliest and latest are in ops.
+ auto end = latest->getIterator();
+ for (auto it = earliest->getIterator(); it != end; ++it)
+ if (!ops.contains(&*it))
+ return std::nullopt;
+
+ return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ Value ops = getWhat();
+ SetVector<Operation *> opsPayload;
+ SmallVector<Operation *> opsClones;
+ SmallVector<Value> liveIns;
+ DenseMap<Value, unsigned> liveInIndex;
+ SmallVector<Value> liveOuts;
+ DenseMap<Value, unsigned> liveOutIndex;
+ SmallVector<Type> resultTypes;
+
+ for (Operation *payload : state.getPayloadOps(ops))
+ opsPayload.insert(payload);
+
+ Operation *wherePayload = nullptr;
+ Value whereOp = getWhereOp();
+ auto whereAttr = getWhere();
+ RelativeLocation where;
+ if (whereOp) {
+ auto wherePayloadOps = state.getPayloadOps(whereOp);
+ if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+ auto diag = emitDefiniteFailure()
+ << "expects a single location for the scope";
+ diag.attachNote(whereOp.getLoc()) << "single location";
+ return diag;
+ }
+ wherePayload = *wherePayloadOps.begin();
+ where = *whereAttr;
+ } else {
+ // No insertion point specified, so payload ops must form an interval.
+ auto interval = getInterval(opsPayload);
+ if (!interval) {
+ auto diag = emitDefiniteFailure()
+ << "payload ops must form an interval unless insertion point "
+ "is specified";
+ diag.attachNote(ops.getLoc()) << "not an interval";
+ return diag;
+ }
+ Operation *last;
+ std::tie(std::ignore, last) = *interval;
+ wherePayload = last;
+ where = RelativeLocation::After;
+ }
+
+ auto isInScope = [&opsPayload](Operation *operation) {
+ return opsPayload.contains(operation);
+ };
+
+ unsigned nextScopeOperandIndex = 0;
+ unsigned nextScopeResultIndex = 0;
+
+ for (auto payload : opsPayload) {
+ if (payload == getOperation()) {
+ auto diag = emitDefiniteFailure()
+ << "scope ops must not contain the transform being applied";
+ diag.attachNote(payload->getLoc()) << "scope";
+ return diag;
+ }
+
+ for (Value operand : payload->getOperands()) {
+ if (liveInIndex.contains(operand))
+ continue; // Already set as a live in.
+
+ Operation *def = operand.getDefiningOp();
+ if (def && opsPayload.contains(def))
+ continue; // Will be defined within scope, so not a live in.
+
+ // Set this operand as the next operand of the scope op.
+ liveIns.push_back(operand);
+ liveInIndex[operand] = nextScopeOperandIndex++;
+ }
+ for (Value result : payload->getResults()) {
+ if (liveOutIndex.contains(result))
+ continue; // Already set as a live out.
+ if (llvm::all_of(result.getUsers(), isInScope))
+ continue; // All users are in scope, so not a live out.
+
+ // Set this result as the next result of the scope op.
+ liveOuts.push_back(result);
+ liveOutIndex[result] = nextScopeResultIndex++;
+ resultTypes.push_back(result.getType());
+ }
+ }
+
+ if (where == RelativeLocation::Before)
+ rewriter.setInsertionPoint(wherePayload);
+ else
+ rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+ TypeRange noTypes;
+ ValueRange noValues;
+ auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+ wherePayload->getLoc(), noTypes, noValues);
+ Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+ rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+ auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+ resultTypes, liveIns);
+ Region *scopeBody = &scope.getBody();
+ // TODO: Move into builder.
+ Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+ // TODO: Move into builder.
+ for (Value arg : liveIns)
+ scopeBlock->addArgument(arg.getType(), arg.getLoc());
+
+ IRMapping mapper;
+ for (Value liveIn : liveIns)
+ mapper.map(liveIn, scopeBlock->getArgument(liveInIndex[liveIn]));
+
+ rewriter.setInsertionPointToEnd(scopeBlock);
+ for (auto payload : opsPayload)
+ rewriter.clone(*payload, mapper);
+
+ SmallVector<Value> scopeResults;
+ for (Value liveOut : liveOuts)
+ scopeResults.push_back(mapper.lookup(liveOut));
+ rewriter.create<scf::YieldOp>(scope.getLoc(), scopeResults);
+
+ Region &body = getBody();
+
+ DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
+
+ // Create a region scope for the transform::as_scope region and run the
+ // transformations it contains. This is done in a syntactic C++ scope since
+ // we'll be deleting the scf::scope, and its mapping will become invalid if
+ // left until compactOpHandles() is called.
+ {
+ auto regionScope = state.make_region_scope(body);
+ if (failed(state.mapBlockArguments(body.front().getArgument(0), {scope})))
+ return DiagnosedSilenceableFailure::definiteFailure();
+
+ for (Operation &transform : body.front().without_terminator()) {
+ result = state.applyTransform(cast<TransformOpInterface>(transform));
+ if (result.isSilenceableFailure()) {
+ LLVM_DEBUG(DBGS() << "transformation on scope failed: "
+ << result.getMessage() << "\n");
+ break;
+ }
+
+ if (::mlir::failed(result.silence()))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+#ifdef WRAP_IN_EXECUTE_REGION
+ // Since the transformations applied may have replaced the scope, get the
+ // current scope from the wrapping scf.execute_region.
+ scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
+ scope->moveAfter(executeRegion);
+ rewriter.eraseOp(executeRegion);
+#else
+ // Since the transformations applied may have replaced the scope, get the
+ // updated payload of the block argument.
+ auto newPayloadOps = state.getPayloadOps(body.front().getArgument(0));
+ if (llvm::range_size(newPayloadOps) != 1) {
+ LLVM_DEBUG(DBGS() << "expected a single scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ Operation *newPayloadOp = *newPayloadOps.begin();
+ scope = dyn_cast<scf::ScopeOp>(newPayloadOp);
+#endif
+ if (!scope) {
+ LLVM_DEBUG(DBGS() << "scope missing post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ if (failed(scope.verify())) {
+ LLVM_DEBUG(DBGS() << "invalid scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ scopeBody = &scope.getBody();
+ if (!scopeBody->hasOneBlock()) {
+ LLVM_DEBUG(DBGS() << "multiple blocks in scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
----------------
aniragil wrote:
Removed
https://github.com/llvm/llvm-project/pull/87352
More information about the Mlir-commits
mailing list