[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)

Gil Rapaport llvmlistbot at llvm.org
Thu Apr 18 09:57:01 PDT 2024


https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/87352

>From 57b873eaf7f21038e13dc8b2c64a6dc9209c56d1 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 16 Apr 2024 17:58:18 +0300
Subject: [PATCH 1/6] [mlir][scf] Add a scope op to the scf dialect

Add to the scf dialect an operation modeling an isolated-from-above
single basic block that is executed once. It provides a localized,
hierarchical alternative to outlining code into a function/call pair.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 37 +++++++++++++++++++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 19 +++++++++++
 mlir/test/Dialect/SCF/invalid.mlir         | 27 +++++++++++++++-
 mlir/test/Dialect/SCF/ops.mlir             | 25 +++++++++++++++
 4 files changed, 106 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..6c04fe4cdd651d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,41 @@ def ReduceReturnOp :
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+def ScopeOp : SCF_Op<"scope",
+    [AutomaticAllocationScope,
+     RecursiveMemoryEffects,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+  let summary = "isolated code scope";
+  let description = [{
+    The 'scope' op encapsulates computations by providing an isolated-from-above,
+    executed-once single basic block. The op takes any number of operands, and
+    its return values are defined by its terminating `scf.yield`. For example:
+
+    ```mlir
+    %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+    ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+      %add = arith.addi %a, %b : i32
+      %mul = arith.mulf %c, %d : f32
+      scf.yield %add, %mul : i32, f32
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$body);
+
+  let assemblyFormat = [{
+    $operands attr-dict `:` functional-type($operands, $results) $body
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
@@ -1155,7 +1190,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
     ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
-                 "WhileOp"]>]> {
+                 "ScopeOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
 
 func.func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
    scf.yield
   }) : () -> ()
   return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-note @below {{required by region isolation constraints}}
+  %p = scf.scope : () -> (i32) {
+  ^bb0():
+    // expected-error @below {{'arith.addi' op using value defined outside the region}}
+    %add = arith.addi %arg0, %arg1 : i32
+    scf.yield %add : i32
+  }
+  return %p : i32
+}
+
+// -----
+
+func.func @scope_yield_results_mismatch(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-error @below {{'scf.scope' op expects terminator operands to have the same type as results of the operation}}
+  %p = scf.scope %arg0, %arg1 : (i32, i32) -> (i32) {
+  ^bb0(%k : i32, %t : i32):
+    %add = arith.addi %k, %t : i32
+    // expected-note @below {{terminator}}
+    scf.yield %add, %add : i32, i32
+  }
+  return %p : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef3b6ba0c..70cb72a9ec0bf6 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,28 @@ func.func @switch(%arg0: index) -> i32 {
 
   return %0 : i32
 }
+
+// CHECK-LABEL: @scope
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32
+func.func @scope(%arg0 : i32, %arg1 : f32, %arg2 : f32) -> (f32) {
+  // CHECK: %[[VAL_3:.*]] = arith.constant 77 : i32
+  %c77 = arith.constant 77 : i32
+
+  // CHECK: %[[VAL_4:.*]]:2 = scf.scope %[[VAL_0]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : (i32, i32, f32, f32) -> (i32, f32)
+  %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+  // CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+  ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+    // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+    %add = arith.addi %a, %b : i32
+    // CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32
+    %mul = arith.mulf %c, %d : f32
+    // CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : i32, f32
+    scf.yield %add, %mul : i32, f32
+  }
+
+  // CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_4:.*]]#0 : i32 to f32
+  %m = arith.sitofp %p#0 : i32 to f32
+  // CHECK: %[[VAL_13:.*]] = arith.subf %[[VAL_11]], %[[VAL_4]]#1 : f32
+  %r = arith.subf %m, %p#1 : f32
+  return %r : f32
+}

>From 40ee7e5a5c6b0c3ec49c05fe238172c86172d281 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Mon, 25 Dec 2023 13:23:10 +0200
Subject: [PATCH 2/6] [mlir][transform] Add transform.as_scope op

The transform.as_scope op creates a temporary scf.scope containing
clones of its operands and applies the transformations within its body
to that scope. If successful, the scope's results replace all uses of
any original operation outside the scope, the original operations are
erased and the scope is inlined. On failure, the scf.scope is erased.
---
 .../Dialect/Transform/IR/TransformAttrs.td    |   9 +
 .../mlir/Dialect/Transform/IR/TransformOps.td |  66 ++++
 mlir/lib/Dialect/Transform/IR/CMakeLists.txt  |   1 +
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 292 ++++++++++++++++++
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  30 ++
 .../Dialect/Transform/test-interpreter.mlir   | 217 +++++++++++++
 6 files changed, 615 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index ebad2994880e75..1d888e433c168f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -33,4 +33,13 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
   let cppNamespace = "::mlir::transform";
 }
 
+def BeforeCase : I32EnumAttrCase<"Before", 1, "before">;
+def AfterCase : I32EnumAttrCase<"After", 2, "after">;
+
+def RelativeLocation : I32EnumAttr<
+    "RelativeLocation", "Relative location specifier",
+    [BeforeCase, AfterCase]> {
+  let cppNamespace = "::mlir::transform";
+}
+
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..29da1b919add3b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,6 +434,72 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
   }];
 }
 
+def AsScopeOp : TransformDialectOp<"as_scope", [
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Attempts sequences of transforms until one succeeds";
+  let description = [{
+    The purpose of this op is to provide an ad-hoc scope for transformations,
+    in order to narrow their impact to specific ops.
+
+    This op creates a temporary scf.scope containing clones of its first operand
+    and applies the transformations within its body to that scope. If
+    successful, the scope's results replace all uses of any original operation
+    outside the scope, the scope is inlined and the original operations are
+    erased. On failure, the scf.scope is erased, leaving payload IR unmodified.
+
+    The operation takes as arguments a handle whose payload ops are to be
+    scoped in the order they are listed, and and optional pair of attribute,
+    handle defining an insertion point for the scf.scope. If specified, the
+    insertion point handle must hold a single payload op. If omitted, the
+    payload ops must form an interval. It is the user's responsibility to make
+    sure that the payload ops can be moved to the designated location in the
+    designated order.
+
+    This operation is useful for narrowing the scope of ops affected by the some
+    transformation, either for functional reasons, e.g.
+
+    ```mlir
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    ^bb2(%s: !transform.any_op):
+      transform.apply_patterns to %s {
+        transform.apply_patterns.canonicalization
+      } : !transform.any_op
+      transform.yield
+    }
+    ```
+
+    or for saving on compile time/memory by minimizing cloning, e.g.
+
+    ```mlir
+    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    ^bb2(%s: !transform.any_op):
+      transform.alternatives %s : !transform.any_op {
+      ^bb2(%arg2: !transform.any_op):
+        // ...
+      }, {
+      ^bb2(%arg2: !transform.any_op):
+        // ...
+      }
+    }
+    ```
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$what,
+                       OptionalAttr<RelativeLocation>:$where,
+                       Optional<TransformHandleTypeInterface>:$whereOp);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let regions = (region SizedRegion<1>:$body);
+
+  let assemblyFormat = [{
+    $what ($where $whereOp^)? `:` functional-type(operands, results)
+    attr-dict-with-keyword $body
+  }];
+  let hasVerifier = 1;
+}
+
 def CastOp : TransformDialectOp<"cast",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 5b4989f328e690..9690004bbf8129 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRTransformDialect
   MLIRPass
   MLIRRewrite
   MLIRSideEffectInterfaces
+  MLIRSCFDialect
   MLIRTransforms
   MLIRTransformDialectInterfaces
   MLIRTransformDialectUtils
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..aeb930001ef9e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -8,9 +8,12 @@
 
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 
+#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -19,6 +22,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+  assert(!ops.empty() && "Expected non-empty operation list");
+  Operation *earliest = ops[0];
+  Operation *latest = ops[ops.size() - 1];
+  Block *block = earliest->getBlock();
+  if (latest->getBlock() != block)
+    return std::nullopt;
+  if (latest->isBeforeInBlock(earliest))
+    std::swap(earliest, latest);
+  for (Operation *op : ops) {
+    if (op->getBlock() != block)
+      return std::nullopt;
+    if (op->isBeforeInBlock(earliest))
+      earliest = op;
+    else if (latest->isBeforeInBlock(op))
+      latest = op;
+    else
+      ;
+  }
+
+  // Make sure all operations between earliest and latest are in ops.
+  auto end = latest->getIterator();
+  for (auto it = earliest->getIterator(); it != end; ++it)
+    if (!ops.contains(&*it))
+      return std::nullopt;
+
+  return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  Value ops = getWhat();
+  SetVector<Operation *> opsPayload;
+  SmallVector<Operation *> opsClones;
+  SmallVector<Value> liveIns;
+  DenseMap<Value, unsigned> liveInIndex;
+  SmallVector<Value> liveOuts;
+  DenseMap<Value, unsigned> liveOutIndex;
+  SmallVector<Type> resultTypes;
+
+  for (Operation *payload : state.getPayloadOps(ops))
+    opsPayload.insert(payload);
+
+  Operation *wherePayload = nullptr;
+  Value whereOp = getWhereOp();
+  auto whereAttr = getWhere();
+  RelativeLocation where;
+  if (whereOp) {
+    auto wherePayloadOps = state.getPayloadOps(whereOp);
+    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+      auto diag = emitDefiniteFailure()
+                  << "expects a single location for the scope";
+      diag.attachNote(whereOp.getLoc()) << "single location";
+      return diag;
+    }
+    wherePayload = *wherePayloadOps.begin();
+    where = *whereAttr;
+  } else {
+    // No insertion point specified, so payload ops must form an interval.
+    auto interval = getInterval(opsPayload);
+    if (!interval) {
+      auto diag = emitDefiniteFailure()
+                  << "payload ops must form an interval unless insertion point "
+                     "is specified";
+      diag.attachNote(ops.getLoc()) << "not an interval";
+      return diag;
+    }
+    Operation *last;
+    std::tie(std::ignore, last) = *interval;
+    wherePayload = last;
+    where = RelativeLocation::After;
+  }
+
+  auto isInScope = [&opsPayload](Operation *operation) {
+    return opsPayload.contains(operation);
+  };
+
+  unsigned nextScopeOperandIndex = 0;
+  unsigned nextScopeResultIndex = 0;
+
+  for (auto payload : opsPayload) {
+    if (payload == getOperation()) {
+      auto diag = emitDefiniteFailure()
+                  << "scope ops must not contain the transform being applied";
+      diag.attachNote(payload->getLoc()) << "scope";
+      return diag;
+    }
+
+    for (Value operand : payload->getOperands()) {
+      if (liveInIndex.contains(operand))
+        continue; // Already set as a live in.
+
+      Operation *def = operand.getDefiningOp();
+      if (def && opsPayload.contains(def))
+        continue; // Will be defined within scope, so not a live in.
+
+      // Set this operand as the next operand of the scope op.
+      liveIns.push_back(operand);
+      liveInIndex[operand] = nextScopeOperandIndex++;
+    }
+    for (Value result : payload->getResults()) {
+      if (liveOutIndex.contains(result))
+        continue; // Already set as a live out.
+      if (llvm::all_of(result.getUsers(), isInScope))
+        continue; // All users are in scope, so not a live out.
+
+      // Set this result as the next result of the scope op.
+      liveOuts.push_back(result);
+      liveOutIndex[result] = nextScopeResultIndex++;
+      resultTypes.push_back(result.getType());
+    }
+  }
+
+  if (where == RelativeLocation::Before)
+    rewriter.setInsertionPoint(wherePayload);
+  else
+    rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+  TypeRange noTypes;
+  ValueRange noValues;
+  auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+      wherePayload->getLoc(), noTypes, noValues);
+  Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+  rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+  auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+                                             resultTypes, liveIns);
+  Region *scopeBody = &scope.getBody();
+  // TODO: Move into builder.
+  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+  // TODO: Move into builder.
+  for (Value arg : liveIns)
+    scopeBlock->addArgument(arg.getType(), arg.getLoc());
+
+  IRMapping mapper;
+  for (Value liveIn : liveIns)
+    mapper.map(liveIn, scopeBlock->getArgument(liveInIndex[liveIn]));
+
+  rewriter.setInsertionPointToEnd(scopeBlock);
+  for (auto payload : opsPayload)
+    rewriter.clone(*payload, mapper);
+
+  SmallVector<Value> scopeResults;
+  for (Value liveOut : liveOuts)
+    scopeResults.push_back(mapper.lookup(liveOut));
+  rewriter.create<scf::YieldOp>(scope.getLoc(), scopeResults);
+
+  Region &body = getBody();
+
+  DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
+
+  // Create a region scope for the transform::as_scope region and run the
+  // transformations it contains. This is done in a syntactic C++ scope since
+  // we'll be deleting the scf::scope, and its mapping will become invalid if
+  // left until compactOpHandles() is called.
+  {
+    auto regionScope = state.make_region_scope(body);
+    if (failed(state.mapBlockArguments(body.front().getArgument(0), {scope})))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (Operation &transform : body.front().without_terminator()) {
+      result = state.applyTransform(cast<TransformOpInterface>(transform));
+      if (result.isSilenceableFailure()) {
+        LLVM_DEBUG(DBGS() << "transformation on scope failed: "
+                          << result.getMessage() << "\n");
+        break;
+      }
+
+      if (::mlir::failed(result.silence()))
+        return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+#ifdef WRAP_IN_EXECUTE_REGION
+    // Since the transformations applied may have replaced the scope, get the
+    // current scope from the wrapping scf.execute_region.
+    scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
+    scope->moveAfter(executeRegion);
+    rewriter.eraseOp(executeRegion);
+#else
+    // Since the transformations applied may have replaced the scope, get the
+    // updated payload of the block argument.
+    auto newPayloadOps = state.getPayloadOps(body.front().getArgument(0));
+    if (llvm::range_size(newPayloadOps) != 1) {
+      LLVM_DEBUG(DBGS() << "expected a single scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    Operation *newPayloadOp = *newPayloadOps.begin();
+    scope = dyn_cast<scf::ScopeOp>(newPayloadOp);
+#endif
+    if (!scope) {
+      LLVM_DEBUG(DBGS() << "scope missing post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    if (failed(scope.verify())) {
+      LLVM_DEBUG(DBGS() << "invalid scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    scopeBody = &scope.getBody();
+    if (!scopeBody->hasOneBlock()) {
+      LLVM_DEBUG(DBGS() << "multiple blocks in scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    scopeBlock = &scopeBody->front();
+  }
+
+  if (!result.succeeded()) {
+    // Erase the scope and return a silenaceable failure.
+    rewriter.eraseOp(scope);
+    return result;
+  }
+
+  // The transformations applied to the scope succeeded. Complete the
+  // transformation by inlining the new contents of the scope and having
+  // the scope's yielded values replace the original live-out values.
+  // Finally, erase the original ops.
+
+  LLVM_DEBUG(DBGS() << "scope after transformations: " << scope << "\n");
+
+  // Update the live-outs to what the scope's terminator now yields.
+  auto scopeYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
+  SmallVector<Value> transformedLiveOuts =
+      llvm::to_vector(scopeYield.getResults());
+
+  IRMapping inlineMapper;
+  for (auto [blockArg, operand] :
+       llvm::zip(scopeBlock->getArguments(), scope.getOperands()))
+    inlineMapper.map(blockArg, operand);
+
+  rewriter.setInsertionPoint(scope);
+  for (Operation &scopeOp : scopeBlock->without_terminator())
+    rewriter.clone(scopeOp, inlineMapper);
+
+  for (auto [liveOut, transformedLiveOut] :
+       llvm::zip(liveOuts, transformedLiveOuts)) {
+    Value inlinedLiveOut = inlineMapper.lookup<Value>(transformedLiveOut);
+    rewriter.replaceAllUsesWith(liveOut, inlinedLiveOut);
+  }
+
+  rewriter.eraseOp(scope);
+  while (!opsPayload.empty()) {
+    for (auto payload : llvm::make_early_inc_range(opsPayload)) {
+      if (payload->getUsers().empty()) {
+        rewriter.eraseOp(payload);
+        opsPayload.remove(payload);
+      }
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::AsScopeOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getWhereOp(), effects);
+  consumesHandle(getWhat(), effects);
+  producesHandle(getResults(), effects);
+  Region &region = getRegion();
+  if (!region.empty())
+    producesHandle(region.front().getArguments(), effects);
+  modifiesPayload(effects);
+}
+
+LogicalResult transform::AsScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index cc04e65420c5b7..0eeca7fdc02733 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -782,3 +782,33 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  %1 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  %what = transform.merge_handles %0, %1 : !transform.any_op
+  %where = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-error @below {{expects terminator operands to have the same type as results of the operation}}
+  %2:2 = transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-note @below {{terminator}}
+    transform.yield
+  }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-note @below {{used here as operand #1}}
+  // expected-note @below {{used here as operand #0}}
+  transform.as_scope %0 before %0 : (!transform.any_op, !transform.any_op) -> () {
+  ^bb2(%arg2: !transform.any_op):
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2024d53d..61dc24bbb2caa3 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2513,3 +2513,220 @@ module attributes { transform.with_named_sequence } {
     transform.yield %arg0 : !transform.any_op
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @scope_before(
+//  CHECK-SAME:                         %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) -> (i32, i32) {
+//       CHECK:   %[[VAL_5:.*]] = arith.muli %[[VAL_3]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.subi %[[VAL_0]], %[[VAL_1]] : i32
+//       CHECK:   %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_8:.*]] = arith.divui %[[VAL_5]], %[[VAL_4]] : i32
+//       CHECK:   return %[[VAL_7]], %[[VAL_8]] : i32, i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_before(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+    %sub = arith.subi %arg0, %arg1 : i32
+    %add = arith.addi %sub, %arg2 : i32
+    %mul = arith.muli %arg3, %arg3 : i32
+    %div = arith.divui %mul, %arg4 : i32
+    return %add, %div : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.structured.match ops{["arith.subi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.divui"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_after(
+//  CHECK-SAME:                        %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) -> (i32, i32) {
+//       CHECK:   %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_5:.*]] = arith.subi %[[VAL_0]], %[[VAL_1]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_8:.*]] = arith.divui %[[VAL_7]], %[[VAL_4]] : i32
+//       CHECK:   return %[[VAL_6]], %[[VAL_8]] : i32, i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_after(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+    %sub = arith.subi %arg0, %arg1 : i32
+    %add = arith.addi %sub, %arg2 : i32
+    %mul = arith.muli %arg3, %arg3 : i32
+    %div = arith.divui %mul, %arg4 : i32
+    return %add, %div : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.structured.match ops{["arith.subi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+// -----
+
+// CHECK-LABEL: func.func @scope_partial_folding() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_1:.*]] = arith.constant 33 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 99 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_4:.*]] = arith.addi %[[VAL_0]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_5:.*]] = arith.muli %[[VAL_1]], %[[VAL_4]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.divui %[[VAL_5]], %[[VAL_2]] : i32
+//       CHECK:   return %[[VAL_6]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_partial_folding() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c33 = arith.constant 33 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    %c99 = arith.constant 99 : i32
+    %mul = arith.muli %c33, %add : i32
+    %div = arith.divui %mul, %c99 : i32
+    return %div : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.apply_patterns to %s {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_interval() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.addi %[[VAL_0]], %[[VAL_2]] : i32
+//       CHECK:   return %[[VAL_3]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_interval() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    return %add : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    transform.as_scope %what : (!transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.apply_patterns to %s {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  func.func @scope_not_an_interval(%a: i32, %b: i32, %c: i32) -> (i32, i32) {
+    %sub = arith.subi %a, %b : i32
+    %mul = arith.muli %a, %b : i32
+    %add = arith.addi %sub, %c : i32
+    return %add, %mul : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    // expected-note @below {{not an interval}}
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    // expected-error @below {{payload ops must form an interval unless insertion point is specified}}
+    transform.as_scope %what : (!transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_with_alternatives() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_1:.*]] = arith.constant 33 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.addi %[[VAL_0]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_4:.*]] = arith.muli %[[VAL_1]], %[[VAL_3]] : i32
+//       CHECK:   return %[[VAL_4]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_with_alternatives() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c33 = arith.constant 33 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    %mul = arith.muli %c33, %add : i32
+    return %mul : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+    ^bb2(%s: !transform.any_op):
+      transform.alternatives %s : !transform.any_op {
+      ^bb2(%arg2: !transform.any_op):
+        %1 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+        // This operation fails, which triggers the next alternative without
+        // reporting the error.
+        transform.test_consume_operand_of_op_kind_or_fail %1, "transform.sequence" : !transform.any_op
+        transform.yield
+      }, {
+      ^bb2(%arg2: !transform.any_op):
+        transform.apply_patterns to %arg2 {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+      }
+      transform.yield
+    }
+    transform.yield
+  }
+}

>From ffadddb0c9a7494b1c460f91e0242cbe6e2bd7c0 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Sat, 6 Apr 2024 17:35:45 +0300
Subject: [PATCH 3/6] Update
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 29da1b919add3b..1d7384c66c0efa 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -444,7 +444,7 @@ def AsScopeOp : TransformDialectOp<"as_scope", [
     The purpose of this op is to provide an ad-hoc scope for transformations,
     in order to narrow their impact to specific ops.
 
-    This op creates a temporary scf.scope containing clones of its first operand
+    This op creates a temporary scf.scope containing clones of payload operations associated with its first operand
     and applies the transformations within its body to that scope. If
     successful, the scope's results replace all uses of any original operation
     outside the scope, the scope is inlined and the original operations are

>From ca2a6f3ea90a48f5e1491cdf492da48d7f327646 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Sat, 6 Apr 2024 17:35:59 +0300
Subject: [PATCH 4/6] Update
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 1d7384c66c0efa..691e02627dd0a3 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -451,7 +451,7 @@ def AsScopeOp : TransformDialectOp<"as_scope", [
     erased. On failure, the scf.scope is erased, leaving payload IR unmodified.
 
     The operation takes as arguments a handle whose payload ops are to be
-    scoped in the order they are listed, and and optional pair of attribute,
+    scoped in the order they are listed, and an optional pair of attribute,
     handle defining an insertion point for the scf.scope. If specified, the
     insertion point handle must hold a single payload op. If omitted, the
     payload ops must form an interval. It is the user's responsibility to make

>From 2e653352acab411631761cbff7d716823f04a0de Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Sun, 7 Apr 2024 10:10:54 +0300
Subject: [PATCH 5/6] Address review comments, add test case

- Fixed and add documentation
- Fixed coding style
- Added test case of an as_scope generated over a single op
- Use createBlock() taking args
- Document and assert forward-progress condition for original payload
  ops: any remaining user after replacing their users with the
  transformed live-outs must be another original payload op.
- Limit user replacement of live-outs to users that do not belong to
  the original payload ops. This is cleaner, even though these ops will
  immediately be deleted.
- Remove hasOneBlock() check already done during verification
- Remove block argument producesHandle code
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  4 +-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 71 +++++++++++--------
 .../Dialect/Transform/test-interpreter.mlir   | 34 +++++++++
 3 files changed, 79 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 691e02627dd0a3..0c4a47bad41d03 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -462,7 +462,7 @@ def AsScopeOp : TransformDialectOp<"as_scope", [
     transformation, either for functional reasons, e.g.
 
     ```mlir
-    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
     ^bb2(%s: !transform.any_op):
       transform.apply_patterns to %s {
         transform.apply_patterns.canonicalization
@@ -474,7 +474,7 @@ def AsScopeOp : TransformDialectOp<"as_scope", [
     or for saving on compile time/memory by minimizing cloning, e.g.
 
     ```mlir
-    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> () {
     ^bb2(%s: !transform.any_op):
       transform.alternatives %s : !transform.any_op {
       ^bb2(%arg2: !transform.any_op):
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index aeb930001ef9e8..3e7c55625c3c8a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -804,6 +804,11 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
 // AsScopeOp
 //===----------------------------------------------------------------------===//
 
+/// Helper function for getting the (first, last) pair of ops in the closed
+/// interval of ops specified as argument. All ops must belong to the same block
+/// and are not assumed to be ordered from earliest to latest. The function
+/// finds the earliest and latest ops and verifies that given ops indeed form an
+/// interval, i.e. that all ops between earliest and latest are in the set.
 static std::optional<std::pair<Operation *, Operation *>>
 getInterval(SetVector<Operation *> &ops) {
   assert(!ops.empty() && "Expected non-empty operation list");
@@ -817,12 +822,12 @@ getInterval(SetVector<Operation *> &ops) {
   for (Operation *op : ops) {
     if (op->getBlock() != block)
       return std::nullopt;
-    if (op->isBeforeInBlock(earliest))
+    if (op->isBeforeInBlock(earliest)) {
       earliest = op;
-    else if (latest->isBeforeInBlock(op))
+      continue;
+    }
+    if (latest->isBeforeInBlock(op))
       latest = op;
-    else
-      ;
   }
 
   // Make sure all operations between earliest and latest are in ops.
@@ -852,25 +857,25 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
 
   Operation *wherePayload = nullptr;
   Value whereOp = getWhereOp();
-  auto whereAttr = getWhere();
   RelativeLocation where;
   if (whereOp) {
     auto wherePayloadOps = state.getPayloadOps(whereOp);
-    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
-      auto diag = emitDefiniteFailure()
-                  << "expects a single location for the scope";
+    if (!llvm::hasSingleElement(wherePayloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "expects a single location for the scope";
       diag.attachNote(whereOp.getLoc()) << "single location";
       return diag;
     }
     wherePayload = *wherePayloadOps.begin();
-    where = *whereAttr;
+    where = *getWhere();
   } else {
     // No insertion point specified, so payload ops must form an interval.
     auto interval = getInterval(opsPayload);
     if (!interval) {
-      auto diag = emitDefiniteFailure()
-                  << "payload ops must form an interval unless insertion point "
-                     "is specified";
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError()
+          << "payload ops must form an interval unless insertion point "
+             "is specified";
       diag.attachNote(ops.getLoc()) << "not an interval";
       return diag;
     }
@@ -938,12 +943,12 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
   auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
                                              resultTypes, liveIns);
   Region *scopeBody = &scope.getBody();
-  // TODO: Move into builder.
-  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
-
-  // TODO: Move into builder.
-  for (Value arg : liveIns)
-    scopeBlock->addArgument(arg.getType(), arg.getLoc());
+  SmallVector<Location> argLocs(liveIns.size(), scope.getLoc());
+  SmallVector<Type> argTypes;
+  for (Value liveIn : liveIns)
+    argTypes.push_back(liveIn.getType());
+  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end(),
+                                           argTypes, argLocs);
 
   IRMapping mapper;
   for (Value liveIn : liveIns)
@@ -979,7 +984,7 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
         break;
       }
 
-      if (::mlir::failed(result.silence()))
+      if (failed(result.silence()))
         return DiagnosedSilenceableFailure::definiteFailure();
     }
 
@@ -1009,10 +1014,6 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
       return DiagnosedSilenceableFailure::definiteFailure();
     }
     scopeBody = &scope.getBody();
-    if (!scopeBody->hasOneBlock()) {
-      LLVM_DEBUG(DBGS() << "multiple blocks in scope post transformation\n");
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
     scopeBlock = &scopeBody->front();
   }
 
@@ -1043,13 +1044,30 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
   for (Operation &scopeOp : scopeBlock->without_terminator())
     rewriter.clone(scopeOp, inlineMapper);
 
+  // Replace users with the transformed live outs, except users which belong to
+  // the original payload ops. Once done, any remaining user of any payload op
+  // must itself be a payload op.
   for (auto [liveOut, transformedLiveOut] :
        llvm::zip(liveOuts, transformedLiveOuts)) {
     Value inlinedLiveOut = inlineMapper.lookup<Value>(transformedLiveOut);
-    rewriter.replaceAllUsesWith(liveOut, inlinedLiveOut);
-  }
+    rewriter.replaceUsesWithIf(liveOut, inlinedLiveOut,
+                               [&isInScope](OpOperand &operand) {
+                                 return !isInScope(operand.getOwner());
+                               });
+  }
+  assert(llvm::all_of(opsPayload,
+                      [&opsPayload](Operation *payload) {
+                        return llvm::all_of(payload->getUsers(),
+                                            [&opsPayload](Operation *user) {
+                                              return opsPayload.contains(user);
+                                            });
+                      }) &&
+         "Expected users of payload ops to be the payload ops themselves");
 
   rewriter.eraseOp(scope);
+
+  // Erase the original payload ops. Since these ops might still be using each
+  // other, they must be erased in use-before-def order.
   while (!opsPayload.empty()) {
     for (auto payload : llvm::make_early_inc_range(opsPayload)) {
       if (payload->getUsers().empty()) {
@@ -1067,9 +1085,6 @@ void transform::AsScopeOp::getEffects(
   onlyReadsHandle(getWhereOp(), effects);
   consumesHandle(getWhat(), effects);
   producesHandle(getResults(), effects);
-  Region &region = getRegion();
-  if (!region.empty())
-    producesHandle(region.front().getArguments(), effects);
   modifiesPayload(effects);
 }
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 61dc24bbb2caa3..434c00568c763d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2658,6 +2658,40 @@ module attributes { transform.with_named_sequence } {
 
 // -----
 
+// CHECK-LABEL: func.func @scope_single_op_interval() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_1:.*]] = arith.constant 75 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 29 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.subi %[[VAL_1]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_0]] : i32
+//       CHECK:   return %[[VAL_4]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_single_op_interval() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    return %add : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %addi : !transform.any_op
+    transform.as_scope %what : (!transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.apply_patterns to %s {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
 module attributes { transform.with_named_sequence } {
   func.func @scope_not_an_interval(%a: i32, %b: i32, %c: i32) -> (i32, i32) {
     %sub = arith.subi %a, %b : i32

>From 21aa8737a6f193c66b1120652da5818df2da86c5 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 16 Apr 2024 14:07:31 +0300
Subject: [PATCH 6/6] Remove commented out code

As it is currently unclear how to retrieve the potentially replaced scope
in a robust manner, this option is removed in favor of the current
implementation of wrapping the scope in a scf.execute_region op.
---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 16 ----------------
 1 file changed, 16 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3e7c55625c3c8a..d2a1ead2425d0a 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -930,16 +930,12 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
   else
     rewriter.setInsertionPointAfter(wherePayload);
 
-#define WRAP_IN_EXECUTE_REGION
-#ifdef WRAP_IN_EXECUTE_REGION
   TypeRange noTypes;
   ValueRange noValues;
   auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
       wherePayload->getLoc(), noTypes, noValues);
   Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
   rewriter.setInsertionPointToStart(&executeRegionBody);
-#endif
-
   auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
                                              resultTypes, liveIns);
   Region *scopeBody = &scope.getBody();
@@ -988,23 +984,11 @@ transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
         return DiagnosedSilenceableFailure::definiteFailure();
     }
 
-#ifdef WRAP_IN_EXECUTE_REGION
     // Since the transformations applied may have replaced the scope, get the
     // current scope from the wrapping scf.execute_region.
     scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
     scope->moveAfter(executeRegion);
     rewriter.eraseOp(executeRegion);
-#else
-    // Since the transformations applied may have replaced the scope, get the
-    // updated payload of the block argument.
-    auto newPayloadOps = state.getPayloadOps(body.front().getArgument(0));
-    if (llvm::range_size(newPayloadOps) != 1) {
-      LLVM_DEBUG(DBGS() << "expected a single scope post transformation\n");
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
-    Operation *newPayloadOp = *newPayloadOps.begin();
-    scope = dyn_cast<scf::ScopeOp>(newPayloadOp);
-#endif
     if (!scope) {
       LLVM_DEBUG(DBGS() << "scope missing post transformation\n");
       return DiagnosedSilenceableFailure::definiteFailure();



More information about the Mlir-commits mailing list