[Mlir-commits] [mlir] b73f1d2 - [mlir][cf-sink] Accept a callback for sinking operations

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 28 12:31:27 PDT 2022


Author: Mogball
Date: 2022-03-28T19:31:23Z
New Revision: b73f1d2c5d92a923c79435e423d7db9d6fa64eac

URL: https://github.com/llvm/llvm-project/commit/b73f1d2c5d92a923c79435e423d7db9d6fa64eac
DIFF: https://github.com/llvm/llvm-project/commit/b73f1d2c5d92a923c79435e423d7db9d6fa64eac.diff

LOG: [mlir][cf-sink] Accept a callback for sinking operations

(This was a TODO from the initial patch).

The control-flow sink utility accepts a callback that is used to sink an operation into a region.
The `moveIntoRegion` is called on the same operation and region that return true for `shouldMoveIntoRegion`.
The callback must preserve the dominance of the operation within the region. In the default control-flow
sink implementation, this is moving the operation to the start of the entry block.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D122445

Added: 
    mlir/test/Transforms/control-flow-sink-test.mlir
    mlir/test/lib/Transforms/TestControlFlowSink.cpp

Modified: 
    mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
    mlir/lib/Transforms/ControlFlowSink.cpp
    mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
    mlir/test/Transforms/control-flow-sink.mlir
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
index f45d753564f45..f4dfc6dee876d 100644
--- a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
+++ b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
@@ -51,12 +51,19 @@ class RegionBranchOpInterface;
 ///
 /// Users must supply a callback `shouldMoveIntoRegion` that determines whether
 /// the given operation that only has users in the given operation should be
-/// moved into that region.
+/// moved into that region. If this returns true, `moveIntoRegion` is called on
+/// the same operation and region.
+///
+/// `moveIntoRegion` must move the operation into the region such that dominance
+/// of the operation is preserved; for example, by moving the operation to the
+/// start of the entry block. This ensures the preservation of SSA dominance of
+/// the operation's results.
 ///
 /// Returns the number of operations sunk.
 size_t
 controlFlowSink(ArrayRef<Region *> regions, DominanceInfo &domInfo,
-                function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion);
+                function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+                function_ref<void(Operation *, Region *)> moveIntoRegion);
 
 /// Populates `regions` with regions of the provided region branch op that are
 /// executed at most once at that are reachable given the current operands of

diff  --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp
index 6643158ac5de0..9d996449c3fbe 100644
--- a/mlir/lib/Transforms/ControlFlowSink.cpp
+++ b/mlir/lib/Transforms/ControlFlowSink.cpp
@@ -60,9 +60,14 @@ void ControlFlowSink::runOnOperation() {
     // Get the regions are that known to be executed at most once.
     getSinglyExecutedRegionsToSink(branch, regionsToSink);
     // Sink side-effect free operations.
-    numSunk =
-        controlFlowSink(regionsToSink, domInfo, [](Operation *op, Region *) {
-          return isSideEffectFree(op);
+    numSunk = controlFlowSink(
+        regionsToSink, domInfo,
+        [](Operation *op, Region *) { return isSideEffectFree(op); },
+        [](Operation *op, Region *region) {
+          // Move the operation to the beginning of the region's entry block.
+          // This guarantees the preservation of SSA dominance of all of the
+          // operation's uses are in the region.
+          op->moveBefore(&region->front(), region->front().begin());
         });
   });
 }

diff  --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
index 34668fa682b62..01d6f6464a40f 100644
--- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -34,8 +34,10 @@ class Sinker {
 public:
   /// Create an operation sinker with given dominance info.
   Sinker(function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+         function_ref<void(Operation *, Region *)> moveIntoRegion,
          DominanceInfo &domInfo)
-      : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo) {}
+      : shouldMoveIntoRegion(shouldMoveIntoRegion),
+        moveIntoRegion(moveIntoRegion), domInfo(domInfo), numSunk(0) {}
 
   /// Given a list of regions, find operations to sink and sink them. Return the
   /// number of operations sunk.
@@ -61,6 +63,8 @@ class Sinker {
 
   /// The callback to determine whether an op should be moved in to a region.
   function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion;
+  /// The calback to move an operation into the region.
+  function_ref<void(Operation *, Region *)> moveIntoRegion;
   /// Dominance info to determine op user dominance with respect to regions.
   DominanceInfo &domInfo;
   /// The number of operations sunk.
@@ -90,12 +94,7 @@ void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
 
     // If the op's users are all in the region and it can be moved, then do so.
     if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
-      // Move the op into the region's entry block. If the op is part of a
-      // subgraph, dependee ops would have been moved first, so inserting before
-      // the start of the block will ensure SSA dominance is preserved locally
-      // in the subgraph. Ops can only be safely moved into the entry block as
-      // the region's other blocks may for a loop.
-      op->moveBefore(&region->front(), region->front().begin());
+      moveIntoRegion(op, region);
       ++numSunk;
       // Add the op to the work queue.
       stack.push_back(op);
@@ -127,8 +126,10 @@ size_t Sinker::sinkRegions(ArrayRef<Region *> regions) {
 
 size_t mlir::controlFlowSink(
     ArrayRef<Region *> regions, DominanceInfo &domInfo,
-    function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion) {
-  return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions);
+    function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+    function_ref<void(Operation *, Region *)> moveIntoRegion) {
+  return Sinker(shouldMoveIntoRegion, moveIntoRegion, domInfo)
+      .sinkRegions(regions);
 }
 
 void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,

diff  --git a/mlir/test/Transforms/control-flow-sink-test.mlir b/mlir/test/Transforms/control-flow-sink-test.mlir
new file mode 100644
index 0000000000000..868beb756b4f2
--- /dev/null
+++ b/mlir/test/Transforms/control-flow-sink-test.mlir
@@ -0,0 +1,44 @@
+// Invoke the test control-flow sink pass to test the utilities.
+// RUN: mlir-opt -test-control-flow-sink %s | FileCheck %s
+
+// CHECK-LABEL: func @test_sink
+func @test_sink() {
+  %0 = "test.sink_me"() : () -> i32
+  // CHECK-NEXT: test.sink_target
+  "test.sink_target"() ({
+    // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {was_sunk = 0 : i32}
+    // CHECK-NEXT: "test.use"(%[[V0]])
+    "test.use"(%0) : (i32) -> ()
+  }) : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @test_sink_first_region_only
+func @test_sink_first_region_only() {
+  %0 = "test.sink_me"() {first} : () -> i32
+  // CHECK-NEXT: %[[V1:.*]] = "test.sink_me"() {second}
+  %1 = "test.sink_me"() {second} : () -> i32
+  // CHECK-NEXT: test.sink_target
+  "test.sink_target"() ({
+    // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {first, was_sunk = 0 : i32}
+    // CHECK-NEXT: "test.use"(%[[V0]])
+    "test.use"(%0) : (i32) -> ()
+  }, {
+    "test.use"(%1) : (i32) -> ()
+  }) : () -> ()
+  return
+}
+
+// CHECK-LABEL: func @test_sink_targeted_op_only
+func @test_sink_targeted_op_only() {
+  %0 = "test.sink_me"() : () -> i32
+  // CHECK-NEXT: %[[V1:.*]] = "test.dont_sink_me"
+  %1 = "test.dont_sink_me"() : () -> i32
+  // CHECK-NEXT: test.sink_target
+  "test.sink_target"() ({
+    // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"
+    // CHECK-NEXT: "test.use"(%[[V0]], %[[V1]])
+    "test.use"(%0, %1) : (i32, i32) -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir
index 09ebd5b981f72..5e476e145ac29 100644
--- a/mlir/test/Transforms/control-flow-sink.mlir
+++ b/mlir/test/Transforms/control-flow-sink.mlir
@@ -1,3 +1,4 @@
+// Test the default control-flow sink pass.
 // RUN: mlir-opt -control-flow-sink %s | FileCheck %s
 
 // Test that operations can be sunk.

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index b8e65bf9a37c0..95c30c34f4950 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestTransforms
   TestConstantFold.cpp
+  TestControlFlowSink.cpp
   TestInlining.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Transforms/TestControlFlowSink.cpp b/mlir/test/lib/Transforms/TestControlFlowSink.cpp
new file mode 100644
index 0000000000000..f6ec809341519
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestControlFlowSink.cpp
@@ -0,0 +1,65 @@
+//===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the control-flow sink utilities by implementing an example
+// control-flow sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/ControlFlowSinkUtils.h"
+
+using namespace mlir;
+
+namespace {
+/// An example control-flow sink pass to test the control-flow sink utilites.
+/// This pass will sink ops named `test.sink_me` and tag them with an attribute
+/// `was_sunk` into the first region of `test.sink_target` ops.
+struct TestControlFlowSinkPass
+    : public PassWrapper<TestControlFlowSinkPass, OperationPass<FuncOp>> {
+  /// Get the command-line argument of the test pass.
+  StringRef getArgument() const final { return "test-control-flow-sink"; }
+  /// Get the description of the test pass.
+  StringRef getDescription() const final {
+    return "Test control-flow sink pass";
+  }
+
+  /// Runs the pass on the function.
+  void runOnOperation() override {
+    auto &domInfo = getAnalysis<DominanceInfo>();
+    auto shouldMoveIntoRegion = [](Operation *op, Region *region) {
+      return region->getRegionNumber() == 0 &&
+             op->getName().getStringRef() == "test.sink_me";
+    };
+    auto moveIntoRegion = [](Operation *op, Region *region) {
+      Block &entry = region->front();
+      op->moveBefore(&entry, entry.begin());
+      op->setAttr("was_sunk",
+                  Builder(op).getI32IntegerAttr(region->getRegionNumber()));
+    };
+
+    getOperation()->walk([&](Operation *op) {
+      if (op->getName().getStringRef() != "test.sink_target")
+        return;
+      SmallVector<Region *> regions =
+          llvm::to_vector(RegionRange(op->getRegions()));
+      controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion);
+    });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestControlFlowSink() {
+  PassRegistration<TestControlFlowSinkPass>();
+}
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8a8d16064134b..2be6e15e3aec0 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -65,6 +65,7 @@ void registerTestAliasAnalysisPass();
 void registerTestBuiltinAttributeInterfaces();
 void registerTestCallGraphPass();
 void registerTestConstantFold();
+void registerTestControlFlowSink();
 void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
 void registerTestDataLayoutQuery();
@@ -151,6 +152,7 @@ void registerTestPasses() {
   mlir::test::registerTestBuiltinAttributeInterfaces();
   mlir::test::registerTestCallGraphPass();
   mlir::test::registerTestConstantFold();
+  mlir::test::registerTestControlFlowSink();
   mlir::test::registerTestDiagnosticsPass();
 #if MLIR_CUDA_CONVERSIONS_ENABLED
   mlir::test::registerTestGpuSerializeToCubinPass();


        


More information about the Mlir-commits mailing list