[Mlir-commits] [mlir] c24f88b - [mlir][SCCP] Don't visit private callables unless they are used when tracking interprocedural arguments/results

River Riddle llvmlistbot at llvm.org
Thu Dec 10 12:57:22 PST 2020


Author: River Riddle
Date: 2020-12-10T12:53:27-08:00
New Revision: c24f88b4db2ef359f47e976d8d79334ced221288

URL: https://github.com/llvm/llvm-project/commit/c24f88b4db2ef359f47e976d8d79334ced221288
DIFF: https://github.com/llvm/llvm-project/commit/c24f88b4db2ef359f47e976d8d79334ced221288.diff

LOG: [mlir][SCCP] Don't visit private callables unless they are used when tracking interprocedural arguments/results

This fixes a subtle bug where SCCP could incorrectly optimize a private callable while waiting for its arguments to be resolved.

Fixes PR#48457

Differential Revision: https://reviews.llvm.org/D92976

Added: 
    

Modified: 
    mlir/lib/Transforms/SCCP.cpp
    mlir/test/Transforms/sccp-callgraph.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 919559c8a6df..9886331820e3 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -236,6 +236,11 @@ class SCCPSolver {
   /// state.
   void visitBlockArgument(Block *block, int i);
 
+  /// Mark the entry block of the given region as executable. Returns false if
+  /// the block was already marked executable. If `markArgsOverdefined` is true,
+  /// the arguments of the entry block are also set to overdefined.
+  bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined);
+
   /// Mark the given block as executable. Returns false if the block was already
   /// marked executable.
   bool markBlockExecutable(Block *block);
@@ -313,16 +318,9 @@ class SCCPSolver {
 SCCPSolver::SCCPSolver(Operation *op) {
   /// Initialize the solver with the regions within this operation.
   for (Region &region : op->getRegions()) {
-    if (region.empty())
-      continue;
-    Block *entryBlock = &region.front();
-
-    // Mark the entry block as executable.
-    markBlockExecutable(entryBlock);
-
-    // The values passed to these regions are invisible, so mark any arguments
-    // as overdefined.
-    markAllOverdefined(entryBlock->getArguments());
+    // Mark the entry block as executable. The values passed to these regions
+    // are also invisible, so mark any arguments as overdefined.
+    markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
   }
   initializeSymbolCallables(op);
 }
@@ -405,8 +403,10 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
 
       // If not all of the uses of this symbol are visible, we can't track the
       // state of the arguments.
-      if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
-        markAllOverdefined(callableRegion->getArguments());
+      if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
+        for (Region &region : callable->getRegions())
+          markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
+      }
     }
     if (callableLatticeState.empty())
       return;
@@ -443,8 +443,10 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
       // This use isn't a call, so don't we know all of the callers.
       auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
       auto it = callableLatticeState.find(symbol);
-      if (it != callableLatticeState.end())
-        markAllOverdefined(it->second.getCallableArguments());
+      if (it != callableLatticeState.end()) {
+        for (Region &region : it->first->getRegions())
+          markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
+      }
     }
   };
   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
@@ -495,8 +497,14 @@ void SCCPSolver::visitOperation(Operation *op) {
 
   // Process callable operations. These are specially handled region operations
   // that track dataflow via calls.
-  if (isa<CallableOpInterface>(op))
+  if (isa<CallableOpInterface>(op)) {
+    // If this callable has a tracked lattice state, it will be visited by calls
+    // that reference it instead. This way, we don't assume that it is
+    // executable unless there is a proper reference to it.
+    if (callableLatticeState.count(op))
+      return;
     return visitCallableOperation(op);
+  }
 
   // Process region holding operations. The region visitor processes result
   // values, so we can exit afterwards.
@@ -551,19 +559,11 @@ void SCCPSolver::visitOperation(Operation *op) {
 }
 
 void SCCPSolver::visitCallableOperation(Operation *op) {
-  // Mark the regions as executable.
+  // Mark the regions as executable. If we aren't tracking lattice state for
+  // this callable, mark all of the region arguments as overdefined.
   bool isTrackingLatticeState = callableLatticeState.count(op);
-  for (Region &region : op->getRegions()) {
-    if (region.empty())
-      continue;
-    Block *entryBlock = &region.front();
-    markBlockExecutable(entryBlock);
-
-    // If we aren't tracking lattice state for this callable, mark all of the
-    // region arguments as overdefined.
-    if (!isTrackingLatticeState)
-      markAllOverdefined(entryBlock->getArguments());
-  }
+  for (Region &region : op->getRegions())
+    markEntryBlockExecutable(&region, !isTrackingLatticeState);
 
   // TODO: Add support for non-symbol callables when necessary. If the callable
   // has non-call uses we would mark overdefined, otherwise allow for
@@ -599,6 +599,9 @@ void SCCPSolver::visitCallOperation(CallOpInterface op) {
       visitUsers(callableArg);
   }
 
+  // Visit the callable.
+  visitCallableOperation(callableOp);
+
   // Merge in the lattice state for the callable results as well.
   auto callableResults = callableLatticeIt->second.getResultLatticeValues();
   for (auto it : llvm::zip(callResults, callableResults))
@@ -613,13 +616,8 @@ void SCCPSolver::visitRegionOperation(Operation *op,
   auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
   if (!regionInterface) {
     // If we can't, conservatively mark all regions as executable.
-    for (Region &region : op->getRegions()) {
-      if (region.empty())
-        continue;
-      Block *entryBlock = &region.front();
-      markBlockExecutable(entryBlock);
-      markAllOverdefined(entryBlock->getArguments());
-    }
+    for (Region &region : op->getRegions())
+      markEntryBlockExecutable(&region, /*markArgsOverdefined=*/true);
 
     // Don't try to simulate the results of a region operation as we can't
     // guarantee that folding will be out-of-place. We don't allow in-place
@@ -856,6 +854,16 @@ void SCCPSolver::visitBlockArgument(Block *block, int i) {
     visitUsers(arg);
 }
 
+bool SCCPSolver::markEntryBlockExecutable(Region *region,
+                                          bool markArgsOverdefined) {
+  if (!region->empty()) {
+    if (markArgsOverdefined)
+      markAllOverdefined(region->front().getArguments());
+    return markBlockExecutable(&region->front());
+  }
+  return false;
+}
+
 bool SCCPSolver::markBlockExecutable(Block *block) {
   bool marked = executableBlocks.insert(block).second;
   if (marked)

diff  --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir
index 58279e7ba329..27ac6d5c7c26 100644
--- a/mlir/test/Transforms/sccp-callgraph.mlir
+++ b/mlir/test/Transforms/sccp-callgraph.mlir
@@ -140,7 +140,7 @@ func @unknown_terminator() -> i32 {
 /// Check that return values are overdefined when the constant conflicts.
 
 func private @callable(%arg0 : i32) -> i32 {
-  "unknown.return"(%arg0) : (i32) -> ()
+  return %arg0 : i32
 }
 
 // CHECK-LABEL: func @conflicting_constant(
@@ -255,3 +255,18 @@ func @non_symbol_defining_callable() -> i32 {
   %res = call_indirect %fn() : () -> (i32)
   return %res : i32
 }
+
+// -----
+
+/// Check that private callables don't get processed if they have no uses.
+
+// CHECK-LABEL: func private @unreferenced_private_function
+func private @unreferenced_private_function() -> i32 {
+  // CHECK: %[[RES:.*]] = select
+  // CHECK: return %[[RES]] : i32
+  %true = constant true
+  %cst0 = constant 0 : i32
+  %cst1 = constant 1 : i32
+  %result = select %true, %cst0, %cst1 : i32
+  return %result : i32
+}


        


More information about the Mlir-commits mailing list