[Mlir-commits] [mlir] 572fa96 - [mlir] Add a ControlFlowSink pass.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 24 15:08:40 PST 2022
Author: Mogball
Date: 2022-01-24T23:08:34Z
New Revision: 572fa9642cb50f3c2d79e138e789c4b23f3ab8cf
URL: https://github.com/llvm/llvm-project/commit/572fa9642cb50f3c2d79e138e789c4b23f3ab8cf
DIFF: https://github.com/llvm/llvm-project/commit/572fa9642cb50f3c2d79e138e789c4b23f3ab8cf.diff
LOG: [mlir] Add a ControlFlowSink pass.
Control-Flow Sink moves operations whose only uses are in conditionally-executed regions into those regions so that paths in which their results are not needed do not perform unnecessary computation.
Depends on D115087
Reviewed By: jpienaar, rriddle, bondhugula
Differential Revision: https://reviews.llvm.org/D115088
Added:
mlir/lib/Transforms/ControlFlowSink.cpp
mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
mlir/test/Transforms/control-flow-sink.mlir
Modified:
mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/include/mlir/Transforms/Utils.h
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Utils/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 8ff76edb3068c..806821a988bf3 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -83,6 +83,35 @@ class RegionSuccessor {
ValueRange inputs;
};
+/// This class represents upper and lower bounds on the number of times a region
+/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
+/// zero, but the upper bound may not be known.
+class InvocationBounds {
+public:
+ /// Create invocation bounds. The lower bound must be at least 0 and only the
+ /// upper bound can be unknown.
+ InvocationBounds(unsigned lb, Optional<unsigned> ub) : lower(lb), upper(ub) {
+ assert(!ub || ub >= lb && "upper bound cannot be less than lower bound");
+ }
+
+ /// Return the lower bound.
+ unsigned getLowerBound() const { return lower; }
+
+ /// Return the upper bound.
+ Optional<unsigned> getUpperBound() const { return upper; }
+
+ /// Returns the unknown invocation bounds, i.e., there is no information on
+ /// how many times a region may be invoked.
+ static InvocationBounds getUnknown() { return {0, llvm::None}; }
+
+private:
+ /// The minimum number of times the successor region will be invoked.
+ unsigned lower;
+ /// The maximum number of times the successor region will be invoked or `None`
+ /// if an upper bound is not known.
+ Optional<unsigned> upper;
+};
+
/// Return `true` if `a` and `b` are in mutually exclusive regions as per
/// RegionBranchOpInterface.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 0633426cf50af..429a5356428f7 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -102,9 +102,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
- entering the region at `index`, which was specified as a successor of this
- operation by `getSuccessorRegions`. These operands should correspond 1-1
- with the successor inputs specified in `getSuccessorRegions`.
+ entering the region at `index`, which was specified as a successor of
+ this operation by `getSuccessorRegions`. These operands should
+ correspond 1-1 with the successor inputs specified in
+ `getSuccessorRegions`.
}],
"::mlir::OperandRange", "getSuccessorEntryOperands",
(ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
@@ -127,9 +128,28 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
successor region must be non-empty.
}],
"void", "getSuccessorRegions",
- (ins "::mlir::Optional<unsigned>":$index, "::mlir::ArrayRef<::mlir::Attribute>":$operands,
+ (ins "::mlir::Optional<unsigned>":$index,
+ "::mlir::ArrayRef<::mlir::Attribute>":$operands,
"::mlir::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
- >
+ >,
+ InterfaceMethod<[{
+ Populates `invocationBounds` with the minimum and maximum number of
+ times this operation will invoke the attached regions (assuming the
+ regions yield normally, i.e. do not abort or invoke an infinite loop).
+ The minimum number of invocations is at least 0. If the maximum number
+ of invocations cannot be statically determined, then it will not have a
+ value (i.e., it is set to `llvm::None`).
+
+ `operands` is a set of optional attributes that either correspond to a
+ constant values for each operand of this operation, or null if that
+ operand is not a constant.
+ }],
+ "void", "getRegionInvocationBounds",
+ (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands,
+ "::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
+ :$invocationBounds), [{}],
+ [{ invocationBounds.append($_op->getNumRegions(), {0, ::llvm::None}); }]
+ >,
];
let verify = [{
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 6aab120fb469f..46bab0047c1b6 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -74,6 +74,9 @@ createCanonicalizerPass(const GreedyRewriteConfig &config,
ArrayRef<std::string> disabledPatterns = llvm::None,
ArrayRef<std::string> enabledPatterns = llvm::None);
+/// Creates a pass to perform control-flow sinking.
+std::unique_ptr<Pass> createControlFlowSinkPass();
+
/// Creates a pass to perform common sub expression elimination.
std::unique_ptr<Pass> createCSEPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ff15df372aa26..56e90644363e5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -307,6 +307,28 @@ def Canonicalizer : Pass<"canonicalize"> {
] # RewritePassUtils.options;
}
+def ControlFlowSink : Pass<"control-flow-sink"> {
+ let summary = "Sink operations into conditional blocks";
+ let description = [{
+ This pass implements a simple control-flow sink on operations that implement
+ `RegionBranchOpInterface` by moving dominating operations whose only uses
+ are in a single conditionally-executed region into that region so that
+ executions paths where their results are not needed do not perform
+ unnecessary computations.
+
+ This is similar (but opposite) to loop-invariant code motion, which hoists
+ operations out of regions executed more than once.
+
+ It is recommended to run canonicalization first to remove unreachable
+ blocks: ops in unreachable blocks may prevent other operations from being
+ sunk as they may contain uses of their results
+ }];
+ let constructor = "::mlir::createControlFlowSinkPass()";
+ let statistics = [
+ Statistic<"numSunk", "num-sunk", "Number of operations sunk">,
+ ];
+}
+
def CSE : Pass<"cse"> {
let summary = "Eliminate common sub-expressions";
let description = [{
diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h
index 5280c2648bfae..5efbb19b08d25 100644
--- a/mlir/include/mlir/Transforms/Utils.h
+++ b/mlir/include/mlir/Transforms/Utils.h
@@ -25,6 +25,7 @@ namespace mlir {
class AffineApplyOp;
class AffineForOp;
+class DominanceInfo;
class Location;
class OpBuilder;
@@ -147,6 +148,53 @@ Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc,
void createAffineComputationSlice(Operation *opInst,
SmallVectorImpl<AffineApplyOp> *sliceOps);
+/// Given a list of regions, perform control flow sinking on them. For each
+/// region, control-flow sinking moves operations that dominate the region but
+/// whose only users are in the region into the regions so that they aren't
+/// executed on paths where their results are not needed.
+///
+/// TODO: For the moment, this is a *simple* control-flow sink, i.e., no
+/// duplicating of ops. It should be made to accept a cost model to determine
+/// whether duplicating a particular op is profitable.
+///
+/// Example:
+///
+/// ```mlir
+/// %0 = arith.addi %arg0, %arg1
+/// scf.if %cond {
+/// scf.yield %0
+/// } else {
+/// scf.yield %arg2
+/// }
+/// ```
+///
+/// After control-flow sink:
+///
+/// ```mlir
+/// scf.if %cond {
+/// %0 = arith.addi %arg0, %arg1
+/// scf.yield %0
+/// } else {
+/// scf.yield %arg2
+/// }
+/// ```
+///
+/// Users must supply a callback `shouldMoveIntoRegion` that determines whether
+/// the given operation that only has users in the given operation should be
+/// moved into that region.
+///
+/// Returns the number of operations sunk.
+size_t
+controlFlowSink(ArrayRef<Region *> regions, DominanceInfo &domInfo,
+ function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion);
+
+/// Populates `regions` with regions of the provided region branch op that are
+/// executed at most once at that are reachable given the current operands of
+/// the op. These regions can be passed to `controlFlowSink` to perform sinking
+/// on the regions of the operation.
+void getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,
+ SmallVectorImpl<Region *> ®ions);
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ae1468fa17ea3..3e10b4a321311 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
BufferResultsToOutParams.cpp
BufferUtils.cpp
Canonicalizer.cpp
+ ControlFlowSink.cpp
CSE.cpp
Inliner.cpp
LocationSnapshot.cpp
diff --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp
new file mode 100644
index 0000000000000..71afc5702edf0
--- /dev/null
+++ b/mlir/lib/Transforms/ControlFlowSink.cpp
@@ -0,0 +1,71 @@
+//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a basic control-flow sink pass. Control-flow sinking
+// moves operations whose only uses are in conditionally-executed blocks in to
+// those blocks so that they aren't executed on paths where their results are
+// not needed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+
+namespace {
+/// A basic control-flow sink pass. This pass analyzes the regions of operations
+/// that implement `RegionBranchOpInterface` that are reachable and executed at
+/// most once and sinks candidate operations that are side-effect free.
+struct ControlFlowSink : public ControlFlowSinkBase<ControlFlowSink> {
+ void runOnOperation() override;
+};
+} // end anonymous namespace
+
+/// Returns true if the given operation is side-effect free as are all of its
+/// nested operations.
+static bool isSideEffectFree(Operation *op) {
+ if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+ // If the op has side-effects, it cannot be moved.
+ if (!memInterface.hasNoEffect())
+ return false;
+ // If the op does not have recursive side effects, then it can be moved.
+ if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>())
+ return true;
+ } else if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
+ // Otherwise, if the op does not implement the memory effect interface and
+ // it does not have recursive side effects, then it cannot be known that the
+ // op is moveable.
+ return false;
+ }
+
+ // Recurse into the regions and ensure that all nested ops can also be moved.
+ for (Region ®ion : op->getRegions())
+ for (Operation &op : region.getOps())
+ if (!isSideEffectFree(&op))
+ return false;
+ return true;
+}
+
+void ControlFlowSink::runOnOperation() {
+ auto &domInfo = getAnalysis<DominanceInfo>();
+ getOperation()->walk([&](RegionBranchOpInterface branch) {
+ SmallVector<Region *> regionsToSink;
+ getSinglyExecutedRegionsToSink(branch, regionsToSink);
+ numSunk = mlir::controlFlowSink(
+ regionsToSink, domInfo,
+ [](Operation *op, Region *) { return isSideEffectFree(op); });
+ });
+}
+
+std::unique_ptr<Pass> mlir::createControlFlowSinkPass() {
+ return std::make_unique<ControlFlowSink>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 33deb17fe1378..c42d45325e1b2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRTransformUtils
+ ControlFlowSinkUtils.cpp
DialectConversion.cpp
FoldUtils.cpp
GreedyPatternRewriteDriver.cpp
diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
new file mode 100644
index 0000000000000..cffbd922f88c8
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -0,0 +1,152 @@
+//===- ControlFlowSinkUtils.cpp - Code to perform control-flow sinking ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilityies for control-flow sinking. Control-flow
+// sinking moves operations whose only uses are in conditionally-executed blocks
+// into those blocks so that they aren't executed on paths where their results
+// are not needed.
+//
+// Control-flow sinking is not implemented on BranchOpInterface because
+// sinking ops into the successors of branch operations may move ops into loops.
+// It is idiomatic MLIR to perform optimizations at IR levels that readily
+// provide the necessary information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/Utils.h"
+#include <vector>
+
+#define DEBUG_TYPE "cf-sink"
+
+using namespace mlir;
+
+namespace {
+/// A helper struct for control-flow sinking.
+class Sinker {
+public:
+ /// Create an operation sinker with given dominance info.
+ Sinker(function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+ DominanceInfo &domInfo)
+ : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo),
+ numSunk(0) {}
+
+ /// Given a list of regions, find operations to sink and sink them. Return the
+ /// number of operations sunk.
+ size_t sinkRegions(ArrayRef<Region *> regions) &&;
+
+private:
+ /// Given a region and an op which dominates the region, returns true if all
+ /// users of the given op are dominated by the entry block of the region, and
+ /// thus the operation can be sunk into the region.
+ bool allUsersDominatedBy(Operation *op, Region *region);
+
+ /// Given a region and a top-level op (an op whose parent region is the given
+ /// region), determine whether the defining ops of the op's operands can be
+ /// sunk into the region.
+ ///
+ /// Add moved ops to the work queue.
+ void tryToSinkPredecessors(Operation *user, Region *region,
+ std::vector<Operation *> &stack);
+
+ /// Iterate over all the ops in a region and try to sink their predecessors.
+ /// Recurse on subgraphs using a work queue.
+ void sinkRegion(Region *region);
+
+ /// The callback to determine whether an op should be moved in to a region.
+ function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion;
+ /// Dominance info to determine op user dominance with respect to regions.
+ DominanceInfo &domInfo;
+ /// The number of operations sunk.
+ size_t numSunk;
+};
+} // end anonymous namespace
+
+bool Sinker::allUsersDominatedBy(Operation *op, Region *region) {
+ assert(region->findAncestorOpInRegion(*op) == nullptr &&
+ "expected op to be defined outside the region");
+ return llvm::all_of(op->getUsers(), [&](Operation *user) {
+ // The user is dominated by the region if its containing block is dominated
+ // by the region's entry block.
+ return domInfo.dominates(®ion->front(), user->getBlock());
+ });
+}
+
+void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
+ std::vector<Operation *> &stack) {
+ LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n"));
+ for (Value value : user->getOperands()) {
+ Operation *op = value.getDefiningOp();
+ // Ignore block arguments and ops that are already inside the region.
+ if (!op || op->getParentRegion() == region)
+ continue;
+ LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n"));
+
+ // If the op's users are all in the region and it can be moved, then do so.
+ if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
+ // Move the op into the region's entry block. If the op is part of a
+ // subgraph, dependee ops would have been moved first, so inserting before
+ // the start of the block will ensure dominance is preserved. Ops can only
+ // be safely moved into the entry block as the region's other blocks may
+ // for a loop.
+ op->moveBefore(®ion->front(), region->front().begin());
+ ++numSunk;
+ // Add the op to the work queue.
+ stack.push_back(op);
+ }
+ }
+}
+
+void Sinker::sinkRegion(Region *region) {
+ // Initialize the work queue with all the ops in the region.
+ std::vector<Operation *> stack;
+ for (Operation &op : region->getOps())
+ stack.push_back(&op);
+
+ // Process all the ops depth-first. This ensures that nodes of subgraphs are
+ // sunk in the correct order.
+ while (!stack.empty()) {
+ Operation *op = stack.back();
+ stack.pop_back();
+ tryToSinkPredecessors(op, region, stack);
+ }
+}
+
+size_t Sinker::sinkRegions(ArrayRef<Region *> regions) && {
+ for (Region *region : regions)
+ if (!region->empty())
+ sinkRegion(region);
+ return numSunk;
+}
+
+size_t mlir::controlFlowSink(
+ ArrayRef<Region *> regions, DominanceInfo &domInfo,
+ function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion) {
+ return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions);
+}
+
+void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,
+ SmallVectorImpl<Region *> ®ions) {
+ // Collect constant operands.
+ SmallVector<Attribute> operands(branch->getNumOperands(), Attribute());
+ for (auto &it : llvm::enumerate(branch->getOperands()))
+ matchPattern(it.value(), m_Constant(&operands[it.index()]));
+ // Get the invocation bounds.
+ SmallVector<InvocationBounds> bounds;
+ branch.getRegionInvocationBounds(operands, bounds);
+
+ // For a simple control-flow sink, only consider regions that are executed at
+ // most once.
+ for (auto it : llvm::zip(branch->getRegions(), bounds)) {
+ const InvocationBounds &bound = std::get<1>(it);
+ if (bound.getUpperBound() && *bound.getUpperBound() <= 1)
+ regions.push_back(&std::get<0>(it));
+ }
+}
diff --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir
new file mode 100644
index 0000000000000..58dffe19c0aae
--- /dev/null
+++ b/mlir/test/Transforms/control-flow-sink.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s
+
+// Test that operations can be sunk.
+
+// CHECK-LABEL: @test_simple_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
+// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: test.region_if_yield %[[V2]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V0]], %[[V2]]
+// CHECK-NEXT: test.region_if_yield %[[V3]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V2]], %[[V0]]
+// CHECK-NEXT: test.region_if_yield %[[V3]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V1]]
+func @test_simple_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = arith.subi %arg1, %arg2 : i32
+ %1 = arith.subi %arg2, %arg1 : i32
+ %2 = arith.addi %arg1, %arg1 : i32
+ %3 = arith.addi %arg2, %arg2 : i32
+ %4 = test.region_if %arg0: i1 -> i32 then {
+ test.region_if_yield %0 : i32
+ } else {
+ %5 = arith.addi %1, %2 : i32
+ test.region_if_yield %5 : i32
+ } join {
+ %5 = arith.addi %3, %1 : i32
+ test.region_if_yield %5 : i32
+ }
+ return %4 : i32
+}
+
+// -----
+
+// Test that a region op can be sunk.
+
+// CHECK-LABEL: @test_region_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: test.region_if_yield %[[V2]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: test.region_if_yield %[[ARG2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: test.region_if_yield %[[V1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: test.region_if_yield %[[ARG2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = arith.subi %arg1, %arg2 : i32
+ %1 = test.region_if %arg0: i1 -> i32 then {
+ test.region_if_yield %arg1 : i32
+ } else {
+ test.region_if_yield %0 : i32
+ } join {
+ test.region_if_yield %arg2 : i32
+ }
+ %2 = test.region_if %arg0: i1 -> i32 then {
+ test.region_if_yield %1 : i32
+ } else {
+ test.region_if_yield %arg1 : i32
+ } join {
+ test.region_if_yield %arg2 : i32
+ }
+ return %2 : i32
+}
+
+// -----
+
+// Test that an entire subgraph can be sunk.
+
+// CHECK-LABEL: @test_subgraph_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V1:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: %[[V3:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
+// CHECK-NEXT: %[[V4:.*]] = arith.muli %[[V3]], %[[V3]]
+// CHECK-NEXT: %[[V5:.*]] = arith.muli %[[V2]], %[[V1]]
+// CHECK-NEXT: %[[V6:.*]] = arith.addi %[[V5]], %[[V4]]
+// CHECK-NEXT: test.region_if_yield %[[V6]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: test.region_if_yield %[[ARG2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = arith.addi %arg1, %arg2 : i32
+ %1 = arith.subi %arg1, %arg2 : i32
+ %2 = arith.subi %arg2, %arg1 : i32
+ %3 = arith.muli %0, %1 : i32
+ %4 = arith.muli %2, %2 : i32
+ %5 = arith.addi %3, %4 : i32
+ %6 = test.region_if %arg0: i1 -> i32 then {
+ test.region_if_yield %5 : i32
+ } else {
+ test.region_if_yield %arg1 : i32
+ } join {
+ test.region_if_yield %arg2 : i32
+ }
+ return %6 : i32
+}
+
+// -----
+
+// Test that ops can be sunk into regions with multiple blocks.
+
+// CHECK-LABEL: @test_multiblock_region_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: %[[V1:.*]] = "test.any_cond"() ({
+// CHECK-NEXT: %[[V3:.*]] = arith.addi %[[V0]], %[[ARG2]]
+// CHECK-NEXT: %[[V4:.*]] = arith.addi %[[V3]], %[[ARG1]]
+// CHECK-NEXT: br ^bb1(%[[V4]] : i32)
+// CHECK-NEXT: ^bb1(%[[V5:.*]]: i32):
+// CHECK-NEXT: %[[V6:.*]] = arith.addi %[[V5]], %[[V4]]
+// CHECK-NEXT: "test.yield"(%[[V6]])
+// CHECK-NEXT: })
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]]
+// CHECK-NEXT: return %[[V2]]
+func @test_multiblock_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+ %0 = arith.addi %arg1, %arg2 : i32
+ %1 = arith.addi %0, %arg2 : i32
+ %2 = arith.addi %1, %arg1 : i32
+ %3 = "test.any_cond"() ({
+ br ^bb1(%2 : i32)
+ ^bb1(%5: i32):
+ %6 = arith.addi %5, %2 : i32
+ "test.yield"(%6) : (i32) -> ()
+ }) : () -> i32
+ %4 = arith.addi %0, %3 : i32
+ return %4 : i32
+}
+
+// -----
+
+// Test that ops can be sunk recursively into nested regions.
+
+// CHECK-LABEL: @test_nested_region_sink
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK-NEXT: test.region_if_yield %[[V2]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: test.region_if_yield %[[V1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT: test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_nested_region_sink(%arg0: i1, %arg1: i32) -> i32 {
+ %0 = arith.addi %arg1, %arg1 : i32
+ %1 = test.region_if %arg0: i1 -> i32 then {
+ %2 = test.region_if %arg0: i1 -> i32 then {
+ test.region_if_yield %0 : i32
+ } else {
+ test.region_if_yield %arg1 : i32
+ } join {
+ test.region_if_yield %arg1 : i32
+ }
+ test.region_if_yield %2 : i32
+ } else {
+ test.region_if_yield %arg1 : i32
+ } join {
+ test.region_if_yield %arg1 : i32
+ }
+ return %1 : i32
+}
+
+// -----
+
+// Test that ops are only moved into the entry block, even when their only uses
+// are further along.
+
+// CHECK-LABEL: @test_not_sunk_deeply
+// CHECK-SAME: (%[[ARG0:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[V0:.*]] = "test.any_cond"() ({
+// CHECK-NEXT: %[[V1:.*]] = arith.addi %[[ARG0]], %[[ARG0]]
+// CHECK-NEXT: br ^bb1
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT: "test.yield"(%[[V1]]) : (i32) -> ()
+// CHECK-NEXT: })
+// CHECK-NEXT: return %[[V0]]
+func @test_not_sunk_deeply(%arg0: i32) -> i32 {
+ %0 = arith.addi %arg0, %arg0 : i32
+ %1 = "test.any_cond"() ({
+ br ^bb1
+ ^bb1:
+ "test.yield"(%0) : (i32) -> ()
+ }) : () -> i32
+ return %1 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 2b915a9fd6c2a..d82f7af73decb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1127,15 +1127,15 @@ static void print(OpAsmPrinter &p, RegionIfOp op) {
p.printOperands(op.getOperands());
p << ": " << op.getOperandTypes();
p.printArrowTypeList(op.getResultTypes());
- p << " then";
+ p << " then ";
p.printRegion(op.getThenRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
- p << " else";
+ p << " else ";
p.printRegion(op.getElseRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
- p << " join";
+ p << " join ";
p.printRegion(op.getJoinRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
@@ -1189,6 +1189,34 @@ void RegionIfOp::getSuccessorRegions(
regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
}
+void RegionIfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ // Each region is invoked at most once.
+ invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
+}
+
+//===----------------------------------------------------------------------===//
+// AnyCondOp
+//===----------------------------------------------------------------------===//
+
+void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The parent op branches into the only region, and the region branches back
+ // to the parent op.
+ if (index)
+ regions.emplace_back(&getRegion());
+ else
+ regions.emplace_back(getResults());
+}
+
+void AnyCondOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ invocationBounds.emplace_back(1, 1);
+}
+
//===----------------------------------------------------------------------===//
// SingleNoTerminatorCustomAsmOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cc9c950ea73fe..f171e19270069 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2342,14 +2342,15 @@ def RegionIfYieldOp : TEST_Op<"region_if_yield",
}
def RegionIfOp : TEST_Op<"region_if",
- [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getRegionInvocationBounds"]>,
SingleBlockImplicitTerminator<"RegionIfYieldOp">,
RecursiveSideEffects]> {
let description =[{
Represents an abstract if-then-else-join pattern. In this context, the then
and else regions jump to the join region, which finally returns to its
parent op.
- }];
+ }];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseRegionIfOp(parser, result); }];
@@ -2372,6 +2373,14 @@ def RegionIfOp : TEST_Op<"region_if",
}];
}
+def AnyCondOp : TEST_Op<"any_cond",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getRegionInvocationBounds"]>,
+ RecursiveSideEffects]> {
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$region);
+}
+
//===----------------------------------------------------------------------===//
// Test TableGen generated build() methods
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list