[Mlir-commits] [mlir] 2eda87d - [mlir][SCCP] Add support for propagating constants across inter-region control flow.

River Riddle llvmlistbot at llvm.org
Tue Apr 21 03:03:25 PDT 2020


Author: River Riddle
Date: 2020-04-21T02:59:25-07:00
New Revision: 2eda87dfbe63bae43b81b22c8c76a3139147797b

URL: https://github.com/llvm/llvm-project/commit/2eda87dfbe63bae43b81b22c8c76a3139147797b
DIFF: https://github.com/llvm/llvm-project/commit/2eda87dfbe63bae43b81b22c8c76a3139147797b.diff

LOG: [mlir][SCCP] Add support for propagating constants across inter-region control flow.

This is possible by adding two new ControlFlowInterface additions:

- A new interface, RegionBranchOpInterface
This interface allows for region holding operations to describe how control flows between regions. This interface initially contains two methods:

* getSuccessorEntryOperands
Returns the operands of this operation used as the entry arguments when entering the region at `index`, which was specified as a successor by `getSuccessorRegions`. when entering. These operands should correspond 1-1 with the successor inputs specified in `getSuccessorRegions`, and may be a subset of the entry arguments for that region.

*  getSuccessorRegions
Returns the viable successors of a region, or the possible successor when branching from the parent op. This allows for describing which regions may be executed when entering an operation, and which regions are executed after having executed another region of the parent op. For example, a structured loop operation may always enter into the loop body region. The loop body region may branch back to itself, or exit to the operation.

- A trait, ReturnLike
This trait signals that a terminator exits a region and forwards all of its operands as "exiting" values.

These additions allow for performing more general dataflow analysis in the presence of region holding operations.

Differential Revision: https://reviews.llvm.org/D78447

Added: 
    mlir/test/Transforms/sccp-structured.mlir

Modified: 
    mlir/include/mlir/Dialect/LoopOps/LoopOps.h
    mlir/include/mlir/Dialect/LoopOps/LoopOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
    mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
    mlir/lib/Dialect/LoopOps/LoopOps.cpp
    mlir/lib/Transforms/SCCP.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
index 09b982d93373..281b9001f18f 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffects.h"
 

diff  --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 436a376a6f4f..4548f5aae2e8 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -13,6 +13,7 @@
 #ifndef LOOP_OPS
 #define LOOP_OPS
 
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffects.td"
 
@@ -37,6 +38,7 @@ class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
 
 def ForOp : Loop_Op<"for",
       [DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+       DeclareOpInterfaceMethods<RegionBranchOpInterface>,
        SingleBlockImplicitTerminator<"YieldOp">,
        RecursiveSideEffects]> {
   let summary = "for operation";
@@ -169,11 +171,18 @@ def ForOp : Loop_Op<"for",
     unsigned getNumIterOperands() {
       return getOperation()->getNumOperands() - getNumControlOperands();
     }
+
+    /// Return operands used when entering the region at 'index'. These operands
+    /// correspond to the loop iterator operands, i.e., those exclusing the
+    /// induction variable. LoopOp only has one region, so 0 is the only valid
+    /// value for `index`.
+    OperandRange getSuccessorEntryOperands(unsigned index);
   }];
 }
 
 def IfOp : Loop_Op<"if",
-      [SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> {
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> {
   let summary = "if-then-else operation";
   let description = [{
     The `loop.if` operation represents an if-then-else construct for
@@ -385,7 +394,7 @@ def ReduceReturnOp :
   let assemblyFormat = "$result attr-dict `:` type($result)";
 }
 
-def YieldOp : Loop_Op<"yield", [NoSideEffect, Terminator]> {
+def YieldOp : Loop_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "loop.yield" yields an SSA value from a loop dialect op region and

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index c1980e687f68..38cb8dcb3d55 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1865,7 +1865,7 @@ def RemFOp : FloatArithmeticOp<"remf"> {
 // ReturnOp
 //===----------------------------------------------------------------------===//
 
-def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
+def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike,
                                  Terminator]> {
   let summary = "return operation";
   let description = [{

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 661d2ec029ff..e22454538343 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -19,6 +19,10 @@
 namespace mlir {
 class BranchOpInterface;
 
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
 namespace detail {
 /// Erase an operand from a branch operation that is used as a successor
 /// operand. `operandIndex` is the operand within `operands` to be erased.
@@ -37,7 +41,69 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                                             Optional<OperandRange> operands);
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// RegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+/// This class represents a successor of a region. A region successor can either
+/// be another region, or the parent operation. If the successor is a region,
+/// this class accepts the destination region, as well as a set of arguments
+/// from that region that will be populated by values from the current region.
+/// If the successor is the parent operation, this class accepts an optional set
+/// of results that will be populated by values from the current region.
+class RegionSuccessor {
+public:
+  /// Initialize a successor that branches to another region of the parent
+  /// operation.
+  RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
+      : region(region), inputs(regionInputs) {}
+  /// Initialize a successor that branches back to/out of the parent operation.
+  RegionSuccessor(Optional<Operation::result_range> results = {})
+      : region(nullptr), inputs(results ? ValueRange(*results) : ValueRange()) {
+  }
+
+  /// Return the given region successor. Returns nullptr if the successor is the
+  /// parent operation.
+  Region *getSuccessor() const { return region; }
+
+  /// Return the inputs to the successor that are remapped by the exit values of
+  /// the current region.
+  ValueRange getSuccessorInputs() const { return inputs; }
+
+private:
+  Region *region;
+  ValueRange inputs;
+};
+
+//===----------------------------------------------------------------------===//
+// ControlFlow Interfaces
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+
+//===----------------------------------------------------------------------===//
+// ControlFlow Traits
+//===----------------------------------------------------------------------===//
+
+namespace OpTrait {
+/// This trait indicates that a terminator operation is "return-like". This
+/// means that it exits its current region and forwards its operands as "exit"
+/// values to the parent region. Operations with this trait are not permitted to
+/// contain successors or produce results.
+template <typename ConcreteType>
+struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(ConcreteType::template hasTrait<IsTerminator>(),
+                  "expected operation to be a terminator");
+    static_assert(ConcreteType::template hasTrait<ZeroResult>(),
+                  "expected operation to have zero results");
+    static_assert(ConcreteType::template hasTrait<ZeroSuccessor>(),
+                  "expected operation to have zero successors");
+    return success();
+  }
+};
+} // namespace OpTrait
+
 } // end namespace mlir
 
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H

diff  --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2018995fe368..4067e2a4fcb2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -90,4 +90,55 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
+  let description = [{
+    This interface provides information for region operations that contain
+    branching behavior between held regions, i.e. this interface allows for
+    expressing control flow information for region holding operations.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+        Returns the operands of this operation used as the entry arguments when
+        entering the region at `index`, which was specified as a successor by
+        `getSuccessorRegions`. These operands should correspond 1-1 with the
+        successor inputs specified in `getSuccessorRegions`, and may corre
+      }],
+      "OperandRange", "getSuccessorEntryOperands",
+      (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
+        auto operandEnd = this->getOperation()->operand_end();
+        return OperandRange({operandEnd, operandEnd});
+      }]
+    >,
+    InterfaceMethod<[{
+        Returns the viable successors of a region at `index`, or the possible
+        successors when branching from the parent op if `index` is None. These
+        are the regions that may be selected during the flow of control. If
+        `index` is None, `operands` is a set of optional attributes that
+        either correspond to a constant value for each operand of this
+        operation, or null if that operand is not a constant. If `index` is
+        valid, `operands` corresponds to the exit values of the region at
+        `index`. Only a region, i.e. a valid `index`, may use the parent
+        operation as a successor. This method allows for describing which
+        regions may be executed when entering an operation, and which regions
+        are executed after having executed another region of the parent op. The
+        successor region must be non-empty.
+      }],
+      "void", "getSuccessorRegions",
+      (ins "Optional<unsigned>":$index, "ArrayRef<Attribute>":$operands,
+           "SmallVectorImpl<RegionSuccessor> &":$regions)
+    >
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// ControlFlow Traits
+//===----------------------------------------------------------------------===//
+
+// Op is "return-like".
+def ReturnLike : NativeOpTrait<"ReturnLike">;
+
 #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES

diff  --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
index 98cd49708370..80f8120a81b6 100644
--- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp
+++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp
@@ -196,6 +196,39 @@ ForOp mlir::loop::getForInductionVarOwner(Value val) {
   return dyn_cast_or_null<ForOp>(containingOp);
 }
 
+/// Return operands used when entering the region at 'index'. These operands
+/// correspond to the loop iterator operands, i.e., those exclusing the
+/// induction variable. LoopOp only has one region, so 0 is the only valid value
+/// for `index`.
+OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
+  assert(index == 0 && "invalid region index");
+
+  // The initial operands map to the loop arguments after the induction
+  // variable.
+  return initArgs();
+}
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void ForOp::getSuccessorRegions(Optional<unsigned> index,
+                                ArrayRef<Attribute> operands,
+                                SmallVectorImpl<RegionSuccessor> &regions) {
+  // If the predecessor is the ForOp, branch into the body using the iterator
+  // arguments.
+  if (!index.hasValue()) {
+    regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
+    return;
+  }
+
+  // Otherwise, the loop may branch back to itself or the parent operation.
+  assert(index.getValue() == 0 && "expected loop region");
+  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
+  regions.push_back(RegionSuccessor(getResults()));
+}
+
 //===----------------------------------------------------------------------===//
 // IfOp
 //===----------------------------------------------------------------------===//
@@ -298,6 +331,42 @@ static void print(OpAsmPrinter &p, IfOp op) {
   p.printOptionalAttrDict(op.getAttrs());
 }
 
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void IfOp::getSuccessorRegions(Optional<unsigned> index,
+                               ArrayRef<Attribute> operands,
+                               SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `then` and the `else` region branch back to the parent operation.
+  if (index.hasValue()) {
+    regions.push_back(RegionSuccessor(getResults()));
+    return;
+  }
+
+  // Don't consider the else region if it is empty.
+  Region *elseRegion = &this->elseRegion();
+  if (elseRegion->empty())
+    elseRegion = nullptr;
+
+  // Otherwise, the successor is dependent on the condition.
+  bool condition;
+  if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+    condition = condAttr.getValue().isOneValue();
+  } else if (auto condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) {
+    condition = condAttr.getValue();
+  } else {
+    // If the condition isn't constant, both regions may be executed.
+    regions.push_back(RegionSuccessor(&thenRegion()));
+    regions.push_back(RegionSuccessor(elseRegion));
+    return;
+  }
+
+  // Add the successor regions using the condition.
+  regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index aecc4d61fddd..1d0a279cc592 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -138,13 +138,30 @@ class SCCPSolver {
   LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
                                     Value value);
 
+  /// Visit the users of the given IR that reside within executable blocks.
+  template <typename T>
+  void visitUsers(T &value) {
+    for (Operation *user : value.getUsers())
+      if (isBlockExecutable(user->getBlock()))
+        visitOperation(user);
+  }
+
   /// Visit the given operation and compute any necessary lattice state.
   void visitOperation(Operation *op);
 
   /// Visit the given operation, which defines regions, and compute any
   /// necessary lattice state. This also resolves the lattice state of both the
   /// operation results and any nested regions.
-  void visitRegionOperation(Operation *op);
+  void visitRegionOperation(Operation *op,
+                            ArrayRef<Attribute> constantOperands);
+
+  /// Visit the given set of region successors, computing any necessary lattice
+  /// state. The provided function returns the input operands to the region at
+  /// the given index. If the index is 'None', the input operands correspond to
+  /// the parent operation results.
+  void visitRegionSuccessors(
+      Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+      function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
 
   /// Visit the given terminator operation and compute any necessary lattice
   /// state.
@@ -186,6 +203,16 @@ class SCCPSolver {
     markAllOverdefined(values);
     opWorklist.push_back(op);
   }
+  template <typename ValuesT>
+  void markAllOverdefinedAndVisitUsers(ValuesT values) {
+    for (auto value : values) {
+      auto &lattice = latticeValues[value];
+      if (!lattice.isOverdefined()) {
+        lattice.markOverdefined();
+        visitUsers(value);
+      }
+    }
+  }
 
   /// Returns true if the given value was marked as overdefined.
   bool isOverdefined(Value value) const;
@@ -229,15 +256,8 @@ SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
 void SCCPSolver::solve() {
   while (!blockWorklist.empty() || !opWorklist.empty()) {
     // Process any operations in the op worklist.
-    while (!opWorklist.empty()) {
-      Operation *op = opWorklist.pop_back_val();
-
-      // Visit all of the live users to propagate changes to this operation.
-      for (Operation *user : op->getUsers()) {
-        if (isBlockExecutable(user->getBlock()))
-          visitOperation(user);
-      }
-    }
+    while (!opWorklist.empty())
+      visitUsers(*opWorklist.pop_back_val());
 
     // Process any blocks in the block worklist.
     while (!blockWorklist.empty())
@@ -330,7 +350,7 @@ void SCCPSolver::visitOperation(Operation *op) {
   // Process region holding operations. The region visitor processes result
   // values, so we can exit afterwards.
   if (op->getNumRegions())
-    return visitRegionOperation(op);
+    return visitRegionOperation(op, operandConstants);
 
   // If this op produces no results, it can't produce any constants.
   if (op->getNumResults() == 0)
@@ -379,25 +399,144 @@ void SCCPSolver::visitOperation(Operation *op) {
   }
 }
 
-void SCCPSolver::visitRegionOperation(Operation *op) {
-  for (Region &region : op->getRegions()) {
-    if (region.empty())
-      continue;
-    Block *entryBlock = &region.front();
-    markBlockExecutable(entryBlock);
-    markAllOverdefined(entryBlock->getArguments());
+void SCCPSolver::visitRegionOperation(Operation *op,
+                                      ArrayRef<Attribute> constantOperands) {
+  // Check to see if we can reason about the internal control flow of this
+  // region operation.
+  auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
+  if (!regionInterface) {
+    // If we can't, conservatively mark all regions as executable.
+    for (Region &region : op->getRegions()) {
+      if (region.empty())
+        continue;
+      Block *entryBlock = &region.front();
+      markBlockExecutable(entryBlock);
+      markAllOverdefined(entryBlock->getArguments());
+    }
+
+    // Don't try to simulate the results of a region operation as we can't
+    // guarantee that folding will be out-of-place. We don't allow in-place
+    // folds as the desire here is for simulated execution, and not general
+    // folding.
+    return markAllOverdefined(op, op->getResults());
   }
 
-  // Don't try to simulate the results of a region operation as we can't
-  // guarantee that folding will be out-of-place. We don't allow in-place folds
-  // as the desire here is for simulated execution, and not general folding.
-  return markAllOverdefined(op, op->getResults());
+  // Check to see which regions are executable.
+  SmallVector<RegionSuccessor, 1> successors;
+  regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands,
+                                      successors);
+
+  // If the interface identified that no region will be executed. Mark
+  // any results of this operation as overdefined, as we can't reason about
+  // them.
+  // TODO: If we had an interface to detect pass through operands, we could
+  // resolve some results based on the lattice state of the operands. We could
+  // also allow for the parent operation to have itself as a region successor.
+  if (successors.empty())
+    return markAllOverdefined(op, op->getResults());
+  return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) {
+    assert(index && "expected valid region index");
+    return regionInterface.getSuccessorEntryOperands(*index);
+  });
+}
+
+void SCCPSolver::visitRegionSuccessors(
+    Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+    function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
+  for (const RegionSuccessor &it : regionSuccessors) {
+    Region *region = it.getSuccessor();
+    ValueRange succArgs = it.getSuccessorInputs();
+
+    // Check to see if this is the parent operation.
+    if (!region) {
+      ResultRange results = parentOp->getResults();
+      if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); }))
+        continue;
+
+      // Mark the results outside of the input range as overdefined.
+      if (succArgs.size() != results.size()) {
+        opWorklist.push_back(parentOp);
+        if (succArgs.empty())
+          return markAllOverdefined(results);
+
+        unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
+        markAllOverdefined(results.take_front(firstResIdx));
+        markAllOverdefined(results.drop_front(firstResIdx + succArgs.size()));
+      }
+
+      // Update the lattice for any operation results.
+      OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
+      for (auto it : llvm::zip(succArgs, operands))
+        meet(parentOp, latticeValues[std::get<0>(it)],
+             latticeValues[std::get<1>(it)]);
+      return;
+    }
+    assert(!region->empty() && "expected region to be non-empty");
+    Block *entryBlock = &region->front();
+    markBlockExecutable(entryBlock);
+
+    // If all of the arguments are already overdefined, the arguments have
+    // already been fully resolved.
+    auto arguments = entryBlock->getArguments();
+    if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); }))
+      continue;
+
+    // Mark any arguments that do not receive inputs as overdefined, we won't be
+    // able to discern if they are constant.
+    if (succArgs.size() != arguments.size()) {
+      if (succArgs.empty()) {
+        markAllOverdefined(arguments);
+        continue;
+      }
+
+      unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
+      markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx));
+      markAllOverdefinedAndVisitUsers(
+          arguments.drop_front(firstArgIdx + succArgs.size()));
+    }
+
+    // Update the lattice for arguments that have inputs from the predecessor.
+    OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
+    for (auto it : llvm::zip(succArgs, succOperands)) {
+      LatticeValue &argLattice = latticeValues[std::get<0>(it)];
+      if (argLattice.meet(latticeValues[std::get<1>(it)]))
+        visitUsers(std::get<0>(it));
+    }
+  }
 }
 
 void SCCPSolver::visitTerminatorOperation(
     Operation *op, ArrayRef<Attribute> constantOperands) {
-  if (op->getNumSuccessors() == 0)
-    return;
+  // If this operation has no successors, we treat it as an exiting terminator.
+  if (op->getNumSuccessors() == 0) {
+    // Check to see if the parent tracks region control flow.
+    Region *parentRegion = op->getParentRegion();
+    Operation *parentOp = parentRegion->getParentOp();
+    auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
+    if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
+      return;
+
+    // Query the set of successors from the current region.
+    SmallVector<RegionSuccessor, 1> regionSuccessors;
+    regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(),
+                                        constantOperands, regionSuccessors);
+    if (regionSuccessors.empty())
+      return;
+
+    // If this terminator is not "region-like", conservatively mark all of the
+    // successor values as overdefined.
+    if (!op->hasTrait<OpTrait::ReturnLike>()) {
+      for (auto &it : regionSuccessors)
+        markAllOverdefinedAndVisitUsers(it.getSuccessorInputs());
+      return;
+    }
+
+    // Otherwise, propagate the operand lattice states to each of the
+    // successors.
+    OperandRange operands = op->getOperands();
+    return visitRegionSuccessors(parentOp, regionSuccessors,
+                                 [&](Optional<unsigned>) { return operands; });
+  }
 
   // Try to resolve to a specific successor with the constant operands.
   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
@@ -465,11 +604,8 @@ void SCCPSolver::visitBlockArgument(Block *block, int i) {
   }
 
   // If the lattice was updated, visit any executable users of the argument.
-  if (updatedLattice) {
-    for (Operation *user : arg.getUsers())
-      if (isBlockExecutable(user->getBlock()))
-        visitOperation(user);
-  }
+  if (updatedLattice)
+    visitUsers(arg);
 }
 
 bool SCCPSolver::markBlockExecutable(Block *block) {

diff  --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir
new file mode 100644
index 000000000000..4acb6f9c99f2
--- /dev/null
+++ b/mlir/test/Transforms/sccp-structured.mlir
@@ -0,0 +1,132 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func(sccp)" -split-input-file | FileCheck %s
+
+/// Check that a constant is properly propagated when only one edge is taken.
+
+// CHECK-LABEL: func @simple(
+func @simple(%arg0 : i32) -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK-NOT: loop.if
+  // CHECK: return %[[CST]] : i32
+
+  %cond = constant true
+  %res = loop.if %cond -> (i32) {
+    %1 = constant 1 : i32
+    loop.yield %1 : i32
+  } else {
+    loop.yield %arg0 : i32
+  }
+  return %res : i32
+}
+
+/// Check that a constant is properly propagated when both edges produce the
+/// same value.
+
+// CHECK-LABEL: func @simple_both_same(
+func @simple_both_same(%cond : i1) -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK-NOT: loop.if
+  // CHECK: return %[[CST]] : i32
+
+  %res = loop.if %cond -> (i32) {
+    %1 = constant 1 : i32
+    loop.yield %1 : i32
+  } else {
+    %2 = constant 1 : i32
+    loop.yield %2 : i32
+  }
+  return %res : i32
+}
+
+/// Check that the arguments go to overdefined if the branch cannot detect when
+/// a specific successor is taken.
+
+// CHECK-LABEL: func @overdefined_unknown_condition(
+func @overdefined_unknown_condition(%cond : i1, %arg0 : i32) -> i32 {
+  // CHECK: %[[RES:.*]] = loop.if
+  // CHECK: return %[[RES]] : i32
+
+  %res = loop.if %cond -> (i32) {
+    %1 = constant 1 : i32
+    loop.yield %1 : i32
+  } else {
+    loop.yield %arg0 : i32
+  }
+  return %res : i32
+}
+
+/// Check that the arguments go to overdefined if there are conflicting
+/// constants.
+
+// CHECK-LABEL: func @overdefined_
diff erent_constants(
+func @overdefined_
diff erent_constants(%cond : i1) -> i32 {
+  // CHECK: %[[RES:.*]] = loop.if
+  // CHECK: return %[[RES]] : i32
+
+  %res = loop.if %cond -> (i32) {
+    %1 = constant 1 : i32
+    loop.yield %1 : i32
+  } else {
+    %2 = constant 2 : i32
+    loop.yield %2 : i32
+  }
+  return %res : i32
+}
+
+/// Check that arguments are properly merged across loop-like control flow.
+
+// CHECK-LABEL: func @simple_loop(
+func @simple_loop(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
+  // CHECK: %[[CST:.*]] = constant 0 : i32
+  // CHECK-NOT: loop.for
+  // CHECK: return %[[CST]] : i32
+
+  %s0 = constant 0 : i32
+  %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (i32) {
+    %sn = addi %si, %si : i32
+    loop.yield %sn : i32
+  }
+  return %result : i32
+}
+
+/// Check that arguments go to overdefined when loop backedges produce a
+/// conflicting value.
+
+// CHECK-LABEL: func @loop_overdefined(
+func @loop_overdefined(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
+  // CHECK: %[[RES:.*]] = loop.for
+  // CHECK: return %[[RES]] : i32
+
+  %s0 = constant 1 : i32
+  %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (i32) {
+    %sn = addi %si, %si : i32
+    loop.yield %sn : i32
+  }
+  return %result : i32
+}
+
+/// Test that we can properly propagate within inner control, and in situations
+/// where the executable edges within the CFG are sensitive to the current state
+/// of the analysis.
+
+// CHECK-LABEL: func @loop_inner_control_flow(
+func @loop_inner_control_flow(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK-NOT: loop.for
+  // CHECK-NOT: loop.if
+  // CHECK: return %[[CST]] : i32
+
+  %cst_1 = constant 1 : i32
+  %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %cst_1) -> (i32) {
+    %cst_20 = constant 20 : i32
+    %cond = cmpi "ult", %si, %cst_20 : i32
+    %inner_res = loop.if %cond -> (i32) {
+      %1 = constant 1 : i32
+      loop.yield %1 : i32
+    } else {
+      %si_inc = addi %si, %cst_1 : i32
+      loop.yield %si_inc : i32
+    }
+    loop.yield %inner_res : i32
+  }
+  return %result : i32
+}


        


More information about the Mlir-commits mailing list