[Mlir-commits] [mlir] [mlir][scf][transform] Add scope op & transform (PR #87352)

Gil Rapaport llvmlistbot at llvm.org
Sat Apr 6 07:35:52 PDT 2024


https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/87352

>From 2196f35cb70b52e303bd2c93811279e242644d05 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Mon, 25 Dec 2023 13:23:10 +0200
Subject: [PATCH 1/2] [mlir][scf][transform] Add scope op & transform

Add scf.scope modeling a code block and transform.as_scope which applies
transformations over it. The scf.scope op is similar to scf.execute_region
except it's isolated-from-above, which allows applying transforms to a
limited region without resorting to function outlining.

The transform.as_scope op creates a temporary scf.scope containing
clones of its operands and applies the transformations within its body
to that scope. If successful, the scope's results replace all uses of
any original operation outside the scope, the original operations are
erased and the scope is inlined. On failure, the scf.scope is erased.
---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  24 +-
 .../Dialect/Transform/IR/TransformAttrs.td    |   9 +
 .../mlir/Dialect/Transform/IR/TransformOps.td |  66 ++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  19 ++
 mlir/lib/Dialect/Transform/IR/CMakeLists.txt  |   1 +
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 292 ++++++++++++++++++
 mlir/test/Dialect/SCF/invalid.mlir            |  27 +-
 mlir/test/Dialect/SCF/ops.mlir                |  25 ++
 mlir/test/Dialect/Transform/ops-invalid.mlir  |  30 ++
 .../Dialect/Transform/test-interpreter.mlir   | 217 +++++++++++++
 10 files changed, 708 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..2f7a005d96f403 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -931,6 +931,28 @@ def ReduceReturnOp :
   let hasVerifier = 1;
 }
 
+def ScopeOp : SCF_Op<"scope",
+    [AutomaticAllocationScope,
+     RecursiveMemoryEffects,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"scf::YieldOp">]> {
+  let summary = "isolated code scope";
+  let description = [{
+    This op models an executed-once region, similar to `execute_region`
+    except it's isolated from above, which facilitates local scoping of
+    code without resorting to function outlining.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$body);
+
+  let assemblyFormat = [{
+    $operands attr-dict `:` functional-type($operands, $results) $body
+  }];
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
@@ -1155,7 +1177,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
 
 def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
     ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
-                 "WhileOp"]>]> {
+                 "ScopeOp", "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index ebad2994880e75..1d888e433c168f 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -33,4 +33,13 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
   let cppNamespace = "::mlir::transform";
 }
 
+def BeforeCase : I32EnumAttrCase<"Before", 1, "before">;
+def AfterCase : I32EnumAttrCase<"After", 2, "after">;
+
+def RelativeLocation : I32EnumAttr<
+    "RelativeLocation", "Relative location specifier",
+    [BeforeCase, AfterCase]> {
+  let cppNamespace = "::mlir::transform";
+}
+
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 21c9595860d4c5..29da1b919add3b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -434,6 +434,72 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
   }];
 }
 
+def AsScopeOp : TransformDialectOp<"as_scope", [
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     IsolatedFromAbove,
+     SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
+  let summary = "Attempts sequences of transforms until one succeeds";
+  let description = [{
+    The purpose of this op is to provide an ad-hoc scope for transformations,
+    in order to narrow their impact to specific ops.
+
+    This op creates a temporary scf.scope containing clones of its first operand
+    and applies the transformations within its body to that scope. If
+    successful, the scope's results replace all uses of any original operation
+    outside the scope, the scope is inlined and the original operations are
+    erased. On failure, the scf.scope is erased, leaving payload IR unmodified.
+
+    The operation takes as arguments a handle whose payload ops are to be
+    scoped in the order they are listed, and and optional pair of attribute,
+    handle defining an insertion point for the scf.scope. If specified, the
+    insertion point handle must hold a single payload op. If omitted, the
+    payload ops must form an interval. It is the user's responsibility to make
+    sure that the payload ops can be moved to the designated location in the
+    designated order.
+
+    This operation is useful for narrowing the scope of ops affected by the some
+    transformation, either for functional reasons, e.g.
+
+    ```mlir
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    ^bb2(%s: !transform.any_op):
+      transform.apply_patterns to %s {
+        transform.apply_patterns.canonicalization
+      } : !transform.any_op
+      transform.yield
+    }
+    ```
+
+    or for saving on compile time/memory by minimizing cloning, e.g.
+
+    ```mlir
+    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> (), !transform.any_op {
+    ^bb2(%s: !transform.any_op):
+      transform.alternatives %s : !transform.any_op {
+      ^bb2(%arg2: !transform.any_op):
+        // ...
+      }, {
+      ^bb2(%arg2: !transform.any_op):
+        // ...
+      }
+    }
+    ```
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$what,
+                       OptionalAttr<RelativeLocation>:$where,
+                       Optional<TransformHandleTypeInterface>:$whereOp);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$results);
+  let regions = (region SizedRegion<1>:$body);
+
+  let assemblyFormat = [{
+    $what ($where $whereOp^)? `:` functional-type(operands, results)
+    attr-dict-with-keyword $body
+  }];
+  let hasVerifier = 1;
+}
+
 def CastOp : TransformDialectOp<"cast",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5bca8e85f889d9..6a8d28afdbc251 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3079,6 +3079,25 @@ LogicalResult ReduceReturnOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScopeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult scf::ScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WhileOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 5b4989f328e690..9690004bbf8129 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRTransformDialect
   MLIRPass
   MLIRRewrite
   MLIRSideEffectInterfaces
+  MLIRSCFDialect
   MLIRTransforms
   MLIRTransformDialectInterfaces
   MLIRTransformDialectUtils
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022219e5b2..aeb930001ef9e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -8,9 +8,12 @@
 
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 
+#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -19,6 +22,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
@@ -796,6 +800,294 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// AsScopeOp
+//===----------------------------------------------------------------------===//
+
+static std::optional<std::pair<Operation *, Operation *>>
+getInterval(SetVector<Operation *> &ops) {
+  assert(!ops.empty() && "Expected non-empty operation list");
+  Operation *earliest = ops[0];
+  Operation *latest = ops[ops.size() - 1];
+  Block *block = earliest->getBlock();
+  if (latest->getBlock() != block)
+    return std::nullopt;
+  if (latest->isBeforeInBlock(earliest))
+    std::swap(earliest, latest);
+  for (Operation *op : ops) {
+    if (op->getBlock() != block)
+      return std::nullopt;
+    if (op->isBeforeInBlock(earliest))
+      earliest = op;
+    else if (latest->isBeforeInBlock(op))
+      latest = op;
+    else
+      ;
+  }
+
+  // Make sure all operations between earliest and latest are in ops.
+  auto end = latest->getIterator();
+  for (auto it = earliest->getIterator(); it != end; ++it)
+    if (!ops.contains(&*it))
+      return std::nullopt;
+
+  return std::make_pair(earliest, latest);
+}
+
+DiagnosedSilenceableFailure
+transform::AsScopeOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+  Value ops = getWhat();
+  SetVector<Operation *> opsPayload;
+  SmallVector<Operation *> opsClones;
+  SmallVector<Value> liveIns;
+  DenseMap<Value, unsigned> liveInIndex;
+  SmallVector<Value> liveOuts;
+  DenseMap<Value, unsigned> liveOutIndex;
+  SmallVector<Type> resultTypes;
+
+  for (Operation *payload : state.getPayloadOps(ops))
+    opsPayload.insert(payload);
+
+  Operation *wherePayload = nullptr;
+  Value whereOp = getWhereOp();
+  auto whereAttr = getWhere();
+  RelativeLocation where;
+  if (whereOp) {
+    auto wherePayloadOps = state.getPayloadOps(whereOp);
+    if (std::distance(wherePayloadOps.begin(), wherePayloadOps.end()) != 1) {
+      auto diag = emitDefiniteFailure()
+                  << "expects a single location for the scope";
+      diag.attachNote(whereOp.getLoc()) << "single location";
+      return diag;
+    }
+    wherePayload = *wherePayloadOps.begin();
+    where = *whereAttr;
+  } else {
+    // No insertion point specified, so payload ops must form an interval.
+    auto interval = getInterval(opsPayload);
+    if (!interval) {
+      auto diag = emitDefiniteFailure()
+                  << "payload ops must form an interval unless insertion point "
+                     "is specified";
+      diag.attachNote(ops.getLoc()) << "not an interval";
+      return diag;
+    }
+    Operation *last;
+    std::tie(std::ignore, last) = *interval;
+    wherePayload = last;
+    where = RelativeLocation::After;
+  }
+
+  auto isInScope = [&opsPayload](Operation *operation) {
+    return opsPayload.contains(operation);
+  };
+
+  unsigned nextScopeOperandIndex = 0;
+  unsigned nextScopeResultIndex = 0;
+
+  for (auto payload : opsPayload) {
+    if (payload == getOperation()) {
+      auto diag = emitDefiniteFailure()
+                  << "scope ops must not contain the transform being applied";
+      diag.attachNote(payload->getLoc()) << "scope";
+      return diag;
+    }
+
+    for (Value operand : payload->getOperands()) {
+      if (liveInIndex.contains(operand))
+        continue; // Already set as a live in.
+
+      Operation *def = operand.getDefiningOp();
+      if (def && opsPayload.contains(def))
+        continue; // Will be defined within scope, so not a live in.
+
+      // Set this operand as the next operand of the scope op.
+      liveIns.push_back(operand);
+      liveInIndex[operand] = nextScopeOperandIndex++;
+    }
+    for (Value result : payload->getResults()) {
+      if (liveOutIndex.contains(result))
+        continue; // Already set as a live out.
+      if (llvm::all_of(result.getUsers(), isInScope))
+        continue; // All users are in scope, so not a live out.
+
+      // Set this result as the next result of the scope op.
+      liveOuts.push_back(result);
+      liveOutIndex[result] = nextScopeResultIndex++;
+      resultTypes.push_back(result.getType());
+    }
+  }
+
+  if (where == RelativeLocation::Before)
+    rewriter.setInsertionPoint(wherePayload);
+  else
+    rewriter.setInsertionPointAfter(wherePayload);
+
+#define WRAP_IN_EXECUTE_REGION
+#ifdef WRAP_IN_EXECUTE_REGION
+  TypeRange noTypes;
+  ValueRange noValues;
+  auto executeRegion = rewriter.create<scf::ExecuteRegionOp>(
+      wherePayload->getLoc(), noTypes, noValues);
+  Block &executeRegionBody = executeRegion.getRegion().emplaceBlock();
+  rewriter.setInsertionPointToStart(&executeRegionBody);
+#endif
+
+  auto scope = rewriter.create<scf::ScopeOp>(wherePayload->getLoc(),
+                                             resultTypes, liveIns);
+  Region *scopeBody = &scope.getBody();
+  // TODO: Move into builder.
+  Block *scopeBlock = rewriter.createBlock(scopeBody, scopeBody->end());
+
+  // TODO: Move into builder.
+  for (Value arg : liveIns)
+    scopeBlock->addArgument(arg.getType(), arg.getLoc());
+
+  IRMapping mapper;
+  for (Value liveIn : liveIns)
+    mapper.map(liveIn, scopeBlock->getArgument(liveInIndex[liveIn]));
+
+  rewriter.setInsertionPointToEnd(scopeBlock);
+  for (auto payload : opsPayload)
+    rewriter.clone(*payload, mapper);
+
+  SmallVector<Value> scopeResults;
+  for (Value liveOut : liveOuts)
+    scopeResults.push_back(mapper.lookup(liveOut));
+  rewriter.create<scf::YieldOp>(scope.getLoc(), scopeResults);
+
+  Region &body = getBody();
+
+  DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
+
+  // Create a region scope for the transform::as_scope region and run the
+  // transformations it contains. This is done in a syntactic C++ scope since
+  // we'll be deleting the scf::scope, and its mapping will become invalid if
+  // left until compactOpHandles() is called.
+  {
+    auto regionScope = state.make_region_scope(body);
+    if (failed(state.mapBlockArguments(body.front().getArgument(0), {scope})))
+      return DiagnosedSilenceableFailure::definiteFailure();
+
+    for (Operation &transform : body.front().without_terminator()) {
+      result = state.applyTransform(cast<TransformOpInterface>(transform));
+      if (result.isSilenceableFailure()) {
+        LLVM_DEBUG(DBGS() << "transformation on scope failed: "
+                          << result.getMessage() << "\n");
+        break;
+      }
+
+      if (::mlir::failed(result.silence()))
+        return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+#ifdef WRAP_IN_EXECUTE_REGION
+    // Since the transformations applied may have replaced the scope, get the
+    // current scope from the wrapping scf.execute_region.
+    scope = dyn_cast<scf::ScopeOp>(*executeRegionBody.getOperations().begin());
+    scope->moveAfter(executeRegion);
+    rewriter.eraseOp(executeRegion);
+#else
+    // Since the transformations applied may have replaced the scope, get the
+    // updated payload of the block argument.
+    auto newPayloadOps = state.getPayloadOps(body.front().getArgument(0));
+    if (llvm::range_size(newPayloadOps) != 1) {
+      LLVM_DEBUG(DBGS() << "expected a single scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    Operation *newPayloadOp = *newPayloadOps.begin();
+    scope = dyn_cast<scf::ScopeOp>(newPayloadOp);
+#endif
+    if (!scope) {
+      LLVM_DEBUG(DBGS() << "scope missing post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    if (failed(scope.verify())) {
+      LLVM_DEBUG(DBGS() << "invalid scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    scopeBody = &scope.getBody();
+    if (!scopeBody->hasOneBlock()) {
+      LLVM_DEBUG(DBGS() << "multiple blocks in scope post transformation\n");
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+    scopeBlock = &scopeBody->front();
+  }
+
+  if (!result.succeeded()) {
+    // Erase the scope and return a silenaceable failure.
+    rewriter.eraseOp(scope);
+    return result;
+  }
+
+  // The transformations applied to the scope succeeded. Complete the
+  // transformation by inlining the new contents of the scope and having
+  // the scope's yielded values replace the original live-out values.
+  // Finally, erase the original ops.
+
+  LLVM_DEBUG(DBGS() << "scope after transformations: " << scope << "\n");
+
+  // Update the live-outs to what the scope's terminator now yields.
+  auto scopeYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
+  SmallVector<Value> transformedLiveOuts =
+      llvm::to_vector(scopeYield.getResults());
+
+  IRMapping inlineMapper;
+  for (auto [blockArg, operand] :
+       llvm::zip(scopeBlock->getArguments(), scope.getOperands()))
+    inlineMapper.map(blockArg, operand);
+
+  rewriter.setInsertionPoint(scope);
+  for (Operation &scopeOp : scopeBlock->without_terminator())
+    rewriter.clone(scopeOp, inlineMapper);
+
+  for (auto [liveOut, transformedLiveOut] :
+       llvm::zip(liveOuts, transformedLiveOuts)) {
+    Value inlinedLiveOut = inlineMapper.lookup<Value>(transformedLiveOut);
+    rewriter.replaceAllUsesWith(liveOut, inlinedLiveOut);
+  }
+
+  rewriter.eraseOp(scope);
+  while (!opsPayload.empty()) {
+    for (auto payload : llvm::make_early_inc_range(opsPayload)) {
+      if (payload->getUsers().empty()) {
+        rewriter.eraseOp(payload);
+        opsPayload.remove(payload);
+      }
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::AsScopeOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getWhereOp(), effects);
+  consumesHandle(getWhat(), effects);
+  producesHandle(getResults(), effects);
+  Region &region = getRegion();
+  if (!region.empty())
+    producesHandle(region.front().getArguments(), effects);
+  modifiesPayload(effects);
+}
+
+LogicalResult transform::AsScopeOp::verify() {
+  Region &body = getBody();
+  Block &block = body.front();
+  Operation *terminator = block.getTerminator();
+  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects terminator operands to have the "
+                                 "same type as results of the operation";
+    diag.attachNote(terminator->getLoc()) << "terminator";
+    return diag;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 337eb9eeb8fa57..6871881a49458c 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -476,7 +476,7 @@ func.func @parallel_invalid_yield(
 
 func.func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}}
+   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.scope, scf.while'}}
    scf.yield
   }) : () -> ()
   return
@@ -747,3 +747,28 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @scope_not_isolated_from_above(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-note @below {{required by region isolation constraints}}
+  %p = scf.scope : () -> (i32) {
+  ^bb0():
+    // expected-error @below {{'arith.addi' op using value defined outside the region}}
+    %add = arith.addi %arg0, %arg1 : i32
+    scf.yield %add : i32
+  }
+  return %p : i32
+}
+
+// -----
+
+func.func @scope_yield_results_mismatch(%arg0 : i32, %arg1 : i32) -> (i32) {
+  // expected-error @below {{'scf.scope' op expects terminator operands to have the same type as results of the operation}}
+  %p = scf.scope %arg0, %arg1 : (i32, i32) -> (i32) {
+  ^bb0(%k : i32, %t : i32):
+    %add = arith.addi %k, %t : i32
+    // expected-note @below {{terminator}}
+    scf.yield %add, %add : i32, i32
+  }
+  return %p : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 7f457ef3b6ba0c..70cb72a9ec0bf6 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,28 @@ func.func @switch(%arg0: index) -> i32 {
 
   return %0 : i32
 }
+
+// CHECK-LABEL: @scope
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32
+func.func @scope(%arg0 : i32, %arg1 : f32, %arg2 : f32) -> (f32) {
+  // CHECK: %[[VAL_3:.*]] = arith.constant 77 : i32
+  %c77 = arith.constant 77 : i32
+
+  // CHECK: %[[VAL_4:.*]]:2 = scf.scope %[[VAL_0]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : (i32, i32, f32, f32) -> (i32, f32)
+  %p:2 = scf.scope %arg0, %c77, %arg1, %arg2 : (i32, i32, f32, f32) -> (i32, f32) {
+  // CHECK: ^bb0(%[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+  ^bb0(%a : i32, %b : i32, %c : f32, %d : f32):
+    // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
+    %add = arith.addi %a, %b : i32
+    // CHECK: %[[VAL_10:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32
+    %mul = arith.mulf %c, %d : f32
+    // CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : i32, f32
+    scf.yield %add, %mul : i32, f32
+  }
+
+  // CHECK: %[[VAL_11:.*]] = arith.sitofp %[[VAL_4:.*]]#0 : i32 to f32
+  %m = arith.sitofp %p#0 : i32 to f32
+  // CHECK: %[[VAL_13:.*]] = arith.subf %[[VAL_11]], %[[VAL_4]]#1 : f32
+  %r = arith.subf %m, %p#1 : f32
+  return %r : f32
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index cc04e65420c5b7..0eeca7fdc02733 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -782,3 +782,33 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  %1 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  %what = transform.merge_handles %0, %1 : !transform.any_op
+  %where = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-error @below {{expects terminator operands to have the same type as results of the operation}}
+  %2:2 = transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) {
+  ^bb2(%arg2: !transform.any_op):
+    // expected-note @below {{terminator}}
+    transform.yield
+  }
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  // expected-error @below {{result #0 has more than one potential consumer}}
+  %0 = test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+  // expected-note @below {{used here as operand #1}}
+  // expected-note @below {{used here as operand #0}}
+  transform.as_scope %0 before %0 : (!transform.any_op, !transform.any_op) -> () {
+  ^bb2(%arg2: !transform.any_op):
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index b6850e2024d53d..61dc24bbb2caa3 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2513,3 +2513,220 @@ module attributes { transform.with_named_sequence } {
     transform.yield %arg0 : !transform.any_op
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @scope_before(
+//  CHECK-SAME:                         %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) -> (i32, i32) {
+//       CHECK:   %[[VAL_5:.*]] = arith.muli %[[VAL_3]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.subi %[[VAL_0]], %[[VAL_1]] : i32
+//       CHECK:   %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_8:.*]] = arith.divui %[[VAL_5]], %[[VAL_4]] : i32
+//       CHECK:   return %[[VAL_7]], %[[VAL_8]] : i32, i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_before(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+    %sub = arith.subi %arg0, %arg1 : i32
+    %add = arith.addi %sub, %arg2 : i32
+    %mul = arith.muli %arg3, %arg3 : i32
+    %div = arith.divui %mul, %arg4 : i32
+    return %add, %div : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.structured.match ops{["arith.subi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.divui"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_after(
+//  CHECK-SAME:                        %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) -> (i32, i32) {
+//       CHECK:   %[[VAL_7:.*]] = arith.muli %[[VAL_3]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_5:.*]] = arith.subi %[[VAL_0]], %[[VAL_1]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_8:.*]] = arith.divui %[[VAL_7]], %[[VAL_4]] : i32
+//       CHECK:   return %[[VAL_6]], %[[VAL_8]] : i32, i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_after(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+    %sub = arith.subi %arg0, %arg1 : i32
+    %add = arith.addi %sub, %arg2 : i32
+    %mul = arith.muli %arg3, %arg3 : i32
+    %div = arith.divui %mul, %arg4 : i32
+    return %add, %div : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.structured.match ops{["arith.subi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    transform.as_scope %what after %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+// -----
+
+// CHECK-LABEL: func.func @scope_partial_folding() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_1:.*]] = arith.constant 33 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 99 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_4:.*]] = arith.addi %[[VAL_0]], %[[VAL_3]] : i32
+//       CHECK:   %[[VAL_5:.*]] = arith.muli %[[VAL_1]], %[[VAL_4]] : i32
+//       CHECK:   %[[VAL_6:.*]] = arith.divui %[[VAL_5]], %[[VAL_2]] : i32
+//       CHECK:   return %[[VAL_6]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_partial_folding() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c33 = arith.constant 33 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    %c99 = arith.constant 99 : i32
+    %mul = arith.muli %c33, %add : i32
+    %div = arith.divui %mul, %c99 : i32
+    return %div : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.apply_patterns to %s {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_interval() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.addi %[[VAL_0]], %[[VAL_2]] : i32
+//       CHECK:   return %[[VAL_3]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_interval() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    return %add : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    transform.as_scope %what : (!transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.apply_patterns to %s {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+module attributes { transform.with_named_sequence } {
+  func.func @scope_not_an_interval(%a: i32, %b: i32, %c: i32) -> (i32, i32) {
+    %sub = arith.subi %a, %b : i32
+    %mul = arith.muli %a, %b : i32
+    %add = arith.addi %sub, %c : i32
+    return %add, %mul : i32, i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    // expected-note @below {{not an interval}}
+    %what = transform.merge_handles %subi, %addi : !transform.any_op
+    // expected-error @below {{payload ops must form an interval unless insertion point is specified}}
+    transform.as_scope %what : (!transform.any_op) -> () {
+      ^bb2(%s: !transform.any_op):
+        transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scope_with_alternatives() -> i32 {
+//       CHECK:   %[[VAL_0:.*]] = arith.constant 17 : i32
+//       CHECK:   %[[VAL_1:.*]] = arith.constant 33 : i32
+//       CHECK:   %[[VAL_2:.*]] = arith.constant 46 : i32
+//       CHECK:   %[[VAL_3:.*]] = arith.addi %[[VAL_0]], %[[VAL_2]] : i32
+//       CHECK:   %[[VAL_4:.*]] = arith.muli %[[VAL_1]], %[[VAL_3]] : i32
+//       CHECK:   return %[[VAL_4]] : i32
+//       CHECK: }
+module attributes { transform.with_named_sequence } {
+  func.func @scope_with_alternatives() -> i32 {
+    %c17 = arith.constant 17 : i32
+    %c33 = arith.constant 33 : i32
+    %c75 = arith.constant 75 : i32
+    %c29 = arith.constant 29 : i32
+    %sub = arith.subi %c75, %c29 : i32
+    %add = arith.addi %sub, %c17 : i32
+    %mul = arith.muli %c33, %add : i32
+    return %mul : i32
+  }
+
+  transform.named_sequence @__transform_main(%fun: !transform.any_op) {
+    %addi = transform.structured.match ops{["arith.addi"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %subi = transform.get_producer_of_operand %addi[0] : (!transform.any_op) -> !transform.any_op
+    %const0 = transform.get_producer_of_operand %subi[0] : (!transform.any_op) -> !transform.any_op
+    %const1 = transform.get_producer_of_operand %subi[1] : (!transform.any_op) -> !transform.any_op
+    %where = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
+    %what = transform.merge_handles %const1, %const0, %subi, %addi : !transform.any_op
+    transform.as_scope %what before %where : (!transform.any_op, !transform.any_op) -> () {
+    ^bb2(%s: !transform.any_op):
+      transform.alternatives %s : !transform.any_op {
+      ^bb2(%arg2: !transform.any_op):
+        %1 = transform.test_produce_self_handle_or_forward_operand : () -> !transform.any_op
+        // This operation fails, which triggers the next alternative without
+        // reporting the error.
+        transform.test_consume_operand_of_op_kind_or_fail %1, "transform.sequence" : !transform.any_op
+        transform.yield
+      }, {
+      ^bb2(%arg2: !transform.any_op):
+        transform.apply_patterns to %arg2 {
+          transform.apply_patterns.canonicalization
+        } : !transform.any_op
+        transform.yield
+      }
+      transform.yield
+    }
+    transform.yield
+  }
+}

>From 82ae3f55560baa10040ed0d2aac9c75f2d30f009 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Sat, 6 Apr 2024 17:35:45 +0300
Subject: [PATCH 2/2] Update
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 29da1b919add3b..1d7384c66c0efa 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -444,7 +444,7 @@ def AsScopeOp : TransformDialectOp<"as_scope", [
     The purpose of this op is to provide an ad-hoc scope for transformations,
     in order to narrow their impact to specific ops.
 
-    This op creates a temporary scf.scope containing clones of its first operand
+    This op creates a temporary scf.scope containing clones of payload operations associated with its first operand
     and applies the transformations within its body to that scope. If
     successful, the scope's results replace all uses of any original operation
     outside the scope, the scope is inlined and the original operations are



More information about the Mlir-commits mailing list