[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 2 07:21:21 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Gil Rapaport (aniragil)
<details>
<summary>Changes</summary>
Add scf.scope modeling a code block and transform.as_scope which applies
transformations over it. The scf.scope op is similar to scf.execute_region
except it's isolated-from-above, which allows applying transforms to a
limited region without resorting to function outlining.
The transform.as_scope op creates a temporary scf.scope containing
clones of its operands and applies the transformations within its body
to that scope. If successful, the scope's results replace all uses of
any original operation outside the scope, the original operations are
erased and the scope is inlined. On failure, the scf.scope is erased.
---
Patch is 33.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87352.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+23-1)
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td (+9)
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+66)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+19)
- (modified) mlir/lib/Dialect/Transform/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+292)
- (modified) mlir/test/Dialect/SCF/invalid.mlir (+26-1)
- (modified) mlir/test/Dialect/SCF/ops.mlir (+25)
- (modified) mlir/test/Dialect/Transform/ops-invalid.mlir (+30)
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+217)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..2f7a005d96f403 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,28 @@ def ReduceReturnOp :
let hasVerifier = 1;
}
+def ScopeOp : SCF_Op<"scope",
+ [AutomaticAllocationScope,
+ RecursiveMemoryEffects,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+ let summary = "isolated code scope";
+ let description = [{
+ This op models an executed-once region, similar to `execute_region`
+ except it's isolated from above, which facilitates local scoping of
+ code without resorting to function outlining.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$body);
+
+ let assemblyFormat = [{
+ $operands attr-dict `:` functional-type($operands, $results) $body
+ }];
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
@@ -1155,7 +1177,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
- "WhileOp"]>]> {
+ "ScopeOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index ebad2994880e75..1d888e433c168f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -33,4 +33,13 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
let cppNamespace = "::mlir::transform";
}
+def BeforeCase : I32EnumAttrCase<"Before", 1, "before">;
+def AfterCase : I32EnumAttrCase<"After", 2, "after">;
+
+def RelativeLocation : I32EnumAttr<
+ "RelativeLocation", "Relative location specifier",
+ [BeforeCase, AfterCase]> {
+ let cppNamespace = "::mlir::transform";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..29da1b919add3b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,6 +434,72 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
}];
}
+def AsScopeOp : TransformDialectOp<"as_scope", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ IsolatedFromAbove,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+ let summary = "Attempts sequences of transforms until one succeeds";
+ let description = [{
+ The purpose of this op is to provide an ad-hoc scope for transformations,
+ in order to narrow their impact to specific ops.
+
+ This op creates a temporary scf.scope containing clones of its first operand
+ and applies the transformations within its body to that scope. If
+ successful, the scope's results replace all uses of any original operation
+ outside the scope, the scope is inlined and the original operations are
+ erased. On failure, the scf.scope is erased, leaving payload IR unmodified.
+
+ The operation takes as arguments a handle whose payload ops are to be
+ scoped in the order they are listed, and and optional pair of attribute,
+ handle defining an insertion point for the scf.scope. If specified, the
+ insertion point handle must hold a single payload op. If omitted, the
+ payload ops must form an interval. It is the user's responsibility to make
+ sure that the payload ops can be moved to the designated location in the
+ designated order.
+
+ This operation is useful for narrowing the scope of ops affected by the some
+ transformation, either for functional reasons, e.g.
+
+ ```mlir
+ transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+ ^bb2(%s: !transform.any_op):
+ transform.apply_patterns to %s {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.yield
+ }
+ ```
+
+ or for saving on compile time/memory by minimizing cloning, e.g.
+
+ ```mlir
+ transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+ ^bb2(%s: !transform.any_op):
+ transform.alternatives %s : !transform.any_op {
+ ^bb2(%arg2: !transform.any_op):
+ // ...
+ }, {
+ ^bb2(%arg2: !transform.any_op):
+ // ...
+ }
+ }
+ ```
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$what,
+ OptionalAttr<RelativeLocation>:$where,
+ Optional<TransformHandleTypeInterface>:$whereOp);
+ let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+ let regions = (region SizedRegion<1>:$body);
+
+ let assemblyFormat = [{
+ $what ($where $whereOp^)? `:` functional-type(operands, results)
+ attr-dict-with-keyword $body
+ }];
+ let hasVerifier = 1;
+}
+
def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+ Region &body = getBody();
+ Block &block = body.front();
+ Operation *terminator = block.getTerminator();
+ if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expects terminator operands to have the "
+ "same type as results of the operation";
+ diag.attachNote(terminator->getLoc()) << "terminator";
+ return diag;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// WhileOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 5b4989f328e690..9690004bbf8129 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRTransformDialect
MLIRPass
MLIRRewrite
MLIRSideEffectInterfaces
+ MLIRSCFDialect
MLIRTransforms
MLIRTransformDialectInterfaces
MLIRTransformDialectUtils
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..aeb930001ef9e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -8,9 +8,12 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -19,6 +22,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+ assert(!ops.empty() && "Expected non-empty operation list");
+ Operation *earliest = ops[0];
+ Operation *latest = ops[ops.size() - 1];
+ Block *block = earliest->getBlock();
+ if (latest->getBlock() != block)
+ return std::nullopt;
+ if (latest->isBeforeInBlock(earliest))
+ std::swap(earliest, latest);
+ for (Operation *op : ops) {
+ if (op->getBlock() != block)
+ return std::nullopt;
+ if (op->isBeforeInBlock(earliest))
+ earliest = op;
+ else if (latest->isBeforeInBlock(op))
+ latest = op;
+ else
+ ;
+ }
+
+ // Make sure all operations between earliest and latest are in ops.
+ auto end = latest->getIterator();
+ for (auto it = earliest->getIterator(); it != end; ++it)
+ if (!ops.contains(&*it))
+ return std::nullopt;
+
+ return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ Value ops = getWhat();
+ SetVector<Operation *> opsPayload;
+ SmallVector<Operation *> opsClones;
+ SmallVector<Value> liveIns;
+ DenseMap<Value, unsigned> liveInIndex;
+ SmallVector<Value> liveOuts;
+ DenseMap<Value, unsigned> liveOutIndex;
+ SmallVector<Type> resultTypes;
+
+ for (Operation *payload : state.getPayloadOps(ops))
+ opsPayload.insert(payload);
+
+ Operation *wherePayload = nullptr;
+ Value whereOp = getWhereOp();
+ auto whereAttr = getWhere();
+ RelativeLocation where;
+ if (whereOp) {
+ auto wherePayloadOps = state.getPayloadOps(whereOp);
+ if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+ auto diag = emitDefiniteFailure()
+ << "expects a single location for the scope";
+ diag.attachNote(whereOp.getLoc()) << "single location";
+ return diag;
+ }
+ wherePayload = *wherePayloadOps.begin();
+ where = *whereAttr;
+ } else {
+ // No insertion point specified, so payload ops must form an interval.
+ auto interval = getInterval(opsPayload);
+ if (!interval) {
+ auto diag = emitDefiniteFailure()
+ << "payload ops must form an interval unless insertion point "
+ "is specified";
+ diag.attachNote(ops.getLoc()) << "not an interval";
+ return diag;
+ }
+ Operation *last;
+ std::tie(std::ignore, last) = *interval;
+ wherePayload = last;
+ where = RelativeLocation::After;
+ }
+
+ auto isInScope = [&opsPayload](Operation *operation) {
+ return opsPayload.contains(operation);
+ };
+
+ unsigned nextScopeOperandIndex = 0;
+ unsigned nextScopeResultIndex = 0;
+
+ for (auto payload : opsPayload) {
+ if (payload == getOperation()) {
+ auto diag = emitDefiniteFailure()
+ << "scope ops must not contain the transform being applied";
+ diag.attachNote(payload->getLoc()) << "scope";
+ return diag;
+ }
+
+ for (Value operand : payload->getOperands()) {
+ if (liveInIndex.contains(operand))
+ continue; // Already set as a live in.
+
+ Operation *def = operand.getDefiningOp();
+ if (def && opsPayload.contains(def))
+ continue; // Will be defined within scope, so not a live in.
+
+ // Set this operand as the next operand of the scope op.
+ liveIns.push_back(operand);
+ liveInIndex[operand] = nextScopeOperandIndex++;
+ }
+ for (Value result : payload->getResults()) {
+ if (liveOutIndex.contains(result))
+ continue; // Already set as a live out.
+ if (llvm::all_of(result.getUsers(), isInScope))
+ continue; // All users are in scope, so not a live out.
+
+ // Set this result as the next result of the scope op.
+ liveOuts.push_back(result);
+ liveOutIndex[result] = nextScopeResultIndex++;
+ resultTypes.push_back(result.getType());
+ }
+ }
+
+ if (where == RelativeLocation::Before)
+ rewriter.setInsertionPoint(wherePayload);
+ else
+ rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+ TypeRange noTypes;
+ ValueRange noValues;
+ auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+ wherePayload->getLoc(), noTypes, noValues);
+ Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+ rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+ auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+ resultTypes, liveIns);
+ Region *scopeBody = &scope.getBody();
+ // TODO: Move into builder.
+ Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+ // TODO: Move into builder.
+ for (Value arg : liveIns)
+ scopeBlock->addArgument(arg.getType(), arg.getLoc());
+
+ IRMapping mapper;
+ for (Value liveIn : liveIns)
+ mapper.map(liveIn, scopeBlock->getArgument(liveInIndex[liveIn]));
+
+ rewriter.setInsertionPointToEnd(scopeBlock);
+ for (auto payload : opsPayload)
+ rewriter.clone(*payload, mapper);
+
+ SmallVector<Value> scopeResults;
+ for (Value liveOut : liveOuts)
+ scopeResults.push_back(mapper.lookup(liveOut));
+ rewriter.create<scf::YieldOp>(scope.getLoc(), scopeResults);
+
+ Region &body = getBody();
+
+ DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
+
+ // Create a region scope for the transform::as_scope region and run the
+ // transformations it contains. This is done in a syntactic C++ scope since
+ // we'll be deleting the scf::scope, and its mapping will become invalid if
+ // left until compactOpHandles() is called.
+ {
+ auto regionScope = state.make_region_scope(body);
+ if (failed(state.mapBlockArguments(body.front().getArgument(0), {scope})))
+ return DiagnosedSilenceableFailure::definiteFailure();
+
+ for (Operation &transform : body.front().without_terminator()) {
+ result = state.applyTransform(cast<TransformOpInterface>(transform));
+ if (result.isSilenceableFailure()) {
+ LLVM_DEBUG(DBGS() << "transformation on scope failed: "
+ << result.getMessage() << "\n");
+ break;
+ }
+
+ if (::mlir::failed(result.silence()))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+#ifdef WRAP_IN_EXECUTE_REGION
+ // Since the transformations applied may have replaced the scope, get the
+ // current scope from the wrapping scf.execute_region.
+ scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
+ scope->moveAfter(executeRegion);
+ rewriter.eraseOp(executeRegion);
+#else
+ // Since the transformations applied may have replaced the scope, get the
+ // updated payload of the block argument.
+ auto newPayloadOps = state.getPayloadOps(body.front().getArgument(0));
+ if (llvm::range_size(newPayloadOps) != 1) {
+ LLVM_DEBUG(DBGS() << "expected a single scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ Operation *newPayloadOp = *newPayloadOps.begin();
+ scope = dyn_cast<scf::ScopeOp>(newPayloadOp);
+#endif
+ if (!scope) {
+ LLVM_DEBUG(DBGS() << "scope missing post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ if (failed(scope.verify())) {
+ LLVM_DEBUG(DBGS() << "invalid scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ scopeBody = &scope.getBody();
+ if (!scopeBody->hasOneBlock()) {
+ LLVM_DEBUG(DBGS() << "multiple blocks in scope post transformation\n");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ scopeBlock = &scopeBody->front();
+ }
+
+ if (!result.succeeded()) {
+ // Erase the scope and return a silenaceable failure.
+ rewriter.eraseOp(scope);
+ return result;
+ }
+
+ // The transformations applied to the scope succeeded. Complete the
+ // transformation by inlining the new contents of the scope and having
+ // the scope's yielded values replace the original live-out values.
+ // Finally, erase the original ops.
+
+ LLVM_DEBUG(DBGS() << "scope after transformations: " << scope << "\n");
+
+ // Update the live-outs to what the scope's terminator now yields.
+ auto scopeYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
+ SmallVector<Value> transformedLiveOuts =
+ llvm::to_vector(scopeYield.getResults());
+
+ IRMapping inlineMapper;
+ for (auto [blockArg, operand] :
+ llvm::zip(scopeBlock->getArguments(), scope.getOperands()))
+ inlineMapper.map(blockArg, operand);
+
+ rewriter.setInsertionPoint(scope);
+ for (Operation &scopeOp : scopeBlock->without_terminator())
+ rewriter.clone(scopeOp, inlineMapper);
+
+ for (auto [liveOut, transformedLiveOut] :
+ llvm::zip(liveOuts, transformedLiveOuts)) {
+ Value inlinedLiveOut = inlineMapper.lookup<Value>(transformedLiveOut);
+ rewriter.replaceAllUsesWith(liveOut, inlinedLiveOut);
+ }
+
+ rewriter.eraseOp(scope);
+ while (!opsPayload.empty()) {
+ for (auto payload : llvm::make_early_inc_range(opsPayload)) {
+ if (payload->getUsers().empty()) {
+ rewriter.eraseOp(payload);
+ opsPayload.remove(payload);
+ }
+ }
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::AsScopeOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getWhereOp(), effects);
+ consumesHandle(getWhat(), effects);
+ producesHandle(getResults(), effects);
+ Region ®ion = getRegion();
+ if (!region.empty())
+ producesHandle(region.front().getArguments(), effects);
+ modifiesPayload(effects);
+}
+
+LogicalResult transform::AsScopeOp::verify() {
+ Region &body = getBody();
+ Block &block = body.front();
+ Operation *terminator = block.getTerminator();
+ if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+ InFlightDiagnostic diag = emitOpError()
+ << "expects terminator operands to have the "
+ "same type as results of the operation";
+ diag.attachNote(terminator->getLoc()) << "terminator";
+ return diag;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
func.func @yield_invalid_parent_op() {
"my.op"() ({
- // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+ // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
scf.yield
}) : () -> ()
return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
return
}
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+ // expected-note @below {{required by region isolation constraints}}
+ %p = scf.scope : () -> (i32) {
+ ^bb0():
+ // expected-error @below {{'arith.addi' op using value defined outside the region}}
+ %add = arith.addi %arg0, %arg1 : i32
+ scf.yield %add : i32
+ }...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/87352
More information about the Mlir-commits
mailing list