[Mlir-commits] [mlir] 0e98fb9 - [MLIR][transforms] Add an optimization pass to remove dead values
Srishti Srivastava
llvmlistbot at llvm.org
Wed Aug 23 16:54:50 PDT 2023
Author: Srishti Srivastava
Date: 2023-08-23T23:54:44Z
New Revision: 0e98fb9fadb056e90f2649601914942765d7bf09
URL: https://github.com/llvm/llvm-project/commit/0e98fb9fadb056e90f2649601914942765d7bf09
DIFF: https://github.com/llvm/llvm-project/commit/0e98fb9fadb056e90f2649601914942765d7bf09.diff
LOG: [MLIR][transforms] Add an optimization pass to remove dead values
Large deep learning models rely on heavy computations. However, not
every computation is necessary. And, even when a computation is
necessary, it helps if the values needed for the computation are
available in registers (which have low-latency) rather than being in
memory (which has high-latency).
Compilers can use liveness analysis to:-
(1) Remove extraneous computations from a program before it executes on
hardware, and,
(2) Optimize register allocation.
Both these tasks help achieve one very important goal: reducing runtime.
Recently, liveness analysis was added to MLIR. Thus, this commit uses
the recently added liveness analysis utility to try to accomplish task
(1).
It adds a pass called `remove-dead-values` whose goal is
optimization (reducing runtime) by removing unnecessary instructions.
Unlike other passes that rely on local information gathered from
patterns to accomplish optimization, this pass uses a full analysis of
the IR, specifically, liveness analysis, and is thus more powerful.
Currently, this pass performs the following optimizations:
(A) Removes function arguments that are not live,
(B) Removes function return values that are not live across all callers of
the function,
(C) Removes unneccesary operands, results, region arguments, region
terminator operands of region branch ops, and,
(D) Removes simple and region branch ops that have all non-live results and
don't affect memory in any way,
iff
the IR doesn't have any non-function symbol ops, non-call symbol user ops
and branch ops.
Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
region branch op, branch op, region branch terminator op, or return-like.
It is noteworthy that we do not refer to non-live values as "dead" in this
file to avoid confusing it with dead code analysis's "dead", which refers to
unreachable code (code that never executes on hardware) while "non-live"
refers to code that executes on hardware but is unnecessary. Thus, while the
removal of dead code helps little in reducing runtime, removing non-live
values should theoretically have significant impact (depending on the amount
removed).
It is also important to note that unlike other passes (like `canonicalize`)
that apply op-specific optimizations through patterns, this pass uses
different interfaces to handle various types of ops and tries to cover all
existing ops through these interfaces.
It is because of its reliance on (a) liveness analysis and (b) interfaces
that makes it so powerful that it can optimize ops that don't have a
canonicalizer and even when an op does have a canonicalizer, it can perform
more aggressive optimizations, as observed in the test files associated with
this pass.
Example of optimization (A):-
```
int add_2_to_y(int x, int y) {
return 2 + y
}
print(add_2_to_y(3, 4))
print(add_2_to_y(5, 6))
```
becomes
```
int add_2_to_y(int y) {
return 2 + y
}
print(add_2_to_y(4))
print(add_2_to_y(6))
```
Example of optimization (B):-
```
int, int get_incremented_values(int y) {
store y somewhere in memory
return y + 1, y + 2
}
y1, y2 = get_incremented_values(4)
y3, y4 = get_incremented_values(6)
print(y2)
```
becomes
```
int get_incremented_values(int y) {
store y somewhere in memory
return y + 2
}
y2 = get_incremented_values(4)
y4 = get_incremented_values(6)
print(y2)
```
Example of optimization (C):-
Assume only `%result1` is live here. Then,
```
%result1, %result2, %result3 = scf.while (%arg1 = %operand1, %arg2 = %operand2) {
%terminator_operand2 = add %arg2, %arg2
%terminator_operand3 = mul %arg2, %arg2
%terminator_operand4 = add %arg1, %arg1
scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3, %terminator_operand4
} do {
^bb0(%arg3, %arg4, %arg5):
%terminator_operand6 = add %arg4, %arg4
%terminator_operand5 = add %arg5, %arg5
scf.yield %terminator_operand5, %terminator_operand6
}
```
becomes
```
%result1, %result2 = scf.while (%arg2 = %operand2) {
%terminator_operand2 = add %arg2, %arg2
%terminator_operand3 = mul %arg2, %arg2
scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3
} do {
^bb0(%arg3, %arg4):
%terminator_operand6 = add %arg4, %arg4
scf.yield %terminator_operand6
}
```
It is interesting to see that `%result2` won't be removed even though it is
not live because `%terminator_operand3` forwards to it and cannot be
removed. And, that is because it also forwards to `%arg4`, which is live.
Example of optimization (D):-
```
int square_and_double_of_y(int y) {
square = y ^ 2
double = y * 2
return square, double
}
sq, do = square_and_double_of_y(5)
print(do)
```
becomes
```
int square_and_double_of_y(int y) {
double = y * 2
return double
}
do = square_and_double_of_y(5)
print(do)
```
Signed-off-by: Srishti Srivastava <srishtisrivastava.ai at gmail.com>
Reviewed By: matthiaskramm, Mogball, jcai19
Differential Revision: https://reviews.llvm.org/D157049
Added:
mlir/lib/Transforms/RemoveDeadValues.cpp
mlir/test/Transforms/remove-dead-values.mlir
Modified:
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9110b64d55a637..320932bb999561 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -19,6 +19,7 @@
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
#include <limits>
+#include <memory>
namespace mlir {
@@ -105,6 +106,9 @@ std::unique_ptr<Pass>
createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
std::function<void(OpPassManager &)> defaultPipelineBuilder);
+/// Creates an optimization pass to remove dead values.
+std::unique_ptr<Pass> createRemoveDeadValuesPass();
+
/// Creates a pass which performs sparse conditional constant propagation over
/// nested operations.
std::unique_ptr<Pass> createSCCPPass();
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 125ba2ffbac723..26d2ff3c30ded5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -85,6 +85,163 @@ def CSE : Pass<"cse"> {
];
}
+def RemoveDeadValues : Pass<"remove-dead-values"> {
+ let summary = "Remove dead values";
+ let description = [{
+ The goal of this pass is optimization (reducing runtime) by removing
+ unnecessary instructions. Unlike other passes that rely on local information
+ gathered from patterns to accomplish optimization, this pass uses a full
+ analysis of the IR, specifically, liveness analysis, and is thus more
+ powerful.
+
+ Currently, this pass performs the following optimizations:
+ (A) Removes function arguments that are not live,
+ (B) Removes function return values that are not live across all callers of
+ the function,
+ (C) Removes unneccesary operands, results, region arguments, and region
+ terminator operands of region branch ops, and,
+ (D) Removes simple and region branch ops that have all non-live results and
+ don't affect memory in any way,
+
+ iff
+
+ the IR doesn't have any non-function symbol ops, non-call symbol user ops
+ and branch ops.
+
+ Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
+ region branch op, branch op, region branch terminator op, or return-like.
+
+ It is noteworthy that we do not refer to non-live values as "dead" in this
+ file to avoid confusing it with dead code analysis's "dead", which refers to
+ unreachable code (code that never executes on hardware) while "non-live"
+ refers to code that executes on hardware but is unnecessary. Thus, while the
+ removal of dead code helps little in reducing runtime, removing non-live
+ values should theoretically have significant impact (depending on the amount
+ removed).
+
+ It is also important to note that unlike other passes (like `canonicalize`)
+ that apply op-specific optimizations through patterns, this pass uses
+
diff erent interfaces to handle various types of ops and tries to cover all
+ existing ops through these interfaces.
+
+ It is because of its reliance on (a) liveness analysis and (b) interfaces
+ that makes it so powerful that it can optimize ops that don't have a
+ canonicalizer and even when an op does have a canonicalizer, it can perform
+ more aggressive optimizations, as observed in the test files associated with
+ this pass.
+
+ Example of optimization (A):-
+
+ ```
+ int add_2_to_y(int x, int y) {
+ return 2 + y
+ }
+
+ print(add_2_to_y(3, 4))
+ print(add_2_to_y(5, 6))
+ ```
+
+ becomes
+
+ ```
+ int add_2_to_y(int y) {
+ return 2 + y
+ }
+
+ print(add_2_to_y(4))
+ print(add_2_to_y(6))
+ ```
+
+ Example of optimization (B):-
+
+ ```
+ int, int get_incremented_values(int y) {
+ store y somewhere in memory
+ return y + 1, y + 2
+ }
+
+ y1, y2 = get_incremented_values(4)
+ y3, y4 = get_incremented_values(6)
+ print(y2)
+ ```
+
+ becomes
+
+ ```
+ int get_incremented_values(int y) {
+ store y somewhere in memory
+ return y + 2
+ }
+
+ y2 = get_incremented_values(4)
+ y4 = get_incremented_values(6)
+ print(y2)
+ ```
+
+ Example of optimization (C):-
+
+ Assume only `%result1` is live here. Then,
+
+ ```
+ %result1, %result2, %result3 = scf.while (%arg1 = %operand1, %arg2 = %operand2) {
+ %terminator_operand2 = add %arg2, %arg2
+ %terminator_operand3 = mul %arg2, %arg2
+ %terminator_operand4 = add %arg1, %arg1
+ scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3, %terminator_operand4
+ } do {
+ ^bb0(%arg3, %arg4, %arg5):
+ %terminator_operand6 = add %arg4, %arg4
+ %terminator_operand5 = add %arg5, %arg5
+ scf.yield %terminator_operand5, %terminator_operand6
+ }
+ ```
+
+ becomes
+
+ ```
+ %result1, %result2 = scf.while (%arg2 = %operand2) {
+ %terminator_operand2 = add %arg2, %arg2
+ %terminator_operand3 = mul %arg2, %arg2
+ scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3
+ } do {
+ ^bb0(%arg3, %arg4):
+ %terminator_operand6 = add %arg4, %arg4
+ scf.yield %terminator_operand6
+ }
+ ```
+
+ It is interesting to see that `%result2` won't be removed even though it is
+ not live because `%terminator_operand3` forwards to it and cannot be
+ removed. And, that is because it also forwards to `%arg4`, which is live.
+
+ Example of optimization (D):-
+
+ ```
+ int square_and_double_of_y(int y) {
+ square = y ^ 2
+ double = y * 2
+ return square, double
+ }
+
+ sq, do = square_and_double_of_y(5)
+ print(do)
+ ```
+
+ becomes
+
+ ```
+ int square_and_double_of_y(int y) {
+ double = y * 2
+ return double
+ }
+
+ do = square_and_double_of_y(5)
+ print(do)
+ ```
+ }];
+ let constructor = "mlir::createRemoveDeadValuesPass()";
+}
+
def PrintIRPass : Pass<"print-ir"> {
let summary = "Print IR on the debug stream";
let description = [{
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 72dd7ab94e9097..641ae8d0befb5b 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_library(MLIRTransforms
Mem2Reg.cpp
OpStats.cpp
PrintIR.cpp
+ RemoveDeadValues.cpp
SCCP.cpp
SROA.cpp
StripDebugInfo.cpp
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
new file mode 100644
index 00000000000000..ad49a10b30b5ae
--- /dev/null
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -0,0 +1,619 @@
+//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The goal of this pass is optimization (reducing runtime) by removing
+// unnecessary instructions. Unlike other passes that rely on local information
+// gathered from patterns to accomplish optimization, this pass uses a full
+// analysis of the IR, specifically, liveness analysis, and is thus more
+// powerful.
+//
+// Currently, this pass performs the following optimizations:
+// (A) Removes function arguments that are not live,
+// (B) Removes function return values that are not live across all callers of
+// the function,
+// (C) Removes unneccesary operands, results, region arguments, and region
+// terminator operands of region branch ops, and,
+// (D) Removes simple and region branch ops that have all non-live results and
+// don't affect memory in any way,
+//
+// iff
+//
+// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
+// branch ops.
+//
+// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
+// region branch op, branch op, region branch terminator op, or return-like.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include <cassert>
+#include <cstddef>
+#include <memory>
+#include <optional>
+#include <vector>
+
+namespace mlir {
+#define GEN_PASS_DEF_REMOVEDEADVALUES
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+//===----------------------------------------------------------------------===//
+// RemoveDeadValues Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Some helper functions...
+
+/// Return true iff at least one value in `values` is live, given the liveness
+/// information in `la`.
+static bool hasLive(ValueRange values, RunLivenessAnalysis &la) {
+ for (Value value : values) {
+ // If there is a null value, it implies that it was dropped during the
+ // execution of this pass, implying that it was non-live.
+ if (!value)
+ continue;
+
+ const Liveness *liveness = la.getLiveness(value);
+ if (!liveness || liveness->isLive)
+ return true;
+ }
+ return false;
+}
+
+/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
+/// i-th value in `values` is live, given the liveness information in `la`.
+static BitVector markLives(ValueRange values, RunLivenessAnalysis &la) {
+ BitVector lives(values.size(), true);
+
+ for (auto [index, value] : llvm::enumerate(values)) {
+ if (!value) {
+ lives.reset(index);
+ continue;
+ }
+
+ const Liveness *liveness = la.getLiveness(value);
+ // It is important to note that when `liveness` is null, we can't tell if
+ // `value` is live or not. So, the safe option is to consider it live. Also,
+ // the execution of this pass might create new SSA values when erasing some
+ // of the results of an op and we know that these new values are live
+ // (because they weren't erased) and also their liveness is null because
+ // liveness analysis ran before their creation.
+ if (liveness && !liveness->isLive)
+ lives.reset(index);
+ }
+
+ return lives;
+}
+
+/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
+/// is 1.
+static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+ assert(op->getNumResults() == toErase.size() &&
+ "expected the number of results in `op` and the size of `toErase` to "
+ "be the same");
+
+ std::vector<Type> newResultTypes;
+ for (OpResult result : op->getResults())
+ if (!toErase[result.getResultNumber()])
+ newResultTypes.push_back(result.getType());
+ OpBuilder builder(op);
+ builder.setInsertionPointAfter(op);
+ OperationState state(op->getLoc(), op->getName().getStringRef(),
+ op->getOperands(), newResultTypes, op->getAttrs());
+ for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+ state.addRegion();
+ Operation *newOp = builder.create(state);
+ for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+ Region &newRegion = newOp->getRegion(index);
+ IRMapping mapping;
+ region.cloneInto(&newRegion, mapping);
+ }
+
+ unsigned indexOfNextNewCallOpResultToReplace = 0;
+ for (auto [index, result] : llvm::enumerate(op->getResults())) {
+ assert(result && "expected result to be non-null");
+ if (toErase[index]) {
+ result.dropAllUses();
+ } else {
+ result.replaceAllUsesWith(
+ newOp->getResult(indexOfNextNewCallOpResultToReplace++));
+ }
+ }
+ op->erase();
+}
+
+/// Convert a list of `Operand`s to a list of `OpOperand`s.
+static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
+ OpOperand *values = operands.getBase();
+ SmallVector<OpOperand *> opOperands;
+ for (unsigned i = 0, e = operands.size(); i < e; i++)
+ opOperands.push_back(&values[i]);
+ return opOperands;
+}
+
+/// Clean a simple op `op`, given the liveness analysis information in `la`.
+/// Here, cleaning means:
+/// (1) Dropping all its uses, AND
+/// (2) Erasing it
+/// iff it has no memory effects and none of its results are live.
+///
+/// It is assumed that `op` is simple. Here, a simple op is one which isn't a
+/// symbol op, a symbol-user op, a region branch op, a branch op, a region
+/// branch terminator op, or return-like.
+static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) {
+ if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la))
+ return;
+
+ op->dropAllUses();
+ op->erase();
+}
+
+/// Clean a function-like op `funcOp`, given the liveness information in `la`
+/// and the IR in `module`. Here, cleaning means:
+/// (1) Dropping the uses of its unnecessary (non-live) arguments,
+/// (2) Erasing these arguments,
+/// (3) Erasing their corresponding operands from its callers,
+/// (4) Erasing its unnecessary terminator operands (return values that are
+/// non-live across all callers),
+/// (5) Dropping the uses of these return values from its callers, AND
+/// (6) Erasing these return values
+/// iff it is not public.
+static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
+ RunLivenessAnalysis &la) {
+ if (funcOp.isPublic())
+ return;
+
+ // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
+ SmallVector<Value> arguments(funcOp.getArguments());
+ BitVector nonLiveArgs = markLives(arguments, la);
+ nonLiveArgs = nonLiveArgs.flip();
+
+ // Do (1).
+ for (auto [index, arg] : llvm::enumerate(arguments))
+ if (arg && nonLiveArgs[index])
+ arg.dropAllUses();
+
+ // Do (2).
+ funcOp.eraseArguments(nonLiveArgs);
+
+ // Do (3).
+ SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
+ for (SymbolTable::SymbolUse use : uses) {
+ Operation *callOp = use.getUser();
+ assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
+ // The number of operands in the call op may not match the number of
+ // arguments in the func op.
+ BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
+ SmallVector<OpOperand *> callOpOperands =
+ operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
+ for (int index : nonLiveArgs.set_bits())
+ nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
+ callOp->eraseOperands(nonLiveCallOperands);
+ }
+
+ // Get the list of unnecessary terminator operands (return values that are
+ // non-live across all callers) in `nonLiveRets`. There is a very important
+ // subtlety here. Unnecessary terminator operands are NOT the operands of the
+ // terminator that are non-live. Instead, these are the return values of the
+ // callers such that a given return value is non-live across all callers. Such
+ // corresponding operands in the terminator could be live. An example to
+ // demonstrate this:
+ // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
+ // %c0_i32 = arith.constant 0 : i32
+ // %0 = arith.addi %c0_i32, %c0_i32 : i32
+ // memref.store %0, %arg0[] : memref<i32>
+ // return %c0_i32, %0 : i32, i32
+ // }
+ // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
+ // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
+ // return %1#0 : i32
+ // }
+ // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
+ // need to return %0. But, %0 is live. And, still, we want to stop it from
+ // being returned, in order to optimize our IR. So, this demonstrates how we
+ // can make our optimization strong by even removing a live return value (%0),
+ // since it forwards only to non-live value(s) (%1#1).
+ Operation *lastReturnOp = funcOp.back().getTerminator();
+ size_t numReturns = lastReturnOp->getNumOperands();
+ BitVector nonLiveRets(numReturns, true);
+ for (SymbolTable::SymbolUse use : uses) {
+ Operation *callOp = use.getUser();
+ assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
+ BitVector liveCallRets = markLives(callOp->getResults(), la);
+ nonLiveRets &= liveCallRets.flip();
+ }
+
+ // Do (4).
+ // Note that in the absence of control flow ops forcing the control to go from
+ // the entry (first) block to the other blocks, the control never reaches any
+ // block other than the entry block, because every block has a terminator.
+ for (Block &block : funcOp.getBlocks()) {
+ Operation *returnOp = block.getTerminator();
+ if (returnOp && returnOp->getNumOperands() == numReturns)
+ returnOp->eraseOperands(nonLiveRets);
+ }
+ funcOp.eraseResults(nonLiveRets);
+
+ // Do (5) and (6).
+ for (SymbolTable::SymbolUse use : uses) {
+ Operation *callOp = use.getUser();
+ assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
+ dropUsesAndEraseResults(callOp, nonLiveRets);
+ }
+}
+
+/// Clean a region branch op `regionBranchOp`, given the liveness information in
+/// `la`. Here, cleaning means:
+/// (1') Dropping all its uses, AND
+/// (2') Erasing it
+/// if it has no memory effects and none of its results are live, AND
+/// (1) Erasing its unnecessary operands (operands that are forwarded to
+/// unneccesary results and arguments),
+/// (2) Cleaning each of its regions,
+/// (3) Dropping the uses of its unnecessary results (results that are
+/// forwarded from unnecessary operands and terminator operands), AND
+/// (4) Erasing these results
+/// otherwise.
+/// Note that here, cleaning a region means:
+/// (2.a) Dropping the uses of its unnecessary arguments (arguments that are
+/// forwarded from unneccesary operands and terminator operands),
+/// (2.b) Erasing these arguments, AND
+/// (2.c) Erasing its unnecessary terminator operands (terminator operands
+/// that are forwarded to unneccesary results and arguments).
+/// It is important to note that values in this op flow from operands and
+/// terminator operands (successor operands) to arguments and results (successor
+/// inputs).
+static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
+ RunLivenessAnalysis &la) {
+ // Mark live results of `regionBranchOp` in `liveResults`.
+ auto markLiveResults = [&](BitVector &liveResults) {
+ liveResults = markLives(regionBranchOp->getResults(), la);
+ };
+
+ // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
+ auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ SmallVector<Value> arguments(region.front().getArguments());
+ BitVector regionLiveArgs = markLives(arguments, la);
+ liveArgs[®ion] = regionLiveArgs;
+ }
+ };
+
+ // Return the successors of `region` if the latter is not null. Else return
+ // the successors of `regionBranchOp`.
+ auto getSuccessors = [&](Region *region = nullptr) {
+ std::optional<unsigned> index =
+ region ? std::optional(region->getRegionNumber()) : std::nullopt;
+ SmallVector<Attribute> operandAttributes(regionBranchOp->getNumOperands(),
+ nullptr);
+ SmallVector<RegionSuccessor> successors;
+ if (!index)
+ regionBranchOp.getEntrySuccessorRegions(operandAttributes, successors);
+ else
+ regionBranchOp.getSuccessorRegions(index, successors);
+ return successors;
+ };
+
+ // Return the operands of `terminator` that are forwarded to `successor` if
+ // the former is not null. Else return the operands of `regionBranchOp`
+ // forwarded to `successor`.
+ auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
+ Operation *terminator = nullptr) {
+ Region *successorRegion = successor.getSuccessor();
+ std::optional<unsigned> index =
+ successorRegion ? std::optional(successorRegion->getRegionNumber())
+ : std::nullopt;
+ OperandRange operands =
+ terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
+ .getSuccessorOperands(index)
+ : regionBranchOp.getEntrySuccessorOperands(index);
+ SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
+ return opOperands;
+ };
+
+ // Mark the non-forwarded operands of `regionBranchOp` in
+ // `nonForwardedOperands`.
+ auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
+ nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
+ for (const RegionSuccessor &successor : getSuccessors()) {
+ for (OpOperand *opOperand : getForwardedOpOperands(successor))
+ nonForwardedOperands.reset(opOperand->getOperandNumber());
+ }
+ };
+
+ // Mark the non-forwarded terminator operands of the various regions of
+ // `regionBranchOp` in `nonForwardedRets`.
+ auto markNonForwardedReturnValues =
+ [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ Operation *terminator = region.front().getTerminator();
+ nonForwardedRets[terminator] =
+ BitVector(terminator->getNumOperands(), true);
+ for (const RegionSuccessor &successor : getSuccessors(®ion)) {
+ for (OpOperand *opOperand :
+ getForwardedOpOperands(successor, terminator))
+ nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
+ }
+ }
+ };
+
+ // Update `valuesToKeep` (which is expected to correspond to operands or
+ // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
+ // `region`. When `valuesToKeep` correspond to operands, `region` is null.
+ // Else, `region` is the parent region of the terminator.
+ auto updateOperandsOrTerminatorOperandsToKeep =
+ [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
+ DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
+ Operation *terminator =
+ region ? region->front().getTerminator() : nullptr;
+
+ for (const RegionSuccessor &successor : getSuccessors(region)) {
+ Region *successorRegion = successor.getSuccessor();
+ for (auto [opOperand, input] :
+ llvm::zip(getForwardedOpOperands(successor, terminator),
+ successor.getSuccessorInputs())) {
+ size_t operandNum = opOperand->getOperandNumber();
+ bool updateBasedOn =
+ successorRegion
+ ? argsToKeep[successorRegion]
+ [cast<BlockArgument>(input).getArgNumber()]
+ : resultsToKeep[cast<OpResult>(input).getResultNumber()];
+ valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
+ }
+ }
+ };
+
+ // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
+ // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
+ // value is modified, else, false.
+ auto recomputeResultsAndArgsToKeep =
+ [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
+ BitVector &operandsToKeep,
+ DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
+ bool &resultsOrArgsToKeepChanged) {
+ resultsOrArgsToKeepChanged = false;
+
+ // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
+ for (const RegionSuccessor &successor : getSuccessors()) {
+ Region *successorRegion = successor.getSuccessor();
+ for (auto [opOperand, input] :
+ llvm::zip(getForwardedOpOperands(successor),
+ successor.getSuccessorInputs())) {
+ bool recomputeBasedOn =
+ operandsToKeep[opOperand->getOperandNumber()];
+ bool toRecompute =
+ successorRegion
+ ? argsToKeep[successorRegion]
+ [cast<BlockArgument>(input).getArgNumber()]
+ : resultsToKeep[cast<OpResult>(input).getResultNumber()];
+ if (!toRecompute && recomputeBasedOn)
+ resultsOrArgsToKeepChanged = true;
+ if (successorRegion) {
+ argsToKeep[successorRegion][cast<BlockArgument>(input)
+ .getArgNumber()] =
+ argsToKeep[successorRegion]
+ [cast<BlockArgument>(input).getArgNumber()] |
+ recomputeBasedOn;
+ } else {
+ resultsToKeep[cast<OpResult>(input).getResultNumber()] =
+ resultsToKeep[cast<OpResult>(input).getResultNumber()] |
+ recomputeBasedOn;
+ }
+ }
+ }
+
+ // Recompute `resultsToKeep` and `argsToKeep` based on
+ // `terminatorOperandsToKeep`.
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ Operation *terminator = region.front().getTerminator();
+ for (const RegionSuccessor &successor : getSuccessors(®ion)) {
+ Region *successorRegion = successor.getSuccessor();
+ for (auto [opOperand, input] :
+ llvm::zip(getForwardedOpOperands(successor, terminator),
+ successor.getSuccessorInputs())) {
+ bool recomputeBasedOn =
+ terminatorOperandsToKeep[region.back().getTerminator()]
+ [opOperand->getOperandNumber()];
+ bool toRecompute =
+ successorRegion
+ ? argsToKeep[successorRegion]
+ [cast<BlockArgument>(input).getArgNumber()]
+ : resultsToKeep[cast<OpResult>(input).getResultNumber()];
+ if (!toRecompute && recomputeBasedOn)
+ resultsOrArgsToKeepChanged = true;
+ if (successorRegion) {
+ argsToKeep[successorRegion][cast<BlockArgument>(input)
+ .getArgNumber()] =
+ argsToKeep[successorRegion]
+ [cast<BlockArgument>(input).getArgNumber()] |
+ recomputeBasedOn;
+ } else {
+ resultsToKeep[cast<OpResult>(input).getResultNumber()] =
+ resultsToKeep[cast<OpResult>(input).getResultNumber()] |
+ recomputeBasedOn;
+ }
+ }
+ }
+ }
+ };
+
+ // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
+ // `operandsToKeep`, and `terminatorOperandsToKeep`.
+ auto markValuesToKeep =
+ [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
+ BitVector &operandsToKeep,
+ DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
+ bool resultsOrArgsToKeepChanged = true;
+ // We keep updating and recomputing the values until we reach a point
+ // where they stop changing.
+ while (resultsOrArgsToKeepChanged) {
+ // Update the operands that need to be kept.
+ updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
+ resultsToKeep, argsToKeep);
+
+ // Update the terminator operands that need to be kept.
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ updateOperandsOrTerminatorOperandsToKeep(
+ terminatorOperandsToKeep[region.back().getTerminator()],
+ resultsToKeep, argsToKeep, ®ion);
+ }
+
+ // Recompute the results and arguments that need to be kept.
+ recomputeResultsAndArgsToKeep(
+ resultsToKeep, argsToKeep, operandsToKeep,
+ terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
+ }
+ };
+
+ // Do (1') and (2'). This is the only case where the entire `regionBranchOp`
+ // is removed. It will not happen in any other scenario. Note that in this
+ // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
+ // It could never be live because of this op but its liveness could have been
+ // attributed to something else.
+ if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
+ !hasLive(regionBranchOp->getResults(), la)) {
+ regionBranchOp->dropAllUses();
+ regionBranchOp->erase();
+ return;
+ }
+
+ // At this point, we know that every non-forwarded operand of `regionBranchOp`
+ // is live.
+
+ // Stores the results of `regionBranchOp` that we want to keep.
+ BitVector resultsToKeep;
+ // Stores the mapping from regions of `regionBranchOp` to their arguments that
+ // we want to keep.
+ DenseMap<Region *, BitVector> argsToKeep;
+ // Stores the operands of `regionBranchOp` that we want to keep.
+ BitVector operandsToKeep;
+ // Stores the mapping from region terminators in `regionBranchOp` to their
+ // operands that we want to keep.
+ DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
+
+ // Initializing the above variables...
+
+ // The live results of `regionBranchOp` definitely need to be kept.
+ markLiveResults(resultsToKeep);
+ // Similarly, the live arguments of the regions in `regionBranchOp` definitely
+ // need to be kept.
+ markLiveArgs(argsToKeep);
+ // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
+ // A live forwarded operand can be removed but no non-forwarded operand can be
+ // removed since it "controls" the flow of data in this control flow op.
+ markNonForwardedOperands(operandsToKeep);
+ // Similarly, the non-forwarded terminator operands of the regions in
+ // `regionBranchOp` definitely need to be kept.
+ markNonForwardedReturnValues(terminatorOperandsToKeep);
+
+ // Mark the values (results, arguments, operands, and terminator operands)
+ // that we want to keep.
+ markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
+ terminatorOperandsToKeep);
+
+ // Do (1).
+ regionBranchOp->eraseOperands(operandsToKeep.flip());
+
+ // Do (2.a) and (2.b).
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ assert(!region.empty() && "expected a non-empty region in an op "
+ "implementing `RegionBranchOpInterface`");
+ for (auto [index, arg] : llvm::enumerate(region.front().getArguments())) {
+ if (argsToKeep[®ion][index])
+ continue;
+ if (arg)
+ arg.dropAllUses();
+ }
+ region.front().eraseArguments(argsToKeep[®ion].flip());
+ }
+
+ // Do (2.c).
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ Operation *terminator = region.front().getTerminator();
+ terminator->eraseOperands(terminatorOperandsToKeep[terminator].flip());
+ }
+
+ // Do (3) and (4).
+ dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip());
+}
+
+struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void RemoveDeadValues::runOnOperation() {
+ auto &la = getAnalysis<RunLivenessAnalysis>();
+ Operation *module = getOperation();
+
+ // The removal of non-live values is performed iff there are no branch ops,
+ // all symbol ops present in the IR are function-like, and all symbol user ops
+ // present in the IR are call-like.
+ WalkResult acceptableIR = module->walk([&](Operation *op) {
+ if (isa<BranchOpInterface>(op) ||
+ (isa<SymbolOpInterface>(op) && !isa<FunctionOpInterface>(op)) ||
+ (isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
+ op->emitError() << "cannot optimize an IR with non-function symbol ops, "
+ "non-call symbol user ops or branch ops\n";
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ if (acceptableIR.wasInterrupted())
+ return;
+
+ module->walk([&](Operation *op) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ cleanFuncOp(funcOp, module, la);
+ } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ cleanRegionBranchOp(regionBranchOp, la);
+ } else if (op->hasTrait<OpTrait::ReturnLike>()) {
+ // Nothing to do because this terminator is associated with either a
+ // function op or a region branch op and gets cleaned when these ops are
+ // cleaned.
+ } else if (isa<RegionBranchTerminatorOpInterface>(op)) {
+ // Nothing to do because this terminator is associated with a region
+ // branch op and gets cleaned when the latter is cleaned.
+ } else if (isa<CallOpInterface>(op)) {
+ // Nothing to do because this op is associated with a function op and gets
+ // cleaned when the latter is cleaned.
+ } else {
+ cleanSimpleOp(op, la);
+ }
+ });
+}
+
+std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
+ return std::make_unique<RemoveDeadValues>();
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
new file mode 100644
index 00000000000000..22b66b464ac69a
--- /dev/null
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -0,0 +1,330 @@
+// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s
+
+// The IR remains untouched because of the presence of a non-function-like
+// symbol op (module @dont_touch_unacceptable_ir).
+//
+// expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}}
+module @dont_touch_unacceptable_ir {
+ func.func @has_cleanable_simple_op(%arg0 : i32) {
+ %non_live = arith.addi %arg0, %arg0 : i32
+ return
+ }
+}
+
+// -----
+
+// The IR remains untouched because of the presence of a branch op `cf.cond_br`.
+//
+func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
+ %non_live = arith.constant 0 : i32
+ // expected-error @+1 {{cannot optimize an IR with non-function symbol ops, non-call symbol user ops or branch ops}}
+ cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
+^bb1(%non_live_0 : i32):
+ cf.br ^bb3
+^bb2(%non_live_1 : i32):
+ cf.br ^bb3
+^bb3:
+ return
+}
+
+// -----
+
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func.func @main(%[[arg0:.*]]: i32) {
+// CHECK-NEXT: call @clean_func_op_remove_argument_and_return_value() : () -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func private @clean_func_op_remove_argument_and_return_value(%arg0: i32) -> (i32) {
+ return %arg0 : i32
+}
+func.func @main(%arg0 : i32) {
+ %non_live = func.call @clean_func_op_remove_argument_and_return_value(%arg0) : (i32) -> (i32)
+ return
+}
+
+// -----
+
+// %arg0 is not live because it is never used. %arg1 is not live because its
+// user `arith.addi` doesn't have any uses and the value that it is forwarded to
+// (%non_live_0) also doesn't have any uses.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK-LABEL: func.func private @clean_func_op_remove_arguments() -> i32 {
+// CHECK-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-NEXT: return %[[c0]]
+// CHECK-NEXT: }
+// CHECK: func.func @main(%[[arg2:.*]]: memref<i32>, %[[arg3:.*]]: i32, %[[DEVICE:.*]]: i32) -> (i32, memref<i32>) {
+// CHECK-NEXT: %[[live:.*]] = test.call_on_device @clean_func_op_remove_arguments(), %[[DEVICE]] : (i32) -> i32
+// CHECK-NEXT: return %[[live]], %[[arg2]]
+// CHECK-NEXT: }
+func.func private @clean_func_op_remove_arguments(%arg0 : memref<i32>, %arg1 : i32) -> (i32, i32) {
+ %c0 = arith.constant 0 : i32
+ %non_live = arith.addi %arg1, %arg1 : i32
+ return %c0, %arg1 : i32, i32
+}
+func.func @main(%arg2 : memref<i32>, %arg3 : i32, %device : i32) -> (i32, memref<i32>) {
+ %live, %non_live_0 = test.call_on_device @clean_func_op_remove_arguments(%arg2, %arg3), %device : (memref<i32>, i32, i32) -> (i32, i32)
+ return %live, %arg2 : i32, memref<i32>
+}
+
+// -----
+
+// Even though %non_live_0 is not live, the first return value of
+// @clean_func_op_remove_return_values isn't removed because %live is live
+// (liveness is checked across all callers).
+//
+// Also, the second return value of @clean_func_op_remove_return_values is
+// removed despite %c0 being live because neither %non_live nor %non_live_1 were
+// live (removal doesn't depend on the liveness of the operand itself but on the
+// liveness of where it is forwarded).
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK: func.func private @clean_func_op_remove_return_values(%[[arg0:.*]]: memref<i32>) -> i32 {
+// CHECK-NEXT: %[[c0]] = arith.constant 0
+// CHECK-NEXT: memref.store %[[c0]], %[[arg0]][]
+// CHECK-NEXT: return %[[c0]]
+// CHECK-NEXT: }
+// CHECK: func.func @main(%[[arg1:.*]]: memref<i32>) -> i32 {
+// CHECK-NEXT: %[[live:.*]] = call @clean_func_op_remove_return_values(%[[arg1]]) : (memref<i32>) -> i32
+// CHECK-NEXT: %[[non_live_0:.*]] = call @clean_func_op_remove_return_values(%[[arg1]]) : (memref<i32>) -> i32
+// CHECK-NEXT: return %[[live]] : i32
+// CHECK-NEXT: }
+func.func private @clean_func_op_remove_return_values(%arg0 : memref<i32>) -> (i32, i32) {
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %arg0[] : memref<i32>
+ return %c0, %c0 : i32, i32
+}
+func.func @main(%arg1 : memref<i32>) -> (i32) {
+ %live, %non_live = func.call @clean_func_op_remove_return_values(%arg1) : (memref<i32>) -> (i32, i32)
+ %non_live_0, %non_live_1 = func.call @clean_func_op_remove_return_values(%arg1) : (memref<i32>) -> (i32, i32)
+ return %live : i32
+}
+
+// -----
+
+// None of the return values of @clean_func_op_dont_remove_return_values can be
+// removed because the first one is forwarded to a live value %live and the
+// second one is forwarded to a live value %live_0.
+//
+// CHECK-LABEL: func.func private @clean_func_op_dont_remove_return_values() -> (i32, i32) {
+// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : i32
+// CHECK-NEXT: return %[[c0]], %[[c0]] : i32, i32
+// CHECK-NEXT: }
+// CHECK-LABEL: func.func @main() -> (i32, i32) {
+// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = call @clean_func_op_dont_remove_return_values() : () -> (i32, i32)
+// CHECK-NEXT: %[[non_live_0_and_live_0:.*]]:2 = call @clean_func_op_dont_remove_return_values() : () -> (i32, i32)
+// CHECK-NEXT: return %[[live_and_non_live]]#0, %[[non_live_0_and_live_0]]#1 : i32, i32
+// CHECK-NEXT: }
+func.func private @clean_func_op_dont_remove_return_values() -> (i32, i32) {
+ %c0 = arith.constant 0 : i32
+ return %c0, %c0 : i32, i32
+}
+func.func @main() -> (i32, i32) {
+ %live, %non_live = func.call @clean_func_op_dont_remove_return_values() : () -> (i32, i32)
+ %non_live_0, %live_0 = func.call @clean_func_op_dont_remove_return_values() : () -> (i32, i32)
+ return %live, %live_0 : i32, i32
+}
+
+// -----
+
+// Values kept:
+// (1) %non_live is not live. Yet, it is kept because %arg4 in `scf.condition`
+// forwards to it, which has to be kept. %arg4 in `scf.condition` has to be
+// kept because it forwards to %arg6 which is live.
+//
+// (2) %arg5 is not live. Yet, it is kept because %live_0 forwards to it, which
+// also forwards to %live, which is live.
+//
+// Values not kept:
+// (1) %arg1 is not kept as an operand of `scf.while` because it only forwards
+// to %arg3, which is not kept. %arg3 is not kept because %arg3 is not live and
+// only %arg1 and %arg7 forward to it, such that neither of them forward
+// anywhere else. Thus, %arg7 is also not kept in the `scf.yield` op.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
+// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
+// CHECK-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[live_and_non_live]]#0
+// CHECK-NEXT: }
+func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) {
+ %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) {
+ %live_0 = arith.addi %arg4, %arg4 : i32
+ %non_live_1 = arith.addi %arg3, %arg3 : i32
+ scf.condition(%arg0) %live_0, %arg4, %non_live_1 : i32, i32, i32
+ } do {
+ ^bb0(%arg5: i32, %arg6: i32, %arg7: i32):
+ %live_1 = arith.addi %arg6, %arg6 : i32
+ scf.yield %arg7, %live_1 : i32, i32
+ }
+ return %live : i32
+}
+
+// -----
+
+// Values kept:
+// (1) %live is kept because it is live.
+//
+// (2) %non_live is not live. Yet, it is kept because %arg3 in `scf.condition`
+// forwards to it and this %arg3 has to be kept. This %arg3 in `scf.condition`
+// has to be kept because it forwards to %arg6, which forwards to %arg4, which
+// forwards to %live, which is live.
+//
+// Values not kept:
+// (1) %non_live_0 is not kept because %non_live_2 in `scf.condition` forwards
+// to it, which forwards to only %non_live_0 and %arg7, where both these are
+// not live and have no other value forwarding to them.
+//
+// (2) %non_live_1 is not kept because %non_live_3 in `scf.condition` forwards
+// to it, which forwards to only %non_live_1 and %arg8, where both these are
+// not live and have no other value forwarding to them.
+//
+// (3) %c2 is not kept because it only forwards to %arg10, which is not kept.
+//
+// (4) %arg10 is not kept because only %c2 and %non_live_4 forward to it, none
+// of them forward anywhere else, and %arg10 is not.
+//
+// (5) %arg7 and %arg8 are not kept because they are not live, %non_live_2 and
+// %non_live_3 forward to them, and both only otherwise forward to %non_live_0
+// and %non_live_1 which are not live and have no other predecessors.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
+// CHECK-NEXT: %[[c0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[c1:.*]] = arith.constant 1
+// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
+// CHECK-NEXT: } do {
+// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32
+// CHECK-NEXT: }
+func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %live, %non_live, %non_live_0, %non_live_1 = scf.while (%arg3 = %c0, %arg4 = %c1, %arg10 = %c2) : (i32, i32, i32) -> (i32, i32, i32, i32) {
+ %non_live_2 = arith.addi %arg10, %arg10 : i32
+ %non_live_3 = arith.muli %arg10, %arg10 : i32
+ scf.condition(%arg2) %arg4, %arg3, %non_live_2, %non_live_3 : i32, i32, i32, i32
+ } do {
+ ^bb0(%arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32):
+ %non_live_4 = arith.addi %arg7, %arg8 :i32
+ scf.yield %arg5, %arg6, %non_live_4 : i32, i32, i32
+ }
+ return %live : i32
+}
+
+// -----
+
+// The op isn't erased because it has memory effects but its unnecessary result
+// is removed.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-NEXT: scf.index_switch %[[arg0]]
+// CHECK-NEXT: case 1 {
+// CHECK-NEXT: %[[c10:.*]] = arith.constant 10
+// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][]
+// CHECK-NEXT: scf.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: default {
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) {
+ %non_live = scf.index_switch %arg0 -> i32
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10 : i32
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ scf.yield %c11 : i32
+ }
+ return
+}
+
+// -----
+
+// The simple ops which don't have memory effects or live results get removed.
+// %arg5 doesn't get removed from the @main even though it isn't live because
+// the signature of a public function is always left untouched.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK: func.func private @clean_simple_ops(%[[arg0:.*]]: i32, %[[arg1:.*]]: memref<i32>)
+// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg0]], %[[arg0]]
+// CHECK-NEXT: %[[c2:.*]] = arith.constant 2
+// CHECK-NEXT: %[[live_1:.*]] = arith.muli %[[live_0]], %[[c2]]
+// CHECK-NEXT: %[[c3:.*]] = arith.constant 3
+// CHECK-NEXT: %[[live_2:.*]] = arith.addi %[[arg0]], %[[c3]]
+// CHECK-NEXT: memref.store %[[live_2]], %[[arg1]][]
+// CHECK-NEXT: return %[[live_1]]
+// CHECK-NEXT: }
+// CHECK: func.func @main(%[[arg3:.*]]: i32, %[[arg4:.*]]: memref<i32>, %[[arg5:.*]]
+// CHECK-NEXT: %[[live:.*]] = call @clean_simple_ops(%[[arg3]], %[[arg4]])
+// CHECK-NEXT: return %[[live]]
+// CHECK-NEXT: }
+func.func private @clean_simple_ops(%arg0 : i32, %arg1 : memref<i32>, %arg2 : i32) -> (i32, i32, i32, i32) {
+ %live_0 = arith.addi %arg0, %arg0 : i32
+ %c2 = arith.constant 2 : i32
+ %live_1 = arith.muli %live_0, %c2 : i32
+ %non_live_1 = arith.addi %live_1, %live_0 : i32
+ %non_live_2 = arith.constant 7 : i32
+ %non_live_3 = arith.subi %arg0, %non_live_1 : i32
+ %c3 = arith.constant 3 : i32
+ %live_2 = arith.addi %arg0, %c3 : i32
+ memref.store %live_2, %arg1[] : memref<i32>
+ return %live_1, %non_live_1, %non_live_2, %non_live_3 : i32, i32, i32, i32
+}
+
+func.func @main(%arg3 : i32, %arg4 : memref<i32>, %arg5 : i32) -> (i32) {
+ %live, %non_live_1, %non_live_2, %non_live_3 = func.call @clean_simple_ops(%arg3, %arg4, %arg5) : (i32, memref<i32>, i32) -> (i32, i32, i32, i32)
+ return %live : i32
+}
+
+// -----
+
+// The scf.while op has no memory effects and its result isn't live.
+//
+// Note that this cleanup cannot be done by the `canonicalize` pass.
+//
+// CHECK-LABEL: func.func private @clean_region_branch_op_erase_it() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func.func @main(%[[arg3:.*]]: i32, %[[arg4:.*]]: i1) {
+// CHECK-NEXT: call @clean_region_branch_op_erase_it() : () -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func private @clean_region_branch_op_erase_it(%arg0 : i32, %arg1 : i1) -> (i32) {
+ %non_live = scf.while (%arg2 = %arg0) : (i32) -> (i32) {
+ scf.condition(%arg1) %arg2 : i32
+ } do {
+ ^bb0(%arg2: i32):
+ scf.yield %arg2 : i32
+ }
+ return %non_live : i32
+}
+
+func.func @main(%arg3 : i32, %arg4 : i1) {
+ %non_live_0 = func.call @clean_region_branch_op_erase_it(%arg3, %arg4) : (i32, i1) -> (i32)
+ return
+}
More information about the Mlir-commits
mailing list