[Mlir-commits] [mlir] a90151d - [mlir][SCCP] Add support for propagating across symbol based calls

River Riddle llvmlistbot at llvm.org
Mon Apr 27 13:06:17 PDT 2020


Author: River Riddle
Date: 2020-04-27T13:04:49-07:00
New Revision: a90151d67e23c8a2c8362c95b44340e19f955a51

URL: https://github.com/llvm/llvm-project/commit/a90151d67e23c8a2c8362c95b44340e19f955a51
DIFF: https://github.com/llvm/llvm-project/commit/a90151d67e23c8a2c8362c95b44340e19f955a51.diff

LOG: [mlir][SCCP] Add support for propagating across symbol based calls

This revision adds support for propagating constants across symbol-based callgraph edges. It uses the existing Call/CallableOpInterfaces to detect the dataflow edges, and propagates constants through arguments and out of returns.

Differential Revision: https://reviews.llvm.org/D78592

Added: 
    mlir/test/Transforms/sccp-callgraph.mlir

Modified: 
    mlir/include/mlir/IR/SymbolTable.h
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/lib/Transforms/SCCP.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 216948b2b3df..0b035836ec61 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -86,6 +86,15 @@ class SymbolTable {
   /// nullptr if no valid parent symbol table could be found.
   static Operation *getNearestSymbolTable(Operation *from);
 
+  /// Walks all symbol table operations nested within, and including, `op`. For
+  /// each symbol table operation, the provided callback is invoked with the op
+  /// and a boolean signifying if the symbols within that symbol table can be
+  /// treated as if all uses within the IR are visible to the caller.
+  /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
+  /// within `op` are visible.
+  static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
+                               function_ref<void(Operation *, bool)> callback);
+
   /// Returns the operation registered with the given symbol name with the
   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
   /// with the 'OpTrait::SymbolTable' trait.

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 0ff189de6800..81ab52f197aa 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -34,7 +34,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
     InterfaceMethod<[{
         Returns the callee of this call-like operation. A `callee` is either a
         reference to a symbol, via SymbolRefAttr, or a reference to a defined
-        SSA value.
+        SSA value. If the reference is an SSA value, the SSA value corresponds
+        to a region of a lambda-like operation.
       }],
       "CallInterfaceCallable", "getCallableForCallee"
     >,

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 487b51de8dc9..dc4186eaf129 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -207,6 +207,35 @@ Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
   return from;
 }
 
+/// Walks all symbol table operations nested within, and including, `op`. For
+/// each symbol table operation, the provided callback is invoked with the op
+/// and a boolean signifying if the symbols within that symbol table can be
+/// treated as if all uses are visible. `allSymUsesVisible` identifies whether
+/// all of the symbol uses of symbols within `op` are visible.
+void SymbolTable::walkSymbolTables(
+    Operation *op, bool allSymUsesVisible,
+    function_ref<void(Operation *, bool)> callback) {
+  bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
+  if (isSymbolTable) {
+    SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
+    allSymUsesVisible |= !symbol || symbol.isPrivate();
+  } else {
+    // Otherwise if 'op' is not a symbol table, any nested symbols are
+    // guaranteed to be hidden.
+    allSymUsesVisible = true;
+  }
+
+  for (Region &region : op->getRegions())
+    for (Block &block : region)
+      for (Operation &nestedOp : block)
+        walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
+
+  // If 'op' had the symbol table trait, visit it after any nested symbol
+  // tables.
+  if (isSymbolTable)
+    callback(op, allSymUsesVisible);
+}
+
 /// Returns the operation registered with the given symbol name with the
 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 28c8216f8333..c0f89da300f1 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -31,29 +31,6 @@ using namespace mlir;
 // Symbol Use Tracking
 //===----------------------------------------------------------------------===//
 
-/// Walk all of the symbol table operations nested with 'op' along with a
-/// boolean signifying if the symbols within can be treated as if all uses are
-/// visible. The provided callback is invoked with the symbol table operation,
-/// and a boolean signaling if all of the uses within the symbol table are
-/// visible.
-static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
-                             function_ref<void(Operation *, bool)> callback) {
-  if (op->hasTrait<OpTrait::SymbolTable>()) {
-    SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
-    allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate();
-    callback(op, allSymUsesVisible);
-  } else {
-    // Otherwise if 'op' is not a symbol table, any nested symbols are
-    // guaranteed to be hidden.
-    allSymUsesVisible = true;
-  }
-
-  for (Region &region : op->getRegions())
-    for (Block &block : region)
-      for (Operation &nested : block)
-        walkSymbolTables(&nested, allSymUsesVisible, callback);
-}
-
 /// Walk all of the used symbol callgraph nodes referenced with the given op.
 static void walkReferencedSymbolNodes(
     Operation *op, CallGraph &cg,
@@ -164,7 +141,8 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
       }
     }
   };
-  walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn);
+  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
+                                walkFn);
 
   // Drop the use information for any discardable nodes that are always live.
   for (auto &it : alwaysLiveNodes)

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 1d0a279cc592..c9fc4ba2f395 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -116,12 +116,56 @@ class LatticeValue {
   Dialect *constantDialect;
 };
 
+/// This class contains various state used when computing the lattice of a
+/// callable operation.
+class CallableLatticeState {
+public:
+  /// Build a lattice state with a given callable region, and a specified number
+  /// of results to be initialized to the default lattice value (Unknown).
+  CallableLatticeState(Region *callableRegion, unsigned numResults)
+      : callableArguments(callableRegion->front().getArguments()),
+        resultLatticeValues(numResults) {}
+
+  /// Returns the arguments to the callable region.
+  Block::BlockArgListType getCallableArguments() const {
+    return callableArguments;
+  }
+
+  /// Returns the lattice value for the results of the callable region.
+  MutableArrayRef<LatticeValue> getResultLatticeValues() {
+    return resultLatticeValues;
+  }
+
+  /// Add a call to this callable. This is only used if the callable defines a
+  /// symbol.
+  void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
+
+  /// Return the calls that reference this callable. This is only used
+  /// if the callable defines a symbol.
+  ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
+
+private:
+  /// The arguments of the callable region.
+  Block::BlockArgListType callableArguments;
+
+  /// The lattice state for each of the results of this region. The return
+  /// values of the callable aren't SSA values, so we need to track them
+  /// separately.
+  SmallVector<LatticeValue, 4> resultLatticeValues;
+
+  /// The calls referencing this callable if this callable defines a symbol.
+  /// This removes the need to recompute symbol references during propagation.
+  /// Value based references are trivial to resolve, so they can be done
+  /// in-place.
+  SmallVector<Operation *, 4> symbolCalls;
+};
+
 /// This class represents the solver for the SCCP analysis. This class acts as
 /// the propagation engine for computing which values form constants.
 class SCCPSolver {
 public:
-  /// Initialize the solver with a given set of regions.
-  SCCPSolver(MutableArrayRef<Region> regions);
+  /// Initialize the solver with the given top-level operation.
+  SCCPSolver(Operation *op);
 
   /// Run the solver until it converges.
   void solve();
@@ -132,6 +176,11 @@ class SCCPSolver {
   void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
 
 private:
+  /// Initialize the set of symbol defining callables that can have their
+  /// arguments and results tracked. 'op' is the top-level operation that SCCP
+  /// is operating on.
+  void initializeSymbolCallables(Operation *op);
+
   /// Replace the given value with a constant if the corresponding lattice
   /// represents a constant. Returns success if the value was replaced, failure
   /// otherwise.
@@ -149,6 +198,13 @@ class SCCPSolver {
   /// Visit the given operation and compute any necessary lattice state.
   void visitOperation(Operation *op);
 
+  /// Visit the given call operation and compute any necessary lattice state.
+  void visitCallOperation(CallOpInterface op);
+
+  /// Visit the given callable operation and compute any necessary lattice
+  /// state.
+  void visitCallableOperation(Operation *op);
+
   /// Visit the given operation, which defines regions, and compute any
   /// necessary lattice state. This also resolves the lattice state of both the
   /// operation results and any nested regions.
@@ -168,6 +224,11 @@ class SCCPSolver {
   void visitTerminatorOperation(Operation *op,
                                 ArrayRef<Attribute> constantOperands);
 
+  /// Visit the given terminator operation that exits a callable region. These
+  /// are terminators with no CFG successors.
+  void visitCallableTerminatorOperation(Operation *callable,
+                                        Operation *terminator);
+
   /// Visit the given block and compute any necessary lattice state.
   void visitBlock(Block *block);
 
@@ -235,11 +296,20 @@ class SCCPSolver {
 
   /// A worklist of operations that need to be processed.
   SmallVector<Operation *, 64> opWorklist;
+
+  /// The callable operations that have their argument/result state tracked.
+  DenseMap<Operation *, CallableLatticeState> callableLatticeState;
+
+  /// A map between a call operation and the resolved symbol callable. This
+  /// avoids re-resolving symbol references during propagation. Value based
+  /// callables are trivial to resolve, so they can be done in-place.
+  DenseMap<Operation *, Operation *> callToSymbolCallable;
 };
 } // end anonymous namespace
 
-SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
-  for (Region &region : regions) {
+SCCPSolver::SCCPSolver(Operation *op) {
+  /// Initialize the solver with the regions within this operation.
+  for (Region &region : op->getRegions()) {
     if (region.empty())
       continue;
     Block *entryBlock = &region.front();
@@ -251,6 +321,7 @@ SCCPSolver::SCCPSolver(MutableArrayRef<Region> regions) {
     // as overdefined.
     markAllOverdefined(entryBlock->getArguments());
   }
+  initializeSymbolCallables(op);
 }
 
 void SCCPSolver::solve() {
@@ -310,6 +381,73 @@ void SCCPSolver::rewrite(MLIRContext *context,
   }
 }
 
+void SCCPSolver::initializeSymbolCallables(Operation *op) {
+  // Initialize the set of symbol callables that can have their state tracked.
+  // This tracks which symbol callable operations we can propagate within and
+  // out of.
+  auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
+    Region &symbolTableRegion = symTable->getRegion(0);
+    Block *symbolTableBlock = &symbolTableRegion.front();
+    for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
+      // We won't be able to track external callables.
+      Region *callableRegion = callable.getCallableRegion();
+      if (!callableRegion)
+        continue;
+      // We only care about symbol defining callables here.
+      auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
+      if (!symbol)
+        continue;
+      callableLatticeState.try_emplace(callable, callableRegion,
+                                       callable.getCallableResults().size());
+
+      // If not all of the uses of this symbol are visible, we can't track the
+      // state of the arguments.
+      if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
+        markAllOverdefined(callableRegion->front().getArguments());
+    }
+    if (callableLatticeState.empty())
+      return;
+
+    // After computing the valid callables, walk any symbol uses to check
+    // for non-call references. We won't be able to track the lattice state
+    // for arguments to these callables, as we can't guarantee that we can see
+    // all of its calls.
+    Optional<SymbolTable::UseRange> uses =
+        SymbolTable::getSymbolUses(&symbolTableRegion);
+    if (!uses) {
+      // If we couldn't gather the symbol uses, conservatively assume that
+      // we can't track information for any nested symbols.
+      op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
+      return;
+    }
+
+    for (const SymbolTable::SymbolUse &use : *uses) {
+      // If the use is a call, track it to avoid the need to recompute the
+      // reference later.
+      if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
+        Operation *symCallable = callOp.resolveCallable();
+        auto callableLatticeIt = callableLatticeState.find(symCallable);
+        if (callableLatticeIt != callableLatticeState.end()) {
+          callToSymbolCallable.try_emplace(callOp, symCallable);
+
+          // We only need to record the call in the lattice if it produces any
+          // values.
+          if (callOp.getOperation()->getNumResults())
+            callableLatticeIt->second.addSymbolCall(callOp);
+        }
+        continue;
+      }
+      // This use isn't a call, so don't we know all of the callers.
+      auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
+      auto it = callableLatticeState.find(symbol);
+      if (it != callableLatticeState.end())
+        markAllOverdefined(it->second.getCallableArguments());
+    }
+  };
+  SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
+                                walkFn);
+}
+
 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
                                               OperationFolder &folder,
                                               Value value) {
@@ -347,6 +485,16 @@ void SCCPSolver::visitOperation(Operation *op) {
   if (op->isKnownTerminator())
     visitTerminatorOperation(op, operandConstants);
 
+  // Process call operations. The call visitor processes result values, so we
+  // can exit afterwards.
+  if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
+    return visitCallOperation(call);
+
+  // Process callable operations. These are specially handled region operations
+  // that track dataflow via calls.
+  if (isa<CallableOpInterface>(op))
+    return visitCallableOperation(op);
+
   // Process region holding operations. The region visitor processes result
   // values, so we can exit afterwards.
   if (op->getNumRegions())
@@ -399,6 +547,62 @@ void SCCPSolver::visitOperation(Operation *op) {
   }
 }
 
+void SCCPSolver::visitCallableOperation(Operation *op) {
+  // Mark the regions as executable.
+  bool isTrackingLatticeState = callableLatticeState.count(op);
+  for (Region &region : op->getRegions()) {
+    if (region.empty())
+      continue;
+    Block *entryBlock = &region.front();
+    markBlockExecutable(entryBlock);
+
+    // If we aren't tracking lattice state for this callable, mark all of the
+    // region arguments as overdefined.
+    if (!isTrackingLatticeState)
+      markAllOverdefined(entryBlock->getArguments());
+  }
+
+  // TODO: Add support for non-symbol callables when necessary. If the callable
+  // has non-call uses we would mark overdefined, otherwise allow for
+  // propagating the return values out.
+  markAllOverdefined(op, op->getResults());
+}
+
+void SCCPSolver::visitCallOperation(CallOpInterface op) {
+  ResultRange callResults = op.getOperation()->getResults();
+
+  // Resolve the callable operation for this call.
+  Operation *callableOp = nullptr;
+  if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
+    callableOp = callableValue.getDefiningOp();
+  else
+    callableOp = callToSymbolCallable.lookup(op);
+
+  // The callable of this call can't be resolved, mark any results overdefined.
+  if (!callableOp)
+    return markAllOverdefined(op, callResults);
+
+  // If this callable is tracking state, merge the argument operands with the
+  // arguments of the callable.
+  auto callableLatticeIt = callableLatticeState.find(callableOp);
+  if (callableLatticeIt == callableLatticeState.end())
+    return markAllOverdefined(op, callResults);
+
+  OperandRange callOperands = op.getArgOperands();
+  auto callableArgs = callableLatticeIt->second.getCallableArguments();
+  for (auto it : llvm::zip(callOperands, callableArgs)) {
+    BlockArgument callableArg = std::get<1>(it);
+    if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
+      visitUsers(callableArg);
+  }
+
+  // Merge in the lattice state for the callable results as well.
+  auto callableResults = callableLatticeIt->second.getResultLatticeValues();
+  for (auto it : llvm::zip(callResults, callableResults))
+    meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
+         /*from=*/std::get<1>(it));
+}
+
 void SCCPSolver::visitRegionOperation(Operation *op,
                                       ArrayRef<Attribute> constantOperands) {
   // Check to see if we can reason about the internal control flow of this
@@ -509,9 +713,14 @@ void SCCPSolver::visitTerminatorOperation(
     Operation *op, ArrayRef<Attribute> constantOperands) {
   // If this operation has no successors, we treat it as an exiting terminator.
   if (op->getNumSuccessors() == 0) {
-    // Check to see if the parent tracks region control flow.
     Region *parentRegion = op->getParentRegion();
     Operation *parentOp = parentRegion->getParentOp();
+
+    // Check to see if this is a terminator for a callable region.
+    if (isa<CallableOpInterface>(parentOp))
+      return visitCallableTerminatorOperation(parentOp, op);
+
+    // Otherwise, check to see if the parent tracks region control flow.
     auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
     if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
       return;
@@ -552,6 +761,42 @@ void SCCPSolver::visitTerminatorOperation(
     markEdgeExecutable(block, succ);
 }
 
+void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
+                                                  Operation *terminator) {
+  // If there are no exiting values, we have nothing to track.
+  if (terminator->getNumOperands() == 0)
+    return;
+
+  // If this callable isn't tracking any lattice state there is nothing to do.
+  auto latticeIt = callableLatticeState.find(callable);
+  if (latticeIt == callableLatticeState.end())
+    return;
+  assert(callable->getNumResults() == 0 && "expected symbol callable");
+
+  // If this terminator is not "return-like", conservatively mark all of the
+  // call-site results as overdefined.
+  auto callableResultLattices = latticeIt->second.getResultLatticeValues();
+  if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
+    for (auto &it : callableResultLattices)
+      it.markOverdefined();
+    for (Operation *call : latticeIt->second.getSymbolCalls())
+      markAllOverdefined(call, call->getResults());
+    return;
+  }
+
+  // Merge the terminator operands into the results.
+  bool anyChanged = false;
+  for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
+    anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
+  if (!anyChanged)
+    return;
+
+  // If any of the result lattices changed, update the callers.
+  for (Operation *call : latticeIt->second.getSymbolCalls())
+    for (auto it : llvm::zip(call->getResults(), callableResultLattices))
+      meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
+}
+
 void SCCPSolver::visitBlock(Block *block) {
   // If the block is not the entry block we need to compute the lattice state
   // for the block arguments. Entry block argument lattices are computed
@@ -663,7 +908,7 @@ void SCCP::runOnOperation() {
   Operation *op = getOperation();
 
   // Solve for SCCP constraints within nested regions.
-  SCCPSolver solver(op->getRegions());
+  SCCPSolver solver(op);
   solver.solve();
 
   // Cleanup any operations using the solver analysis.

diff  --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir
new file mode 100644
index 000000000000..5d47a277df93
--- /dev/null
+++ b/mlir/test/Transforms/sccp-callgraph.mlir
@@ -0,0 +1,257 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s -dump-input-on-failure
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED -dump-input-on-failure
+
+/// Check that a constant is properly propagated through the arguments and
+/// results of a private function.
+
+// CHECK-LABEL: func @private(
+func @private(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: return %[[CST]] : i32
+
+  return %arg0 : i32
+}
+
+// CHECK-LABEL: func @simple_private(
+func @simple_private() -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: return %[[CST]] : i32
+
+  %1 = constant 1 : i32
+  %result = call @private(%1) : (i32) -> i32
+  return %result : i32
+}
+
+// -----
+
+/// Check that a constant is properly propagated through the arguments and
+/// results of a visible nested function.
+
+// CHECK: func @nested(
+func @nested(%arg0 : i32) -> i32 attributes { sym_visibility = "nested" } {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: return %[[CST]] : i32
+
+  return %arg0 : i32
+}
+
+// CHECK-LABEL: func @simple_nested(
+func @simple_nested() -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: return %[[CST]] : i32
+
+  %1 = constant 1 : i32
+  %result = call @nested(%1) : (i32) -> i32
+  return %result : i32
+}
+
+// -----
+
+/// Check that non-visible nested functions do not track arguments.
+module {
+  // NESTED-LABEL: module @nested_module
+  module @nested_module attributes { sym_visibility = "public" } {
+
+    // NESTED: func @nested(
+    func @nested(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "nested" } {
+      // NESTED: %[[CST:.*]] = constant 1 : i32
+      // NESTED: return %[[CST]], %arg0 : i32, i32
+
+      %1 = constant 1 : i32
+      return %1, %arg0 : i32, i32
+    }
+
+    // NESTED: func @nested_not_all_uses_visible(
+    func @nested_not_all_uses_visible() -> (i32, i32) {
+      // NESTED: %[[CST:.*]] = constant 1 : i32
+      // NESTED: %[[CALL:.*]]:2 = call @nested
+      // NESTED: return %[[CST]], %[[CALL]]#1 : i32, i32
+
+      %1 = constant 1 : i32
+      %result:2 = call @nested(%1) : (i32) -> (i32, i32)
+      return %result#0, %result#1 : i32, i32
+    }
+  }
+}
+
+// -----
+
+/// Check that public functions do not track arguments.
+
+// CHECK-LABEL: func @public(
+func @public(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "public" } {
+  %1 = constant 1 : i32
+  return %1, %arg0 : i32, i32
+}
+
+// CHECK-LABEL: func @simple_public(
+func @simple_public() -> (i32, i32) {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: %[[CALL:.*]]:2 = call @public
+  // CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32
+
+  %1 = constant 1 : i32
+  %result:2 = call @public(%1) : (i32) -> (i32, i32)
+  return %result#0, %result#1 : i32, i32
+}
+
+// -----
+
+/// Check that functions with non-call users don't have arguments tracked.
+
+func @callable(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "private" } {
+  %1 = constant 1 : i32
+  return %1, %arg0 : i32, i32
+}
+
+// CHECK-LABEL: func @non_call_users(
+func @non_call_users() -> (i32, i32) {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: %[[CALL:.*]]:2 = call @callable
+  // CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32
+
+  %1 = constant 1 : i32
+  %result:2 = call @callable(%1) : (i32) -> (i32, i32)
+  return %result#0, %result#1 : i32, i32
+}
+
+"live.user"() {uses = [@callable]} : () -> ()
+
+// -----
+
+/// Check that return values are overdefined in the presence of an unknown terminator.
+
+func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  "unknown.return"(%arg0) : (i32) -> ()
+}
+
+// CHECK-LABEL: func @unknown_terminator(
+func @unknown_terminator() -> i32 {
+  // CHECK: %[[CALL:.*]] = call @callable
+  // CHECK: return %[[CALL]] : i32
+
+  %1 = constant 1 : i32
+  %result = call @callable(%1) : (i32) -> i32
+  return %result : i32
+}
+
+// -----
+
+/// Check that return values are overdefined when the constant conflicts.
+
+func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  "unknown.return"(%arg0) : (i32) -> ()
+}
+
+// CHECK-LABEL: func @conflicting_constant(
+func @conflicting_constant() -> (i32, i32) {
+  // CHECK: %[[CALL1:.*]] = call @callable
+  // CHECK: %[[CALL2:.*]] = call @callable
+  // CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32
+
+  %1 = constant 1 : i32
+  %2 = constant 2 : i32
+  %result = call @callable(%1) : (i32) -> i32
+  %result2 = call @callable(%2) : (i32) -> i32
+  return %result, %result2 : i32, i32
+}
+
+// -----
+
+/// Check that return values are overdefined when the constant conflicts with a
+/// non-constant.
+
+func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  "unknown.return"(%arg0) : (i32) -> ()
+}
+
+// CHECK-LABEL: func @conflicting_constant(
+func @conflicting_constant(%arg0 : i32) -> (i32, i32) {
+  // CHECK: %[[CALL1:.*]] = call @callable
+  // CHECK: %[[CALL2:.*]] = call @callable
+  // CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32
+
+  %1 = constant 1 : i32
+  %result = call @callable(%1) : (i32) -> i32
+  %result2 = call @callable(%arg0) : (i32) -> i32
+  return %result, %result2 : i32, i32
+}
+
+// -----
+
+/// Check a more complex interaction with calls and control flow.
+
+// CHECK-LABEL: func @complex_inner_if(
+func @complex_inner_if(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  // CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1
+  // CHECK-DAG: %[[CST:.*]] = constant 1 : i32
+  // CHECK: cond_br %[[TRUE]], ^bb1
+
+  %cst_20 = constant 20 : i32
+  %cond = cmpi "ult", %arg0, %cst_20 : i32
+  cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  // CHECK: ^bb1:
+  // CHECK: return %[[CST]] : i32
+
+  %cst_1 = constant 1 : i32
+  return %cst_1 : i32
+
+^bb2:
+  %cst_1_2 = constant 1 : i32
+  %arg_inc = addi %arg0, %cst_1_2 : i32
+  return %arg_inc : i32
+}
+
+func @complex_cond() -> i1
+
+// CHECK-LABEL: func @complex_callee(
+func @complex_callee(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+
+  %loop_cond = call @complex_cond() : () -> i1
+  cond_br %loop_cond, ^bb1, ^bb2
+
+^bb1:
+  // CHECK: ^bb1:
+  // CHECK-NEXT: return %[[CST]] : i32
+  return %arg0 : i32
+
+^bb2:
+  // CHECK: ^bb2:
+  // CHECK: call @complex_inner_if(%[[CST]]) : (i32) -> i32
+  // CHECK: call @complex_callee(%[[CST]]) : (i32) -> i32
+  // CHECK: return %[[CST]] : i32
+
+  %updated_arg = call @complex_inner_if(%arg0) : (i32) -> i32
+  %res = call @complex_callee(%updated_arg) : (i32) -> i32
+  return %res : i32
+}
+
+// CHECK-LABEL: func @complex_caller(
+func @complex_caller(%arg0 : i32) -> i32 {
+  // CHECK: %[[CST:.*]] = constant 1 : i32
+  // CHECK: return %[[CST]] : i32
+
+  %1 = constant 1 : i32
+  %result = call @complex_callee(%1) : (i32) -> i32
+  return %result : i32
+}
+
+// -----
+
+/// Check that non-symbol defining callables currently go to overdefined.
+
+// CHECK-LABEL: func @non_symbol_defining_callable
+func @non_symbol_defining_callable() -> i32 {
+  // CHECK: %[[RES:.*]] = call_indirect
+  // CHECK: return %[[RES]] : i32
+
+  %fn = "test.functional_region_op"() ({
+    %1 = constant 1 : i32
+    "test.return"(%1) : (i32) -> ()
+  }) : () -> (() -> i32)
+  %res = call_indirect %fn() : () -> (i32)
+  return %res : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 000a5722a76a..ad8c6fb99e67 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1090,7 +1090,7 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
 //===----------------------------------------------------------------------===//
 
 def TestRegionBuilderOp : TEST_Op<"region_builder">;
-def TestReturnOp : TEST_Op<"return", [Terminator]>,
+def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 def TestCastOp : TEST_Op<"cast">,
   Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;


        


More information about the Mlir-commits mailing list