[Mlir-commits] [mlir] b73f1d2 - [mlir][cf-sink] Accept a callback for sinking operations
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 28 12:31:27 PDT 2022
Author: Mogball
Date: 2022-03-28T19:31:23Z
New Revision: b73f1d2c5d92a923c79435e423d7db9d6fa64eac
URL: https://github.com/llvm/llvm-project/commit/b73f1d2c5d92a923c79435e423d7db9d6fa64eac
DIFF: https://github.com/llvm/llvm-project/commit/b73f1d2c5d92a923c79435e423d7db9d6fa64eac.diff
LOG: [mlir][cf-sink] Accept a callback for sinking operations
(This was a TODO from the initial patch).
The control-flow sink utility accepts a callback that is used to sink an operation into a region.
The `moveIntoRegion` is called on the same operation and region that return true for `shouldMoveIntoRegion`.
The callback must preserve the dominance of the operation within the region. In the default control-flow
sink implementation, this is moving the operation to the start of the entry block.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D122445
Added:
mlir/test/Transforms/control-flow-sink-test.mlir
mlir/test/lib/Transforms/TestControlFlowSink.cpp
Modified:
mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
mlir/lib/Transforms/ControlFlowSink.cpp
mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
mlir/test/Transforms/control-flow-sink.mlir
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
index f45d753564f45..f4dfc6dee876d 100644
--- a/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
+++ b/mlir/include/mlir/Transforms/ControlFlowSinkUtils.h
@@ -51,12 +51,19 @@ class RegionBranchOpInterface;
///
/// Users must supply a callback `shouldMoveIntoRegion` that determines whether
/// the given operation that only has users in the given operation should be
-/// moved into that region.
+/// moved into that region. If this returns true, `moveIntoRegion` is called on
+/// the same operation and region.
+///
+/// `moveIntoRegion` must move the operation into the region such that dominance
+/// of the operation is preserved; for example, by moving the operation to the
+/// start of the entry block. This ensures the preservation of SSA dominance of
+/// the operation's results.
///
/// Returns the number of operations sunk.
size_t
controlFlowSink(ArrayRef<Region *> regions, DominanceInfo &domInfo,
- function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion);
+ function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+ function_ref<void(Operation *, Region *)> moveIntoRegion);
/// Populates `regions` with regions of the provided region branch op that are
/// executed at most once at that are reachable given the current operands of
diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp
index 6643158ac5de0..9d996449c3fbe 100644
--- a/mlir/lib/Transforms/ControlFlowSink.cpp
+++ b/mlir/lib/Transforms/ControlFlowSink.cpp
@@ -60,9 +60,14 @@ void ControlFlowSink::runOnOperation() {
// Get the regions are that known to be executed at most once.
getSinglyExecutedRegionsToSink(branch, regionsToSink);
// Sink side-effect free operations.
- numSunk =
- controlFlowSink(regionsToSink, domInfo, [](Operation *op, Region *) {
- return isSideEffectFree(op);
+ numSunk = controlFlowSink(
+ regionsToSink, domInfo,
+ [](Operation *op, Region *) { return isSideEffectFree(op); },
+ [](Operation *op, Region *region) {
+ // Move the operation to the beginning of the region's entry block.
+ // This guarantees the preservation of SSA dominance of all of the
+ // operation's uses are in the region.
+ op->moveBefore(®ion->front(), region->front().begin());
});
});
}
diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
index 34668fa682b62..01d6f6464a40f 100644
--- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -34,8 +34,10 @@ class Sinker {
public:
/// Create an operation sinker with given dominance info.
Sinker(function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+ function_ref<void(Operation *, Region *)> moveIntoRegion,
DominanceInfo &domInfo)
- : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo) {}
+ : shouldMoveIntoRegion(shouldMoveIntoRegion),
+ moveIntoRegion(moveIntoRegion), domInfo(domInfo), numSunk(0) {}
/// Given a list of regions, find operations to sink and sink them. Return the
/// number of operations sunk.
@@ -61,6 +63,8 @@ class Sinker {
/// The callback to determine whether an op should be moved in to a region.
function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion;
+ /// The calback to move an operation into the region.
+ function_ref<void(Operation *, Region *)> moveIntoRegion;
/// Dominance info to determine op user dominance with respect to regions.
DominanceInfo &domInfo;
/// The number of operations sunk.
@@ -90,12 +94,7 @@ void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
// If the op's users are all in the region and it can be moved, then do so.
if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
- // Move the op into the region's entry block. If the op is part of a
- // subgraph, dependee ops would have been moved first, so inserting before
- // the start of the block will ensure SSA dominance is preserved locally
- // in the subgraph. Ops can only be safely moved into the entry block as
- // the region's other blocks may for a loop.
- op->moveBefore(®ion->front(), region->front().begin());
+ moveIntoRegion(op, region);
++numSunk;
// Add the op to the work queue.
stack.push_back(op);
@@ -127,8 +126,10 @@ size_t Sinker::sinkRegions(ArrayRef<Region *> regions) {
size_t mlir::controlFlowSink(
ArrayRef<Region *> regions, DominanceInfo &domInfo,
- function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion) {
- return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions);
+ function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+ function_ref<void(Operation *, Region *)> moveIntoRegion) {
+ return Sinker(shouldMoveIntoRegion, moveIntoRegion, domInfo)
+ .sinkRegions(regions);
}
void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,
diff --git a/mlir/test/Transforms/control-flow-sink-test.mlir b/mlir/test/Transforms/control-flow-sink-test.mlir
new file mode 100644
index 0000000000000..868beb756b4f2
--- /dev/null
+++ b/mlir/test/Transforms/control-flow-sink-test.mlir
@@ -0,0 +1,44 @@
+// Invoke the test control-flow sink pass to test the utilities.
+// RUN: mlir-opt -test-control-flow-sink %s | FileCheck %s
+
+// CHECK-LABEL: func @test_sink
+func @test_sink() {
+ %0 = "test.sink_me"() : () -> i32
+ // CHECK-NEXT: test.sink_target
+ "test.sink_target"() ({
+ // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {was_sunk = 0 : i32}
+ // CHECK-NEXT: "test.use"(%[[V0]])
+ "test.use"(%0) : (i32) -> ()
+ }) : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @test_sink_first_region_only
+func @test_sink_first_region_only() {
+ %0 = "test.sink_me"() {first} : () -> i32
+ // CHECK-NEXT: %[[V1:.*]] = "test.sink_me"() {second}
+ %1 = "test.sink_me"() {second} : () -> i32
+ // CHECK-NEXT: test.sink_target
+ "test.sink_target"() ({
+ // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"() {first, was_sunk = 0 : i32}
+ // CHECK-NEXT: "test.use"(%[[V0]])
+ "test.use"(%0) : (i32) -> ()
+ }, {
+ "test.use"(%1) : (i32) -> ()
+ }) : () -> ()
+ return
+}
+
+// CHECK-LABEL: func @test_sink_targeted_op_only
+func @test_sink_targeted_op_only() {
+ %0 = "test.sink_me"() : () -> i32
+ // CHECK-NEXT: %[[V1:.*]] = "test.dont_sink_me"
+ %1 = "test.dont_sink_me"() : () -> i32
+ // CHECK-NEXT: test.sink_target
+ "test.sink_target"() ({
+ // CHECK-NEXT: %[[V0:.*]] = "test.sink_me"
+ // CHECK-NEXT: "test.use"(%[[V0]], %[[V1]])
+ "test.use"(%0, %1) : (i32, i32) -> ()
+ }) : () -> ()
+ return
+}
diff --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir
index 09ebd5b981f72..5e476e145ac29 100644
--- a/mlir/test/Transforms/control-flow-sink.mlir
+++ b/mlir/test/Transforms/control-flow-sink.mlir
@@ -1,3 +1,4 @@
+// Test the default control-flow sink pass.
// RUN: mlir-opt -control-flow-sink %s | FileCheck %s
// Test that operations can be sunk.
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index b8e65bf9a37c0..95c30c34f4950 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
+ TestControlFlowSink.cpp
TestInlining.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Transforms/TestControlFlowSink.cpp b/mlir/test/lib/Transforms/TestControlFlowSink.cpp
new file mode 100644
index 0000000000000..f6ec809341519
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestControlFlowSink.cpp
@@ -0,0 +1,65 @@
+//===- TestControlFlowSink.cpp - Test control-flow sink pass --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass tests the control-flow sink utilities by implementing an example
+// control-flow sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/ControlFlowSinkUtils.h"
+
+using namespace mlir;
+
+namespace {
+/// An example control-flow sink pass to test the control-flow sink utilites.
+/// This pass will sink ops named `test.sink_me` and tag them with an attribute
+/// `was_sunk` into the first region of `test.sink_target` ops.
+struct TestControlFlowSinkPass
+ : public PassWrapper<TestControlFlowSinkPass, OperationPass<FuncOp>> {
+ /// Get the command-line argument of the test pass.
+ StringRef getArgument() const final { return "test-control-flow-sink"; }
+ /// Get the description of the test pass.
+ StringRef getDescription() const final {
+ return "Test control-flow sink pass";
+ }
+
+ /// Runs the pass on the function.
+ void runOnOperation() override {
+ auto &domInfo = getAnalysis<DominanceInfo>();
+ auto shouldMoveIntoRegion = [](Operation *op, Region *region) {
+ return region->getRegionNumber() == 0 &&
+ op->getName().getStringRef() == "test.sink_me";
+ };
+ auto moveIntoRegion = [](Operation *op, Region *region) {
+ Block &entry = region->front();
+ op->moveBefore(&entry, entry.begin());
+ op->setAttr("was_sunk",
+ Builder(op).getI32IntegerAttr(region->getRegionNumber()));
+ };
+
+ getOperation()->walk([&](Operation *op) {
+ if (op->getName().getStringRef() != "test.sink_target")
+ return;
+ SmallVector<Region *> regions =
+ llvm::to_vector(RegionRange(op->getRegions()));
+ controlFlowSink(regions, domInfo, shouldMoveIntoRegion, moveIntoRegion);
+ });
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestControlFlowSink() {
+ PassRegistration<TestControlFlowSinkPass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8a8d16064134b..2be6e15e3aec0 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -65,6 +65,7 @@ void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
void registerTestConstantFold();
+void registerTestControlFlowSink();
void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
void registerTestDataLayoutQuery();
@@ -151,6 +152,7 @@ void registerTestPasses() {
mlir::test::registerTestBuiltinAttributeInterfaces();
mlir::test::registerTestCallGraphPass();
mlir::test::registerTestConstantFold();
+ mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass();
#if MLIR_CUDA_CONVERSIONS_ENABLED
mlir::test::registerTestGpuSerializeToCubinPass();
More information about the Mlir-commits
mailing list