[Mlir-commits] [mlir] da784e7 - [mlir] Add a utility function to make a region isolated from above.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Apr 20 09:40:38 PDT 2023


Author: Mahesh Ravishankar
Date: 2023-04-20T16:40:25Z
New Revision: da784e77da7715f7c1da9918e1a2bf2239c5fd06

URL: https://github.com/llvm/llvm-project/commit/da784e77da7715f7c1da9918e1a2bf2239c5fd06
DIFF: https://github.com/llvm/llvm-project/commit/da784e77da7715f7c1da9918e1a2bf2239c5fd06.diff

LOG: [mlir] Add a utility function to make a region isolated from above.

The utility functions takes a region and makes it isolated from above
by appending to the entry block arguments that represent the captured
values and replacing all uses of the captured values within the region
with the newly added arguments. The captures values are returned.

The utility function also takes an optional callback that allows
cloning operations that define the captured values into the region
during the process of making it isolated from above. The cloned value
is no longer a captured values. The operands of the operation are then
captured values. This is done transitively allow cloning of a DAG of
operations into the region based on the callback.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D148684

Added: 
    mlir/test/Transforms/make-isolated-from-above.mlir
    mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp

Modified: 
    mlir/include/mlir/Transforms/RegionUtils.h
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 109b76ce2a9b7..06eebff201d1b 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -51,6 +51,24 @@ void getUsedValuesDefinedAbove(Region &region, Region &limit,
 void getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
                                SetVector<Value> &values);
 
+/// Make a region isolated from above
+/// - Capture the values that are defined above the region and used within it.
+/// - Append to the entry block arguments that represent the captured values
+/// (one per captured value).
+/// - Replace all uses within the region of the captured values with the
+///   newly added arguments.
+/// - `cloneOperationIntoRegion` is a callback that allows caller to specify
+///   if the operation defining an `OpOperand` needs to be cloned into the
+///   region. Then the operands of this operation become part of the captured
+///   values set (unless the operations that define the operands themeselves
+///   are to be cloned). The cloned operations are added to the entry block
+///   of the region.
+/// Return the set of captured values for the operation.
+SmallVector<Value> makeRegionIsolatedFromAbove(
+    RewriterBase &rewriter, Region &region,
+    llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
+        [](Operation *) { return false; });
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e5824bd6663b0..2933758db62b9 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -8,17 +8,21 @@
 
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SmallSet.h"
 
+#include <deque>
+
 using namespace mlir;
 
 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
@@ -69,6 +73,102 @@ void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
     getUsedValuesDefinedAbove(region, region, values);
 }
 
+//===----------------------------------------------------------------------===//
+// Make block isolated from above.
+//===----------------------------------------------------------------------===//
+
+SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
+    RewriterBase &rewriter, Region &region,
+    llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
+
+  // Get initial list of values used within region but defined above.
+  llvm::SetVector<Value> initialCapturedValues;
+  mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
+
+  std::deque<Value> worklist(initialCapturedValues.begin(),
+                             initialCapturedValues.end());
+  llvm::DenseSet<Value> visited;
+  llvm::DenseSet<Operation *> visitedOps;
+
+  llvm::SetVector<Value> finalCapturedValues;
+  SmallVector<Operation *> clonedOperations;
+  while (!worklist.empty()) {
+    Value currValue = worklist.front();
+    worklist.pop_front();
+    if (visited.count(currValue))
+      continue;
+    visited.insert(currValue);
+
+    Operation *definingOp = currValue.getDefiningOp();
+    if (!definingOp || visitedOps.count(definingOp)) {
+      finalCapturedValues.insert(currValue);
+      continue;
+    }
+    visitedOps.insert(definingOp);
+
+    if (!cloneOperationIntoRegion(definingOp)) {
+      // Defining operation isnt cloned, so add the current value to final
+      // captured values list.
+      finalCapturedValues.insert(currValue);
+      continue;
+    }
+
+    // Add all operands of the operation to the worklist and mark the op as to
+    // be cloned.
+    for (Value operand : definingOp->getOperands()) {
+      if (visited.count(operand))
+        continue;
+      worklist.push_back(operand);
+    }
+    clonedOperations.push_back(definingOp);
+  }
+
+  // The operations to be cloned need to be ordered in topological order
+  // so that they can be cloned into the region without violating use-def
+  // chains.
+  mlir::computeTopologicalSorting(clonedOperations);
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Collect types of existing block
+  Block *entryBlock = &region.front();
+  SmallVector<Type> newArgTypes =
+      llvm::to_vector(entryBlock->getArgumentTypes());
+  SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
+      entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
+
+  // Append the types of the captured values.
+  for (auto value : finalCapturedValues) {
+    newArgTypes.push_back(value.getType());
+    newArgLocs.push_back(value.getLoc());
+  }
+
+  // Create a new entry block.
+  Block *newEntryBlock =
+      rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
+  auto newEntryBlockArgs = newEntryBlock->getArguments();
+
+  // Create a mapping between the captured values and the new arguments added.
+  IRMapping map;
+  auto replaceIfFn = [&](OpOperand &use) {
+    return use.getOwner()->getBlock()->getParent() == ®ion;
+  };
+  for (auto [arg, capturedVal] :
+       llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
+                 finalCapturedValues)) {
+    map.map(capturedVal, arg);
+    rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
+  }
+  rewriter.setInsertionPointToStart(newEntryBlock);
+  for (auto clonedOp : clonedOperations) {
+    Operation *newOp = rewriter.clone(*clonedOp, map);
+    rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+  }
+  rewriter.mergeBlocks(
+      entryBlock, newEntryBlock,
+      newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
+  return llvm::to_vector(finalCapturedValues);
+}
+
 //===----------------------------------------------------------------------===//
 // Unreachable Block Elimination
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir
new file mode 100644
index 0000000000000..58f6cfbc5dd65
--- /dev/null
+++ b/mlir/test/Transforms/make-isolated-from-above.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -test-make-isolated-from-above=simple -allow-unregistered-dialect --split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-make-isolated-from-above=clone-ops-with-no-operands -allow-unregistered-dialect --split-input-file %s | FileCheck %s --check-prefix=CLONE1
+// RUN: mlir-opt -test-make-isolated-from-above=clone-ops-with-operands -allow-unregistered-dialect --split-input-file %s | FileCheck %s --check-prefix=CLONE2
+
+func.func @make_isolated_from_above_single_block(%arg0 : index, %arg1 : index) {
+  %c0 = arith.constant 0: index
+  %c1 = arith.constant 1 : index
+  %empty = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
+  %d0 = tensor.dim %empty, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %empty, %c1 : tensor<?x?xf32>
+  "test.one_region_with_operands_op"() ({
+    "foo.yield"(%c0, %c1, %d0, %d1) : (index, index, index, index) -> ()
+  }) : () -> ()
+  return
+}
+// CHECK-LABEL: func @make_isolated_from_above_single_block(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//       CHECK:   test.isolated_one_region_op %[[C0]], %[[C1]], %[[D0]], %[[D1]]
+//  CHECK-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index)
+//       CHECK:       "foo.yield"(%[[B0]], %[[B1]], %[[B2]], %[[B3]])
+
+// CLONE1-LABEL: func @make_isolated_from_above_single_block(
+//  CLONE1-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CLONE1-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//   CLONE1-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CLONE1-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CLONE1-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CLONE1-DAG:   %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CLONE1-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//       CLONE1:   test.isolated_one_region_op %[[D0]], %[[D1]]
+//  CLONE1-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index)
+//   CLONE1-DAG:       %[[C0_0:.+]] = arith.constant 0 : index
+//   CLONE1-DAG:       %[[C1_0:.+]] = arith.constant 1 : index
+//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B0]], %[[B1]])
+
+// CLONE2-LABEL: func @make_isolated_from_above_single_block(
+//  CLONE2-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CLONE2-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//       CLONE2:   test.isolated_one_region_op %[[ARG0]], %[[ARG1]]
+//  CLONE2-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index)
+//   CLONE2-DAG:       %[[C0:.+]] = arith.constant 0 : index
+//   CLONE2-DAG:       %[[C1:.+]] = arith.constant 1 : index
+//   CLONE2-DAG:       %[[EMPTY:.+]] = tensor.empty(%[[B0]], %[[B1]])
+//   CLONE2-DAG:       %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CLONE2-DAG:       %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]])
+
+// -----
+
+func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %c0 = arith.constant 0: index
+  %c1 = arith.constant 1 : index
+  %empty = tensor.empty(%arg0, %arg1) : tensor<?x?xf32>
+  %d0 = tensor.dim %empty, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %empty, %c1 : tensor<?x?xf32>
+  "test.one_region_with_operands_op"(%arg2) ({
+    ^bb0(%b0 : index):
+      cf.br ^bb1(%b0: index)
+    ^bb1(%b1 : index):
+    "foo.yield"(%c0, %c1, %d0, %d1, %b1) : (index, index, index, index, index) -> ()
+  }) : (index) -> ()
+  return
+}
+// CHECK-LABEL: func @make_isolated_from_above_multiple_blocks(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//       CHECK:   test.isolated_one_region_op %[[ARG2]], %[[C0]], %[[C1]], %[[D0]], %[[D1]]
+//  CHECK-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index, %[[B3:[a-zA-Z0-9]+]]: index, %[[B4:[a-zA-Z0-9]+]]: index)
+//  CHECK-NEXT:       cf.br ^bb1(%[[B0]] : index)
+//       CHECK:     ^bb1(%[[B5:.+]]: index)
+//       CHECK:       "foo.yield"(%[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]])
+
+// CLONE1-LABEL: func @make_isolated_from_above_multiple_blocks(
+//  CLONE1-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CLONE1-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CLONE1-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+//   CLONE1-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CLONE1-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CLONE1-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CLONE1-DAG:   %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CLONE1-DAG:   %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//       CLONE1:   test.isolated_one_region_op %[[ARG2]], %[[D0]], %[[D1]]
+//  CLONE1-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
+//   CLONE1-DAG:       %[[C0_0:.+]] = arith.constant 0 : index
+//   CLONE1-DAG:       %[[C1_0:.+]] = arith.constant 1 : index
+//  CLONE1-NEXT:       cf.br ^bb1(%[[B0]] : index)
+//       CLONE1:     ^bb1(%[[B3:.+]]: index)
+//       CLONE1:       "foo.yield"(%[[C0_0]], %[[C1_0]], %[[B1]], %[[B2]], %[[B3]])
+
+// CLONE2-LABEL: func @make_isolated_from_above_multiple_blocks(
+//  CLONE2-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+//  CLONE2-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+//  CLONE2-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+//       CLONE2:   test.isolated_one_region_op %[[ARG2]], %[[ARG0]], %[[ARG1]]
+//  CLONE2-NEXT:     ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: index, %[[B2:[a-zA-Z0-9]+]]: index)
+//   CLONE2-DAG:       %[[C0:.+]] = arith.constant 0 : index
+//   CLONE2-DAG:       %[[C1:.+]] = arith.constant 1 : index
+//   CLONE2-DAG:       %[[EMPTY:.+]] = tensor.empty(%[[B1]], %[[B2]])
+//   CLONE2-DAG:       %[[D0:.+]] = tensor.dim %[[EMPTY]], %[[C0]]
+//   CLONE2-DAG:       %[[D1:.+]] = tensor.dim %[[EMPTY]], %[[C1]]
+//  CLONE2-NEXT:       cf.br ^bb1(%[[B0]] : index)
+//       CLONE2:     ^bb1(%[[B3:.+]]: index)
+//       CLONE2:       "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B3]])

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9381409f07844..887031e6ff434 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -438,6 +438,19 @@ def VariadicRegionInferredTypesOp : TEST_Op<"variadic_region_inferred",
   }];
 }
 
+def OneRegionWithOperandsOp : TEST_Op<"one_region_with_operands_op", []> {
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let regions = (region AnyRegion);
+}
+
+def IsolatedOneRegionOp : TEST_Op<"isolated_one_region_op", [IsolatedFromAbove]> {
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let regions = (region AnyRegion:$my_region);
+  let assemblyFormat = [{
+    attr-dict-with-keyword $operands $my_region `:` type($operands)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NoTerminator Operation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 0379dcd7a1968..e032ce7200fbf 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_library(MLIRTestTransforms
   TestDialectConversion.cpp
   TestInlining.cpp
   TestIntRangeInference.cpp
+  TestMakeIsolatedFromAbove.cpp
   TestTopologicalSort.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
new file mode 100644
index 0000000000000..61e1fbcf3feaf
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
@@ -0,0 +1,156 @@
+//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+using namespace mlir;
+
+/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
+/// `test.one_region_op` to `test.isolated_one_region_op`.
+static LogicalResult
+makeIsolatedFromAboveImpl(RewriterBase &rewriter,
+                          test::OneRegionWithOperandsOp regionOp,
+                          llvm::function_ref<bool(Operation *)> callBack) {
+  Region &region = regionOp.getRegion();
+  SmallVector<Value> capturedValues =
+      makeRegionIsolatedFromAbove(rewriter, region, callBack);
+  SmallVector<Value> operands = regionOp.getOperands();
+  operands.append(capturedValues);
+  auto isolatedRegionOp =
+      rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
+  rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
+                              isolatedRegionOp.getRegion().begin());
+  rewriter.eraseOp(regionOp);
+  return success();
+}
+
+namespace {
+
+/// Simple test for making region isolated from above without cloning any
+/// operations.
+struct SimpleMakeIsolatedFromAbove
+    : OpRewritePattern<test::OneRegionWithOperandsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
+                                PatternRewriter &rewriter) const override {
+    return makeIsolatedFromAboveImpl(rewriter, regionOp,
+                                     [](Operation *) { return false; });
+  }
+};
+
+/// Test for making region isolated from above while clong operations
+/// with no operands.
+struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
+    : OpRewritePattern<test::OneRegionWithOperandsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
+                                PatternRewriter &rewriter) const override {
+    return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
+      return op->getNumOperands() == 0;
+    });
+  }
+};
+
+/// Test for making region isolated from above while clong operations
+/// with no operands.
+struct MakeIsolatedFromAboveAndCloneOpsWithOperands
+    : OpRewritePattern<test::OneRegionWithOperandsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
+                                PatternRewriter &rewriter) const override {
+    return makeIsolatedFromAboveImpl(rewriter, regionOp,
+                                     [](Operation *op) { return true; });
+  }
+};
+
+/// Test pass for testing the `makeIsolatedFromAbove` function.
+struct TestMakeIsolatedFromAbovePass
+    : public PassWrapper<TestMakeIsolatedFromAbovePass,
+                         OperationPass<func::FuncOp>> {
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
+
+  TestMakeIsolatedFromAbovePass() = default;
+  TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final {
+    return "test-make-isolated-from-above";
+  }
+
+  StringRef getDescription() const final {
+    return "Test making a region isolated from above";
+  }
+
+  Option<bool> simple{
+      *this, "simple",
+      llvm::cl::desc("Test simple case with no cloning of operations"),
+      llvm::cl::init(false)};
+
+  Option<bool> cloneOpsWithNoOperands{
+      *this, "clone-ops-with-no-operands",
+      llvm::cl::desc("Test case with cloning of operations with no operands"),
+      llvm::cl::init(false)};
+
+  Option<bool> cloneOpsWithOperands{
+      *this, "clone-ops-with-operands",
+      llvm::cl::desc("Test case with cloning of operations with no operands"),
+      llvm::cl::init(false)};
+
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void TestMakeIsolatedFromAbovePass::runOnOperation() {
+  MLIRContext *context = &getContext();
+  func::FuncOp funcOp = getOperation();
+
+  if (simple) {
+    RewritePatternSet patterns(context);
+    patterns.insert<SimpleMakeIsolatedFromAbove>(context);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+    return;
+  }
+
+  if (cloneOpsWithNoOperands) {
+    RewritePatternSet patterns(context);
+    patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+    return;
+  }
+
+  if (cloneOpsWithOperands) {
+    RewritePatternSet patterns(context);
+    patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+    return;
+  }
+}
+
+namespace mlir {
+namespace test {
+void registerTestMakeIsolatedFromAbovePass() {
+  PassRegistration<TestMakeIsolatedFromAbovePass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 12eee4ba924bd..ce9008b58ba2b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -104,6 +104,7 @@ void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
 void registerTestLowerToLLVM();
+void registerTestMakeIsolatedFromAbovePass();
 void registerTestMatchReductionPass();
 void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
@@ -218,6 +219,7 @@ void registerTestPasses() {
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
   mlir::test::registerTestLowerToLLVM();
+  mlir::test::registerTestMakeIsolatedFromAbovePass();
   mlir::test::registerTestMatchReductionPass();
   mlir::test::registerTestMathAlgebraicSimplificationPass();
   mlir::test::registerTestMathPolynomialApproximationPass();


        


More information about the Mlir-commits mailing list