[Mlir-commits] [mlir] c24f88b - [mlir][SCCP] Don't visit private callables unless they are used when tracking interprocedural arguments/results
River Riddle
llvmlistbot at llvm.org
Thu Dec 10 12:57:22 PST 2020
Author: River Riddle
Date: 2020-12-10T12:53:27-08:00
New Revision: c24f88b4db2ef359f47e976d8d79334ced221288
URL: https://github.com/llvm/llvm-project/commit/c24f88b4db2ef359f47e976d8d79334ced221288
DIFF: https://github.com/llvm/llvm-project/commit/c24f88b4db2ef359f47e976d8d79334ced221288.diff
LOG: [mlir][SCCP] Don't visit private callables unless they are used when tracking interprocedural arguments/results
This fixes a subtle bug where SCCP could incorrectly optimize a private callable while waiting for its arguments to be resolved.
Fixes PR#48457
Differential Revision: https://reviews.llvm.org/D92976
Added:
Modified:
mlir/lib/Transforms/SCCP.cpp
mlir/test/Transforms/sccp-callgraph.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 919559c8a6df..9886331820e3 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -236,6 +236,11 @@ class SCCPSolver {
/// state.
void visitBlockArgument(Block *block, int i);
+ /// Mark the entry block of the given region as executable. Returns false if
+ /// the block was already marked executable. If `markArgsOverdefined` is true,
+ /// the arguments of the entry block are also set to overdefined.
+ bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined);
+
/// Mark the given block as executable. Returns false if the block was already
/// marked executable.
bool markBlockExecutable(Block *block);
@@ -313,16 +318,9 @@ class SCCPSolver {
SCCPSolver::SCCPSolver(Operation *op) {
/// Initialize the solver with the regions within this operation.
for (Region ®ion : op->getRegions()) {
- if (region.empty())
- continue;
- Block *entryBlock = ®ion.front();
-
- // Mark the entry block as executable.
- markBlockExecutable(entryBlock);
-
- // The values passed to these regions are invisible, so mark any arguments
- // as overdefined.
- markAllOverdefined(entryBlock->getArguments());
+ // Mark the entry block as executable. The values passed to these regions
+ // are also invisible, so mark any arguments as overdefined.
+ markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
}
initializeSymbolCallables(op);
}
@@ -405,8 +403,10 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
// If not all of the uses of this symbol are visible, we can't track the
// state of the arguments.
- if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
- markAllOverdefined(callableRegion->getArguments());
+ if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
+ for (Region ®ion : callable->getRegions())
+ markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
+ }
}
if (callableLatticeState.empty())
return;
@@ -443,8 +443,10 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
// This use isn't a call, so don't we know all of the callers.
auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
auto it = callableLatticeState.find(symbol);
- if (it != callableLatticeState.end())
- markAllOverdefined(it->second.getCallableArguments());
+ if (it != callableLatticeState.end()) {
+ for (Region ®ion : it->first->getRegions())
+ markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
+ }
}
};
SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
@@ -495,8 +497,14 @@ void SCCPSolver::visitOperation(Operation *op) {
// Process callable operations. These are specially handled region operations
// that track dataflow via calls.
- if (isa<CallableOpInterface>(op))
+ if (isa<CallableOpInterface>(op)) {
+ // If this callable has a tracked lattice state, it will be visited by calls
+ // that reference it instead. This way, we don't assume that it is
+ // executable unless there is a proper reference to it.
+ if (callableLatticeState.count(op))
+ return;
return visitCallableOperation(op);
+ }
// Process region holding operations. The region visitor processes result
// values, so we can exit afterwards.
@@ -551,19 +559,11 @@ void SCCPSolver::visitOperation(Operation *op) {
}
void SCCPSolver::visitCallableOperation(Operation *op) {
- // Mark the regions as executable.
+ // Mark the regions as executable. If we aren't tracking lattice state for
+ // this callable, mark all of the region arguments as overdefined.
bool isTrackingLatticeState = callableLatticeState.count(op);
- for (Region ®ion : op->getRegions()) {
- if (region.empty())
- continue;
- Block *entryBlock = ®ion.front();
- markBlockExecutable(entryBlock);
-
- // If we aren't tracking lattice state for this callable, mark all of the
- // region arguments as overdefined.
- if (!isTrackingLatticeState)
- markAllOverdefined(entryBlock->getArguments());
- }
+ for (Region ®ion : op->getRegions())
+ markEntryBlockExecutable(®ion, !isTrackingLatticeState);
// TODO: Add support for non-symbol callables when necessary. If the callable
// has non-call uses we would mark overdefined, otherwise allow for
@@ -599,6 +599,9 @@ void SCCPSolver::visitCallOperation(CallOpInterface op) {
visitUsers(callableArg);
}
+ // Visit the callable.
+ visitCallableOperation(callableOp);
+
// Merge in the lattice state for the callable results as well.
auto callableResults = callableLatticeIt->second.getResultLatticeValues();
for (auto it : llvm::zip(callResults, callableResults))
@@ -613,13 +616,8 @@ void SCCPSolver::visitRegionOperation(Operation *op,
auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
if (!regionInterface) {
// If we can't, conservatively mark all regions as executable.
- for (Region ®ion : op->getRegions()) {
- if (region.empty())
- continue;
- Block *entryBlock = ®ion.front();
- markBlockExecutable(entryBlock);
- markAllOverdefined(entryBlock->getArguments());
- }
+ for (Region ®ion : op->getRegions())
+ markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
@@ -856,6 +854,16 @@ void SCCPSolver::visitBlockArgument(Block *block, int i) {
visitUsers(arg);
}
+bool SCCPSolver::markEntryBlockExecutable(Region *region,
+ bool markArgsOverdefined) {
+ if (!region->empty()) {
+ if (markArgsOverdefined)
+ markAllOverdefined(region->front().getArguments());
+ return markBlockExecutable(®ion->front());
+ }
+ return false;
+}
+
bool SCCPSolver::markBlockExecutable(Block *block) {
bool marked = executableBlocks.insert(block).second;
if (marked)
diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir
index 58279e7ba329..27ac6d5c7c26 100644
--- a/mlir/test/Transforms/sccp-callgraph.mlir
+++ b/mlir/test/Transforms/sccp-callgraph.mlir
@@ -140,7 +140,7 @@ func @unknown_terminator() -> i32 {
/// Check that return values are overdefined when the constant conflicts.
func private @callable(%arg0 : i32) -> i32 {
- "unknown.return"(%arg0) : (i32) -> ()
+ return %arg0 : i32
}
// CHECK-LABEL: func @conflicting_constant(
@@ -255,3 +255,18 @@ func @non_symbol_defining_callable() -> i32 {
%res = call_indirect %fn() : () -> (i32)
return %res : i32
}
+
+// -----
+
+/// Check that private callables don't get processed if they have no uses.
+
+// CHECK-LABEL: func private @unreferenced_private_function
+func private @unreferenced_private_function() -> i32 {
+ // CHECK: %[[RES:.*]] = select
+ // CHECK: return %[[RES]] : i32
+ %true = constant true
+ %cst0 = constant 0 : i32
+ %cst1 = constant 1 : i32
+ %result = select %true, %cst0, %cst1 : i32
+ return %result : i32
+}
More information about the Mlir-commits
mailing list