[Mlir-commits] [mlir] 572fa96 - [mlir] Add a ControlFlowSink pass.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 24 15:08:40 PST 2022


Author: Mogball
Date: 2022-01-24T23:08:34Z
New Revision: 572fa9642cb50f3c2d79e138e789c4b23f3ab8cf

URL: https://github.com/llvm/llvm-project/commit/572fa9642cb50f3c2d79e138e789c4b23f3ab8cf
DIFF: https://github.com/llvm/llvm-project/commit/572fa9642cb50f3c2d79e138e789c4b23f3ab8cf.diff

LOG: [mlir] Add a ControlFlowSink pass.

Control-Flow Sink moves operations whose only uses are in conditionally-executed regions into those regions so that paths in which their results are not needed do not perform unnecessary computation.

Depends on D115087

Reviewed By: jpienaar, rriddle, bondhugula

Differential Revision: https://reviews.llvm.org/D115088

Added: 
    mlir/lib/Transforms/ControlFlowSink.cpp
    mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
    mlir/test/Transforms/control-flow-sink.mlir

Modified: 
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/include/mlir/Transforms/Utils.h
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Utils/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 8ff76edb3068c..806821a988bf3 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -83,6 +83,35 @@ class RegionSuccessor {
   ValueRange inputs;
 };
 
+/// This class represents upper and lower bounds on the number of times a region
+/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
+/// zero, but the upper bound may not be known.
+class InvocationBounds {
+public:
+  /// Create invocation bounds. The lower bound must be at least 0 and only the
+  /// upper bound can be unknown.
+  InvocationBounds(unsigned lb, Optional<unsigned> ub) : lower(lb), upper(ub) {
+    assert(!ub || ub >= lb && "upper bound cannot be less than lower bound");
+  }
+
+  /// Return the lower bound.
+  unsigned getLowerBound() const { return lower; }
+
+  /// Return the upper bound.
+  Optional<unsigned> getUpperBound() const { return upper; }
+
+  /// Returns the unknown invocation bounds, i.e., there is no information on
+  /// how many times a region may be invoked.
+  static InvocationBounds getUnknown() { return {0, llvm::None}; }
+
+private:
+  /// The minimum number of times the successor region will be invoked.
+  unsigned lower;
+  /// The maximum number of times the successor region will be invoked or `None`
+  /// if an upper bound is not known.
+  Optional<unsigned> upper;
+};
+
 /// Return `true` if `a` and `b` are in mutually exclusive regions as per
 /// RegionBranchOpInterface.
 bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 0633426cf50af..429a5356428f7 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -102,9 +102,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let methods = [
     InterfaceMethod<[{
         Returns the operands of this operation used as the entry arguments when
-        entering the region at `index`, which was specified as a successor of this
-        operation by `getSuccessorRegions`. These operands should correspond 1-1
-        with the successor inputs specified in `getSuccessorRegions`.
+        entering the region at `index`, which was specified as a successor of
+        this operation by `getSuccessorRegions`. These operands should
+        correspond 1-1 with the successor inputs specified in
+        `getSuccessorRegions`.
       }],
       "::mlir::OperandRange", "getSuccessorEntryOperands",
       (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
@@ -127,9 +128,28 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
         successor region must be non-empty.
       }],
       "void", "getSuccessorRegions",
-      (ins "::mlir::Optional<unsigned>":$index, "::mlir::ArrayRef<::mlir::Attribute>":$operands,
+      (ins "::mlir::Optional<unsigned>":$index,
+           "::mlir::ArrayRef<::mlir::Attribute>":$operands,
            "::mlir::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
-    >
+    >,
+    InterfaceMethod<[{
+        Populates `invocationBounds` with the minimum and maximum number of
+        times this operation will invoke the attached regions (assuming the
+        regions yield normally, i.e. do not abort or invoke an infinite loop).
+        The minimum number of invocations is at least 0. If the maximum number
+        of invocations cannot be statically determined, then it will not have a
+        value (i.e., it is set to `llvm::None`).
+
+        `operands` is a set of optional attributes that either correspond to a
+        constant values for each operand of this operation, or null if that
+        operand is not a constant.
+      }],
+      "void", "getRegionInvocationBounds",
+      (ins "::mlir::ArrayRef<::mlir::Attribute>":$operands,
+           "::llvm::SmallVectorImpl<::mlir::InvocationBounds> &"
+             :$invocationBounds), [{}],
+       [{ invocationBounds.append($_op->getNumRegions(), {0, ::llvm::None}); }]
+    >,
   ];
 
   let verify = [{

diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 6aab120fb469f..46bab0047c1b6 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -74,6 +74,9 @@ createCanonicalizerPass(const GreedyRewriteConfig &config,
                         ArrayRef<std::string> disabledPatterns = llvm::None,
                         ArrayRef<std::string> enabledPatterns = llvm::None);
 
+/// Creates a pass to perform control-flow sinking.
+std::unique_ptr<Pass> createControlFlowSinkPass();
+
 /// Creates a pass to perform common sub expression elimination.
 std::unique_ptr<Pass> createCSEPass();
 

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ff15df372aa26..56e90644363e5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -307,6 +307,28 @@ def Canonicalizer : Pass<"canonicalize"> {
   ] # RewritePassUtils.options;
 }
 
+def ControlFlowSink : Pass<"control-flow-sink"> {
+  let summary = "Sink operations into conditional blocks";
+  let description = [{
+    This pass implements a simple control-flow sink on operations that implement
+    `RegionBranchOpInterface` by moving dominating operations whose only uses
+    are in a single conditionally-executed region into that region so that
+    executions paths where their results are not needed do not perform
+    unnecessary computations.
+
+    This is similar (but opposite) to loop-invariant code motion, which hoists
+    operations out of regions executed more than once.
+
+    It is recommended to run canonicalization first to remove unreachable
+    blocks: ops in unreachable blocks may prevent other operations from being
+    sunk as they may contain uses of their results
+  }];
+  let constructor = "::mlir::createControlFlowSinkPass()";
+  let statistics = [
+    Statistic<"numSunk", "num-sunk", "Number of operations sunk">,
+  ];
+}
+
 def CSE : Pass<"cse"> {
   let summary = "Eliminate common sub-expressions";
   let description = [{

diff  --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h
index 5280c2648bfae..5efbb19b08d25 100644
--- a/mlir/include/mlir/Transforms/Utils.h
+++ b/mlir/include/mlir/Transforms/Utils.h
@@ -25,6 +25,7 @@ namespace mlir {
 
 class AffineApplyOp;
 class AffineForOp;
+class DominanceInfo;
 class Location;
 class OpBuilder;
 
@@ -147,6 +148,53 @@ Operation *createComposedAffineApplyOp(OpBuilder &builder, Location loc,
 void createAffineComputationSlice(Operation *opInst,
                                   SmallVectorImpl<AffineApplyOp> *sliceOps);
 
+/// Given a list of regions, perform control flow sinking on them. For each
+/// region, control-flow sinking moves operations that dominate the region but
+/// whose only users are in the region into the regions so that they aren't
+/// executed on paths where their results are not needed.
+///
+/// TODO: For the moment, this is a *simple* control-flow sink, i.e., no
+/// duplicating of ops. It should be made to accept a cost model to determine
+/// whether duplicating a particular op is profitable.
+///
+/// Example:
+///
+/// ```mlir
+/// %0 = arith.addi %arg0, %arg1
+/// scf.if %cond {
+///   scf.yield %0
+/// } else {
+///   scf.yield %arg2
+/// }
+/// ```
+///
+/// After control-flow sink:
+///
+/// ```mlir
+/// scf.if %cond {
+///   %0 = arith.addi %arg0, %arg1
+///   scf.yield %0
+/// } else {
+///   scf.yield %arg2
+/// }
+/// ```
+///
+/// Users must supply a callback `shouldMoveIntoRegion` that determines whether
+/// the given operation that only has users in the given operation should be
+/// moved into that region.
+///
+/// Returns the number of operations sunk.
+size_t
+controlFlowSink(ArrayRef<Region *> regions, DominanceInfo &domInfo,
+                function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion);
+
+/// Populates `regions` with regions of the provided region branch op that are
+/// executed at most once at that are reachable given the current operands of
+/// the op. These regions can be passed to `controlFlowSink` to perform sinking
+/// on the regions of the operation.
+void getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,
+                                    SmallVectorImpl<Region *> &regions);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_UTILS_H

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index ae1468fa17ea3..3e10b4a321311 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_library(MLIRTransforms
   BufferResultsToOutParams.cpp
   BufferUtils.cpp
   Canonicalizer.cpp
+  ControlFlowSink.cpp
   CSE.cpp
   Inliner.cpp
   LocationSnapshot.cpp

diff  --git a/mlir/lib/Transforms/ControlFlowSink.cpp b/mlir/lib/Transforms/ControlFlowSink.cpp
new file mode 100644
index 0000000000000..71afc5702edf0
--- /dev/null
+++ b/mlir/lib/Transforms/ControlFlowSink.cpp
@@ -0,0 +1,71 @@
+//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a basic control-flow sink pass. Control-flow sinking
+// moves operations whose only uses are in conditionally-executed blocks in to
+// those blocks so that they aren't executed on paths where their results are
+// not needed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/Utils.h"
+
+using namespace mlir;
+
+namespace {
+/// A basic control-flow sink pass. This pass analyzes the regions of operations
+/// that implement `RegionBranchOpInterface` that are reachable and executed at
+/// most once and sinks candidate operations that are side-effect free.
+struct ControlFlowSink : public ControlFlowSinkBase<ControlFlowSink> {
+  void runOnOperation() override;
+};
+} // end anonymous namespace
+
+/// Returns true if the given operation is side-effect free as are all of its
+/// nested operations.
+static bool isSideEffectFree(Operation *op) {
+  if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+    // If the op has side-effects, it cannot be moved.
+    if (!memInterface.hasNoEffect())
+      return false;
+    // If the op does not have recursive side effects, then it can be moved.
+    if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>())
+      return true;
+  } else if (!op->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
+    // Otherwise, if the op does not implement the memory effect interface and
+    // it does not have recursive side effects, then it cannot be known that the
+    // op is moveable.
+    return false;
+  }
+
+  // Recurse into the regions and ensure that all nested ops can also be moved.
+  for (Region &region : op->getRegions())
+    for (Operation &op : region.getOps())
+      if (!isSideEffectFree(&op))
+        return false;
+  return true;
+}
+
+void ControlFlowSink::runOnOperation() {
+  auto &domInfo = getAnalysis<DominanceInfo>();
+  getOperation()->walk([&](RegionBranchOpInterface branch) {
+    SmallVector<Region *> regionsToSink;
+    getSinglyExecutedRegionsToSink(branch, regionsToSink);
+    numSunk = mlir::controlFlowSink(
+        regionsToSink, domInfo,
+        [](Operation *op, Region *) { return isSideEffectFree(op); });
+  });
+}
+
+std::unique_ptr<Pass> mlir::createControlFlowSinkPass() {
+  return std::make_unique<ControlFlowSink>();
+}

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 33deb17fe1378..c42d45325e1b2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRTransformUtils
+  ControlFlowSinkUtils.cpp
   DialectConversion.cpp
   FoldUtils.cpp
   GreedyPatternRewriteDriver.cpp

diff  --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
new file mode 100644
index 0000000000000..cffbd922f88c8
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -0,0 +1,152 @@
+//===- ControlFlowSinkUtils.cpp - Code to perform control-flow sinking ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilityies for control-flow sinking. Control-flow
+// sinking moves operations whose only uses are in conditionally-executed blocks
+// into those blocks so that they aren't executed on paths where their results
+// are not needed.
+//
+// Control-flow sinking is not implemented on BranchOpInterface because
+// sinking ops into the successors of branch operations may move ops into loops.
+// It is idiomatic MLIR to perform optimizations at IR levels that readily
+// provide the necessary information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/Utils.h"
+#include <vector>
+
+#define DEBUG_TYPE "cf-sink"
+
+using namespace mlir;
+
+namespace {
+/// A helper struct for control-flow sinking.
+class Sinker {
+public:
+  /// Create an operation sinker with given dominance info.
+  Sinker(function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion,
+         DominanceInfo &domInfo)
+      : shouldMoveIntoRegion(shouldMoveIntoRegion), domInfo(domInfo),
+        numSunk(0) {}
+
+  /// Given a list of regions, find operations to sink and sink them. Return the
+  /// number of operations sunk.
+  size_t sinkRegions(ArrayRef<Region *> regions) &&;
+
+private:
+  /// Given a region and an op which dominates the region, returns true if all
+  /// users of the given op are dominated by the entry block of the region, and
+  /// thus the operation can be sunk into the region.
+  bool allUsersDominatedBy(Operation *op, Region *region);
+
+  /// Given a region and a top-level op (an op whose parent region is the given
+  /// region), determine whether the defining ops of the op's operands can be
+  /// sunk into the region.
+  ///
+  /// Add moved ops to the work queue.
+  void tryToSinkPredecessors(Operation *user, Region *region,
+                             std::vector<Operation *> &stack);
+
+  /// Iterate over all the ops in a region and try to sink their predecessors.
+  /// Recurse on subgraphs using a work queue.
+  void sinkRegion(Region *region);
+
+  /// The callback to determine whether an op should be moved in to a region.
+  function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion;
+  /// Dominance info to determine op user dominance with respect to regions.
+  DominanceInfo &domInfo;
+  /// The number of operations sunk.
+  size_t numSunk;
+};
+} // end anonymous namespace
+
+bool Sinker::allUsersDominatedBy(Operation *op, Region *region) {
+  assert(region->findAncestorOpInRegion(*op) == nullptr &&
+         "expected op to be defined outside the region");
+  return llvm::all_of(op->getUsers(), [&](Operation *user) {
+    // The user is dominated by the region if its containing block is dominated
+    // by the region's entry block.
+    return domInfo.dominates(&region->front(), user->getBlock());
+  });
+}
+
+void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
+                                   std::vector<Operation *> &stack) {
+  LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n"));
+  for (Value value : user->getOperands()) {
+    Operation *op = value.getDefiningOp();
+    // Ignore block arguments and ops that are already inside the region.
+    if (!op || op->getParentRegion() == region)
+      continue;
+    LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n"));
+
+    // If the op's users are all in the region and it can be moved, then do so.
+    if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
+      // Move the op into the region's entry block. If the op is part of a
+      // subgraph, dependee ops would have been moved first, so inserting before
+      // the start of the block will ensure dominance is preserved. Ops can only
+      // be safely moved into the entry block as the region's other blocks may
+      // for a loop.
+      op->moveBefore(&region->front(), region->front().begin());
+      ++numSunk;
+      // Add the op to the work queue.
+      stack.push_back(op);
+    }
+  }
+}
+
+void Sinker::sinkRegion(Region *region) {
+  // Initialize the work queue with all the ops in the region.
+  std::vector<Operation *> stack;
+  for (Operation &op : region->getOps())
+    stack.push_back(&op);
+
+  // Process all the ops depth-first. This ensures that nodes of subgraphs are
+  // sunk in the correct order.
+  while (!stack.empty()) {
+    Operation *op = stack.back();
+    stack.pop_back();
+    tryToSinkPredecessors(op, region, stack);
+  }
+}
+
+size_t Sinker::sinkRegions(ArrayRef<Region *> regions) && {
+  for (Region *region : regions)
+    if (!region->empty())
+      sinkRegion(region);
+  return numSunk;
+}
+
+size_t mlir::controlFlowSink(
+    ArrayRef<Region *> regions, DominanceInfo &domInfo,
+    function_ref<bool(Operation *, Region *)> shouldMoveIntoRegion) {
+  return Sinker(shouldMoveIntoRegion, domInfo).sinkRegions(regions);
+}
+
+void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch,
+                                          SmallVectorImpl<Region *> &regions) {
+  // Collect constant operands.
+  SmallVector<Attribute> operands(branch->getNumOperands(), Attribute());
+  for (auto &it : llvm::enumerate(branch->getOperands()))
+    matchPattern(it.value(), m_Constant(&operands[it.index()]));
+  // Get the invocation bounds.
+  SmallVector<InvocationBounds> bounds;
+  branch.getRegionInvocationBounds(operands, bounds);
+
+  // For a simple control-flow sink, only consider regions that are executed at
+  // most once.
+  for (auto it : llvm::zip(branch->getRegions(), bounds)) {
+    const InvocationBounds &bound = std::get<1>(it);
+    if (bound.getUpperBound() && *bound.getUpperBound() <= 1)
+      regions.push_back(&std::get<0>(it));
+  }
+}

diff  --git a/mlir/test/Transforms/control-flow-sink.mlir b/mlir/test/Transforms/control-flow-sink.mlir
new file mode 100644
index 0000000000000..58dffe19c0aae
--- /dev/null
+++ b/mlir/test/Transforms/control-flow-sink.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s
+
+// Test that operations can be sunk.
+
+// CHECK-LABEL: @test_simple_sink
+// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
+// CHECK-NEXT: %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:   %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT:   test.region_if_yield %[[V2]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V0]], %[[V2]]
+// CHECK-NEXT:   test.region_if_yield %[[V3]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+// CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V2]], %[[V0]]
+// CHECK-NEXT:   test.region_if_yield %[[V3]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V1]]
+func @test_simple_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+  %0 = arith.subi %arg1, %arg2 : i32
+  %1 = arith.subi %arg2, %arg1 : i32
+  %2 = arith.addi %arg1, %arg1 : i32
+  %3 = arith.addi %arg2, %arg2 : i32
+  %4 = test.region_if %arg0: i1 -> i32 then {
+    test.region_if_yield %0 : i32
+  } else {
+    %5 = arith.addi %1, %2 : i32
+    test.region_if_yield %5 : i32
+  } join {
+    %5 = arith.addi %3, %1 : i32
+    test.region_if_yield %5 : i32
+  }
+  return %4 : i32
+}
+
+// -----
+
+// Test that a region op can be sunk.
+
+// CHECK-LABEL: @test_region_sink
+// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:     test.region_if_yield %[[ARG1]]
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     %[[V2:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT:     test.region_if_yield %[[V2]]
+// CHECK-NEXT:   } join {
+// CHECK-NEXT:     test.region_if_yield %[[ARG2]]
+// CHECK-NEXT:   }
+// CHECK-NEXT:   test.region_if_yield %[[V1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT:   test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT:   test.region_if_yield %[[ARG2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+  %0 = arith.subi %arg1, %arg2 : i32
+  %1 = test.region_if %arg0: i1 -> i32 then {
+    test.region_if_yield %arg1 : i32
+  } else {
+    test.region_if_yield %0 : i32
+  } join {
+    test.region_if_yield %arg2 : i32
+  }
+  %2 = test.region_if %arg0: i1 -> i32 then {
+    test.region_if_yield %1 : i32
+  } else {
+    test.region_if_yield %arg1 : i32
+  } join {
+    test.region_if_yield %arg2 : i32
+  }
+  return %2 : i32
+}
+
+// -----
+
+// Test that an entire subgraph can be sunk.
+
+// CHECK-LABEL: @test_subgraph_sink
+// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:   %[[V1:.*]] = arith.subi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT:   %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT:   %[[V3:.*]] = arith.subi %[[ARG2]], %[[ARG1]]
+// CHECK-NEXT:   %[[V4:.*]] = arith.muli %[[V3]], %[[V3]]
+// CHECK-NEXT:   %[[V5:.*]] = arith.muli %[[V2]], %[[V1]]
+// CHECK-NEXT:   %[[V6:.*]] = arith.addi %[[V5]], %[[V4]]
+// CHECK-NEXT:   test.region_if_yield %[[V6]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT:   test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT:   test.region_if_yield %[[ARG2]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_subgraph_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+  %0 = arith.addi %arg1, %arg2 : i32
+  %1 = arith.subi %arg1, %arg2 : i32
+  %2 = arith.subi %arg2, %arg1 : i32
+  %3 = arith.muli %0, %1 : i32
+  %4 = arith.muli %2, %2 : i32
+  %5 = arith.addi %3, %4 : i32
+  %6 = test.region_if %arg0: i1 -> i32 then {
+    test.region_if_yield %5 : i32
+  } else {
+    test.region_if_yield %arg1 : i32
+  } join {
+    test.region_if_yield %arg2 : i32
+  }
+  return %6 : i32
+}
+
+// -----
+
+// Test that ops can be sunk into regions with multiple blocks.
+
+// CHECK-LABEL: @test_multiblock_region_sink
+// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+// CHECK-NEXT: %[[V0:.*]] = arith.addi %[[ARG1]], %[[ARG2]]
+// CHECK-NEXT: %[[V1:.*]] = "test.any_cond"() ({
+// CHECK-NEXT:   %[[V3:.*]] = arith.addi %[[V0]], %[[ARG2]]
+// CHECK-NEXT:   %[[V4:.*]] = arith.addi %[[V3]], %[[ARG1]]
+// CHECK-NEXT:   br ^bb1(%[[V4]] : i32)
+// CHECK-NEXT: ^bb1(%[[V5:.*]]: i32):
+// CHECK-NEXT:   %[[V6:.*]] = arith.addi %[[V5]], %[[V4]]
+// CHECK-NEXT:   "test.yield"(%[[V6]])
+// CHECK-NEXT: })
+// CHECK-NEXT: %[[V2:.*]] = arith.addi %[[V0]], %[[V1]]
+// CHECK-NEXT: return %[[V2]]
+func @test_multiblock_region_sink(%arg0: i1, %arg1: i32, %arg2: i32) -> i32 {
+  %0 = arith.addi %arg1, %arg2 : i32
+  %1 = arith.addi %0, %arg2 : i32
+  %2 = arith.addi %1, %arg1 : i32
+  %3 = "test.any_cond"() ({
+    br ^bb1(%2 : i32)
+  ^bb1(%5: i32):
+    %6 = arith.addi %5, %2 : i32
+    "test.yield"(%6) : (i32) -> ()
+  }) : () -> i32
+  %4 = arith.addi %0, %3 : i32
+  return %4 : i32
+}
+
+// -----
+
+// Test that ops can be sunk recursively into nested regions.
+
+// CHECK-LABEL: @test_nested_region_sink
+// CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[V0:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:   %[[V1:.*]] = test.region_if %[[ARG0]]: i1 -> i32 then {
+// CHECK-NEXT:     %[[V2:.*]] = arith.addi %[[ARG1]], %[[ARG1]]
+// CHECK-NEXT:     test.region_if_yield %[[V2]]
+// CHECK-NEXT:   } else {
+// CHECK-NEXT:     test.region_if_yield %[[ARG1]]
+// CHECK-NEXT:   } join {
+// CHECK-NEXT:     test.region_if_yield %[[ARG1]]
+// CHECK-NEXT:   }
+// CHECK-NEXT:   test.region_if_yield %[[V1]]
+// CHECK-NEXT: } else {
+// CHECK-NEXT:   test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: } join {
+// CHECK-NEXT:   test.region_if_yield %[[ARG1]]
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[V0]]
+func @test_nested_region_sink(%arg0: i1, %arg1: i32) -> i32 {
+  %0 = arith.addi %arg1, %arg1 : i32
+  %1 = test.region_if %arg0: i1 -> i32 then {
+    %2 = test.region_if %arg0: i1 -> i32 then {
+      test.region_if_yield %0 : i32
+    } else {
+      test.region_if_yield %arg1 : i32
+    } join {
+      test.region_if_yield %arg1 : i32
+    }
+    test.region_if_yield %2 : i32
+  } else {
+    test.region_if_yield %arg1 : i32
+  } join {
+    test.region_if_yield %arg1 : i32
+  }
+  return %1 : i32
+}
+
+// -----
+
+// Test that ops are only moved into the entry block, even when their only uses
+// are further along.
+
+// CHECK-LABEL: @test_not_sunk_deeply
+// CHECK-SAME:  (%[[ARG0:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[V0:.*]] = "test.any_cond"() ({
+// CHECK-NEXT:   %[[V1:.*]] = arith.addi %[[ARG0]], %[[ARG0]]
+// CHECK-NEXT:   br ^bb1
+// CHECK-NEXT: ^bb1:
+// CHECK-NEXT:   "test.yield"(%[[V1]]) : (i32) -> ()
+// CHECK-NEXT: })
+// CHECK-NEXT: return %[[V0]]
+func @test_not_sunk_deeply(%arg0: i32) -> i32 {
+  %0 = arith.addi %arg0, %arg0 : i32
+  %1 = "test.any_cond"() ({
+    br ^bb1
+  ^bb1:
+    "test.yield"(%0) : (i32) -> ()
+  }) : () -> i32
+  return %1 : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 2b915a9fd6c2a..d82f7af73decb 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1127,15 +1127,15 @@ static void print(OpAsmPrinter &p, RegionIfOp op) {
   p.printOperands(op.getOperands());
   p << ": " << op.getOperandTypes();
   p.printArrowTypeList(op.getResultTypes());
-  p << " then";
+  p << " then ";
   p.printRegion(op.getThenRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
-  p << " else";
+  p << " else ";
   p.printRegion(op.getElseRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
-  p << " join";
+  p << " join ";
   p.printRegion(op.getJoinRegion(),
                 /*printEntryBlockArgs=*/true,
                 /*printBlockTerminators=*/true);
@@ -1189,6 +1189,34 @@ void RegionIfOp::getSuccessorRegions(
   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
 }
 
+void RegionIfOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands,
+    SmallVectorImpl<InvocationBounds> &invocationBounds) {
+  // Each region is invoked at most once.
+  invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
+}
+
+//===----------------------------------------------------------------------===//
+// AnyCondOp
+//===----------------------------------------------------------------------===//
+
+void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
+                                    ArrayRef<Attribute> operands,
+                                    SmallVectorImpl<RegionSuccessor> &regions) {
+  // The parent op branches into the only region, and the region branches back
+  // to the parent op.
+  if (index)
+    regions.emplace_back(&getRegion());
+  else
+    regions.emplace_back(getResults());
+}
+
+void AnyCondOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands,
+    SmallVectorImpl<InvocationBounds> &invocationBounds) {
+  invocationBounds.emplace_back(1, 1);
+}
+
 //===----------------------------------------------------------------------===//
 // SingleNoTerminatorCustomAsmOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cc9c950ea73fe..f171e19270069 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2342,14 +2342,15 @@ def RegionIfYieldOp : TEST_Op<"region_if_yield",
 }
 
 def RegionIfOp : TEST_Op<"region_if",
-      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+                                 ["getRegionInvocationBounds"]>,
        SingleBlockImplicitTerminator<"RegionIfYieldOp">,
        RecursiveSideEffects]> {
   let description =[{
     Represents an abstract if-then-else-join pattern. In this context, the then
     and else regions jump to the join region, which finally returns to its
     parent op.
-    }];
+  }];
 
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parseRegionIfOp(parser, result); }];
@@ -2372,6 +2373,14 @@ def RegionIfOp : TEST_Op<"region_if",
   }];
 }
 
+def AnyCondOp : TEST_Op<"any_cond",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+                                 ["getRegionInvocationBounds"]>,
+       RecursiveSideEffects]> {
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region AnyRegion:$region);
+}
+
 //===----------------------------------------------------------------------===//
 // Test TableGen generated build() methods
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list