[llvm-branch-commits] [mlir] [mlir] Consolidate patterns into `RegionBranchOpInterface` patterns (PR #174094)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jan 3 02:49:31 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174094
>From abdf646dfa8a2432bddf9eef89869279bf550f59 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 31 Dec 2025 14:07:51 +0000
Subject: [PATCH] [mlir][draft] Consolidate patterns into
RegionBranchOpInterface patterns
fix some tests
reorganize code
---
.../mlir/Interfaces/ControlFlowInterfaces.h | 9 +
.../mlir/Interfaces/ControlFlowInterfaces.td | 5 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 830 +-----------------
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 487 ++++++++++
mlir/test/Dialect/SCF/canonicalize.mlir | 24 +-
mlir/test/Transforms/remove-dead-values.mlir | 8 +-
6 files changed, 533 insertions(+), 830 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..ea85b2d1b5cb6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -16,6 +16,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
@@ -188,6 +189,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
/// possible successors.) Operands that not forwarded at all are not present in
/// the mapping.
using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+using RegionBranchInverseSuccessorMapping =
+ DenseMap<Value, SmallVector<OpOperand *>>;
/// This class represents a successor of a region. A region successor can either
/// be another region, or the parent operation. If the successor is a region,
@@ -350,6 +353,12 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
/// exists.
Region *getEnclosingRepetitiveRegion(Value value);
+/// Populate canonicalization patterns that simplify successor operands/inputs
+/// of region branch operations. Only operations with the given name are
+/// matched.
+void populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..70aed9e1e11c6 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -355,6 +355,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
::mlir::RegionBranchSuccessorMapping &mapping,
std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
+ /// Build a mapping from successor inputs to successor operands. This is
+ /// the same as "getSuccessorOperandInputMapping", but inverted.
+ void getSuccessorInputOperandMapping(
+ ::mlir::RegionBranchInverseSuccessorMapping &mapping);
+
/// Return all possible region branch points: the region branch op itself
/// and all region branch terminators.
::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8803a6d136f7a..95a854b655a53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -291,102 +291,11 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
-// Pattern to eliminate ExecuteRegionOp results which forward external
-// values from the region. In case there are multiple yield operations,
-// all of them must have the same operands in order for the pattern to be
-// applicable.
-struct ExecuteRegionForwardingEliminator
- : public OpRewritePattern<ExecuteRegionOp> {
- using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExecuteRegionOp op,
- PatternRewriter &rewriter) const override {
- if (op.getNumResults() == 0)
- return failure();
-
- SmallVector<Operation *> yieldOps;
- for (Block &block : op.getRegion()) {
- if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
- yieldOps.push_back(yield.getOperation());
- }
-
- if (yieldOps.empty())
- return failure();
-
- // Check if all yield operations have the same operands.
- auto yieldOpsOperands = yieldOps[0]->getOperands();
- for (auto *yieldOp : yieldOps) {
- if (yieldOp->getOperands() != yieldOpsOperands)
- return failure();
- }
-
- SmallVector<Value> externalValues;
- SmallVector<Value> internalValues;
- SmallVector<Value> opResultsToReplaceWithExternalValues;
- SmallVector<Value> opResultsToKeep;
- for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
- if (isValueFromInsideRegion(yieldedValue, op)) {
- internalValues.push_back(yieldedValue);
- opResultsToKeep.push_back(op.getResult(index));
- } else {
- externalValues.push_back(yieldedValue);
- opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
- }
- }
- // No yielded external values - nothing to do.
- if (externalValues.empty())
- return failure();
-
- // There are yielded external values - create a new execute_region returning
- // just the internal values.
- SmallVector<Type> resultTypes;
- for (Value value : internalValues)
- resultTypes.push_back(value.getType());
- auto newOp =
- ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
- newOp->setAttrs(op->getAttrs());
-
- // Move old op's region to the new operation.
- rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
-
- // Replace all yield operations with a new yield operation with updated
- // results. scf.execute_region must have at least one yield operation.
- for (auto *yieldOp : yieldOps) {
- rewriter.setInsertionPoint(yieldOp);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
- ValueRange(internalValues));
- }
-
- // Replace the old operation with the external values directly.
- rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
- externalValues);
- // Replace the old operation's remaining results with the new operation's
- // results.
- rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
- rewriter.eraseOp(op);
- return success();
- }
-
-private:
- bool isValueFromInsideRegion(Value value,
- ExecuteRegionOp executeRegionOp) const {
- // Check if the value is defined within the execute_region
- if (Operation *defOp = value.getDefiningOp())
- return &executeRegionOp.getRegion() == defOp->getParentRegion();
-
- // If it's a block argument, check if it's from within the region
- if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
- return &executeRegionOp.getRegion() == blockArg.getParentRegion();
-
- return false; // Value is from outside the region
- }
-};
-
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
- ExecuteRegionForwardingEliminator>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ results, ExecuteRegionOp::getOperationName());
}
void ExecuteRegionOp::getSuccessorRegions(
@@ -989,146 +898,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
}
namespace {
-// Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
-// have no use.
-//
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
-//
-// The implementation uses `inlineBlockBefore` to steal the content of the
-// original ForOp and avoid cloning.
-struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
- using OpRewritePattern<scf::ForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(scf::ForOp forOp,
- PatternRewriter &rewriter) const final {
- bool canonicalize = false;
-
- // An internal flat vector of block transfer
- // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
- // transformed block argument mappings. This plays the role of a
- // IRMapping for the particular use case of calling into
- // `inlineBlockBefore`.
- int64_t numResults = forOp.getNumResults();
- SmallVector<bool, 4> keepMask;
- keepMask.reserve(numResults);
- SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
- newResultValues;
- newBlockTransferArgs.reserve(1 + numResults);
- newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
- newIterArgs.reserve(forOp.getInitArgs().size());
- newYieldValues.reserve(numResults);
- newResultValues.reserve(numResults);
- DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
- for (auto [init, arg, result, yielded] :
- llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
- // Forwarded is `true` when:
- // 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument the corresponding input is yielded.
- // 3) The region `iter` argument has no use, and the corresponding op
- // result has no use.
- bool forwarded = (arg == yielded) || (init == yielded) ||
- (arg.use_empty() && result.use_empty());
- if (forwarded) {
- canonicalize = true;
- keepMask.push_back(false);
- newBlockTransferArgs.push_back(init);
- newResultValues.push_back(init);
- continue;
- }
-
- // Check if a previous kept argument always has the same values for init
- // and yielded values.
- if (auto it = initYieldToArg.find({init, yielded});
- it != initYieldToArg.end()) {
- canonicalize = true;
- keepMask.push_back(false);
- auto [sameArg, sameResult] = it->second;
- rewriter.replaceAllUsesWith(arg, sameArg);
- rewriter.replaceAllUsesWith(result, sameResult);
- // The replacement value doesn't matter because there are no uses.
- newBlockTransferArgs.push_back(init);
- newResultValues.push_back(init);
- continue;
- }
-
- // This value is kept.
- initYieldToArg.insert({{init, yielded}, {arg, result}});
- keepMask.push_back(true);
- newIterArgs.push_back(init);
- newYieldValues.push_back(yielded);
- newBlockTransferArgs.push_back(Value()); // placeholder with null value
- newResultValues.push_back(Value()); // placeholder with null value
- }
-
- if (!canonicalize)
- return failure();
-
- scf::ForOp newForOp =
- scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs,
- /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
- newForOp->setAttrs(forOp->getAttrs());
- Block &newBlock = newForOp.getRegion().front();
-
- // Replace the null placeholders with newly constructed values.
- newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
- for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
- idx != e; ++idx) {
- Value &blockTransferArg = newBlockTransferArgs[1 + idx];
- Value &newResultVal = newResultValues[idx];
- assert((blockTransferArg && newResultVal) ||
- (!blockTransferArg && !newResultVal));
- if (!blockTransferArg) {
- blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
- newResultVal = newForOp.getResult(collapsedIdx++);
- }
- }
-
- Block &oldBlock = forOp.getRegion().front();
- assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
- "unexpected argument size mismatch");
-
- // No results case: the scf::ForOp builder already created a zero
- // result terminator. Merge before this terminator and just get rid of the
- // original terminator that has been merged in.
- if (newIterArgs.empty()) {
- auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
- rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
- rewriter.replaceOp(forOp, newResultValues);
- return success();
- }
-
- // No terminator case: merge and rewrite the merged terminator.
- auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(mergedTerminator);
- SmallVector<Value, 4> filteredOperands;
- filteredOperands.reserve(newResultValues.size());
- for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
- if (keepMask[idx])
- filteredOperands.push_back(mergedTerminator.getOperand(idx));
- scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
- filteredOperands);
- };
-
- rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
- auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- cloneFilteredTerminator(mergedYieldOp);
- rewriter.eraseOp(mergedYieldOp);
- rewriter.replaceOp(forOp, newResultValues);
- return success();
- }
-};
-
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -1235,13 +1004,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
return failure();
}
};
-
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
- context);
+ results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+ populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ results, ForOp::getOperationName());
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -2378,35 +2147,6 @@ void IfOp::getRegionInvocationBounds(
}
namespace {
-// Pattern to remove unused IfOp results.
-struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Compute the list of unused results.
- BitVector toErase(op.getNumResults(), false);
- for (auto [idx, result] : llvm::enumerate(op.getResults()))
- if (result.use_empty())
- toErase[idx] = true;
- if (toErase.none())
- return rewriter.notifyMatchFailure(op, "no results to erase");
-
- // Erase results.
- auto newOp = cast<scf::IfOp>(rewriter.eraseOpResults(op, toErase));
-
- // Erase operands.
- rewriter.modifyOpInPlace(newOp.thenYield(), [&]() {
- newOp.thenYield()->eraseOperands(toErase);
- });
- rewriter.modifyOpInPlace(newOp.elseYield(), [&]() {
- newOp.elseYield()->eraseOperands(toErase);
- });
-
- return success();
- }
-};
-
struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -2977,8 +2717,10 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
- ReplaceIfYieldWithConditionOrValue>(context);
+ RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
+ context);
+ populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ results, IfOp::getOperationName());
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3816,390 +3558,6 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
}
};
-/// Remove loop invariant arguments from `before` block of scf.while.
-/// A before block argument is considered loop invariant if :-
-/// 1. i-th yield operand is equal to the i-th while operand.
-/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
-/// condition operand AND this (k+1)-th condition operand is equal to i-th
-/// iter argument/while operand.
-/// For the arguments which are removed, their uses inside scf.while
-/// are replaced with their corresponding initial value.
-///
-/// Eg:
-/// INPUT :-
-/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
-/// ..., %argN_before = %N)
-/// {
-/// ...
-/// scf.condition(%cond) %arg1_before, %arg0_before,
-/// %arg2_before, %arg0_before, ...
-/// } do {
-/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-/// ..., %argK_after):
-/// ...
-/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
-/// }
-///
-/// OUTPUT :-
-/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
-/// %N)
-/// {
-/// ...
-/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
-/// } do {
-/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-/// ..., %argK_after):
-/// ...
-/// scf.yield %arg1_after, ..., %argN
-/// }
-///
-/// EXPLANATION:
-/// We iterate over each yield operand.
-/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
-/// %arg0_before, which in turn is the 0-th iter argument. So we
-/// remove 0-th before block argument and yield operand, and replace
-/// all uses of the 0-th before block argument with its initial value
-/// %a.
-/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
-/// value. So we remove this operand and the corresponding before
-/// block argument and replace all uses of 1-th before block argument
-/// with %b.
-struct RemoveLoopInvariantArgsFromBeforeBlock
- : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- Block &afterBlock = *op.getAfterBody();
- Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
- ConditionOp condOp = op.getConditionOp();
- OperandRange condOpArgs = condOp.getArgs();
- Operation *yieldOp = afterBlock.getTerminator();
- ValueRange yieldOpArgs = yieldOp->getOperands();
-
- bool canSimplify = false;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
- auto index = static_cast<unsigned>(it.index());
- auto [initVal, yieldOpArg] = it.value();
- // If i-th yield operand is equal to the i-th operand of the scf.while,
- // the i-th before block argument is a loop invariant.
- if (yieldOpArg == initVal) {
- canSimplify = true;
- break;
- }
- // If the i-th yield operand is k-th after block argument, then we check
- // if the (k+1)-th condition op operand is equal to either the i-th before
- // block argument or the initial value of i-th before block argument. If
- // the comparison results `true`, i-th before block argument is a loop
- // invariant.
- auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
- if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
- Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
- if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
- canSimplify = true;
- break;
- }
- }
- }
-
- if (!canSimplify)
- return failure();
-
- SmallVector<Value> newInitArgs, newYieldOpArgs;
- DenseMap<unsigned, Value> beforeBlockInitValMap;
- SmallVector<Location> newBeforeBlockArgLocs;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
- auto index = static_cast<unsigned>(it.index());
- auto [initVal, yieldOpArg] = it.value();
-
- // If i-th yield operand is equal to the i-th operand of the scf.while,
- // the i-th before block argument is a loop invariant.
- if (yieldOpArg == initVal) {
- beforeBlockInitValMap.insert({index, initVal});
- continue;
- } else {
- // If the i-th yield operand is k-th after block argument, then we check
- // if the (k+1)-th condition op operand is equal to either the i-th
- // before block argument or the initial value of i-th before block
- // argument. If the comparison results `true`, i-th before block
- // argument is a loop invariant.
- auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
- if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
- Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
- if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
- beforeBlockInitValMap.insert({index, initVal});
- continue;
- }
- }
- }
- newInitArgs.emplace_back(initVal);
- newYieldOpArgs.emplace_back(yieldOpArg);
- newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
- }
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(yieldOp);
- rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
- }
-
- auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
- newInitArgs);
-
- Block &newBeforeBlock = *rewriter.createBlock(
- &newWhile.getBefore(), /*insertPt*/ {},
- ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
-
- Block &beforeBlock = *op.getBeforeBody();
- SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
- // For each i-th before block argument we find it's replacement value as :-
- // 1. If i-th before block argument is a loop invariant, we fetch it's
- // initial value from `beforeBlockInitValMap` by querying for key `i`.
- // 2. Else we fetch j-th new before block argument as the replacement
- // value of i-th before block argument.
- for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
- // If the index 'i' argument was a loop invariant we fetch it's initial
- // value from `beforeBlockInitValMap`.
- if (beforeBlockInitValMap.count(i) != 0)
- newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
- else
- newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
- }
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
- rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
- newWhile.getAfter().begin());
-
- rewriter.replaceOp(op, newWhile.getResults());
- return success();
- }
-};
-
-/// Remove loop invariant value from result (condition op) of scf.while.
-/// A value is considered loop invariant if the final value yielded by
-/// scf.condition is defined outside of the `before` block. We remove the
-/// corresponding argument in `after` block and replace the use with the value.
-/// We also replace the use of the corresponding result of scf.while with the
-/// value.
-///
-/// Eg:
-/// INPUT :-
-/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
-/// %argN_before = %N) {
-/// ...
-/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
-/// } do {
-/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
-/// ...
-/// some_func(%arg1_after)
-/// ...
-/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
-/// }
-///
-/// OUTPUT :-
-/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
-/// ...
-/// scf.condition(%cond) %arg0, %arg1, ..., %argM
-/// } do {
-/// ^bb0(%arg0, %arg3, ..., %argM):
-/// ...
-/// some_func(%a)
-/// ...
-/// scf.yield %arg0, %b, ..., %argN
-/// }
-///
-/// EXPLANATION:
-/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
-/// before block of scf.while, so they get removed.
-/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
-/// replaced by %b.
-/// 3. The corresponding after block argument %arg1_after's uses are
-/// replaced by %a and %arg2_after's uses are replaced by %b.
-struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- Block &beforeBlock = *op.getBeforeBody();
- ConditionOp condOp = op.getConditionOp();
- OperandRange condOpArgs = condOp.getArgs();
-
- bool canSimplify = false;
- for (Value condOpArg : condOpArgs) {
- // Those values not defined within `before` block will be considered as
- // loop invariant values. We map the corresponding `index` with their
- // value.
- if (condOpArg.getParentBlock() != &beforeBlock) {
- canSimplify = true;
- break;
- }
- }
-
- if (!canSimplify)
- return failure();
-
- Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
-
- SmallVector<Value> newCondOpArgs;
- SmallVector<Type> newAfterBlockType;
- DenseMap<unsigned, Value> condOpInitValMap;
- SmallVector<Location> newAfterBlockArgLocs;
- for (const auto &it : llvm::enumerate(condOpArgs)) {
- auto index = static_cast<unsigned>(it.index());
- Value condOpArg = it.value();
- // Those values not defined within `before` block will be considered as
- // loop invariant values. We map the corresponding `index` with their
- // value.
- if (condOpArg.getParentBlock() != &beforeBlock) {
- condOpInitValMap.insert({index, condOpArg});
- } else {
- newCondOpArgs.emplace_back(condOpArg);
- newAfterBlockType.emplace_back(condOpArg.getType());
- newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
- }
- }
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(condOp);
- rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
- newCondOpArgs);
- }
-
- auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
- op.getOperands());
-
- Block &newAfterBlock =
- *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
- newAfterBlockType, newAfterBlockArgLocs);
-
- Block &afterBlock = *op.getAfterBody();
- // Since a new scf.condition op was created, we need to fetch the new
- // `after` block arguments which will be used while replacing operations of
- // previous scf.while's `after` blocks. We'd also be fetching new result
- // values too.
- SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
- SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
- for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
- Value afterBlockArg, result;
- // If index 'i' argument was loop invariant we fetch it's value from the
- // `condOpInitMap` map.
- if (condOpInitValMap.count(i) != 0) {
- afterBlockArg = condOpInitValMap[i];
- result = afterBlockArg;
- } else {
- afterBlockArg = newAfterBlock.getArgument(j);
- result = newWhile.getResult(j);
- j++;
- }
- newAfterBlockArgs[i] = afterBlockArg;
- newWhileResults[i] = result;
- }
-
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
- rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
- newWhile.getBefore().begin());
-
- rewriter.replaceOp(op, newWhileResults);
- return success();
- }
-};
-
-/// Remove WhileOp results that are also unused in 'after' block.
-///
-/// %0:2 = scf.while () : () -> (i32, i64) {
-/// %condition = "test.condition"() : () -> i1
-/// %v1 = "test.get_some_value"() : () -> i32
-/// %v2 = "test.get_some_value"() : () -> i64
-/// scf.condition(%condition) %v1, %v2 : i32, i64
-/// } do {
-/// ^bb0(%arg0: i32, %arg1: i64):
-/// "test.use"(%arg0) : (i32) -> ()
-/// scf.yield
-/// }
-/// return %0#0 : i32
-///
-/// becomes
-/// %0 = scf.while () : () -> (i32) {
-/// %condition = "test.condition"() : () -> i1
-/// %v1 = "test.get_some_value"() : () -> i32
-/// %v2 = "test.get_some_value"() : () -> i64
-/// scf.condition(%condition) %v1 : i32
-/// } do {
-/// ^bb0(%arg0: i32):
-/// "test.use"(%arg0) : (i32) -> ()
-/// scf.yield
-/// }
-/// return %0 : i32
-struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- auto term = op.getConditionOp();
- auto afterArgs = op.getAfterArguments();
- auto termArgs = term.getArgs();
-
- // Collect results mapping, new terminator args and new result types.
- SmallVector<unsigned> newResultsIndices;
- SmallVector<Type> newResultTypes;
- SmallVector<Value> newTermArgs;
- SmallVector<Location> newArgLocs;
- bool needUpdate = false;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
- auto i = static_cast<unsigned>(it.index());
- Value result = std::get<0>(it.value());
- Value afterArg = std::get<1>(it.value());
- Value termArg = std::get<2>(it.value());
- if (result.use_empty() && afterArg.use_empty()) {
- needUpdate = true;
- } else {
- newResultsIndices.emplace_back(i);
- newTermArgs.emplace_back(termArg);
- newResultTypes.emplace_back(result.getType());
- newArgLocs.emplace_back(result.getLoc());
- }
- }
-
- if (!needUpdate)
- return failure();
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(term);
- rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
- newTermArgs);
- }
-
- auto newWhile =
- WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
-
- Block &newAfterBlock = *rewriter.createBlock(
- &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
-
- // Build new results list and new after block args (unused entries will be
- // null).
- SmallVector<Value> newResults(op.getNumResults());
- SmallVector<Value> newAfterBlockArgs(op.getNumResults());
- for (const auto &it : llvm::enumerate(newResultsIndices)) {
- newResults[it.value()] = newWhile.getResult(it.index());
- newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
- }
-
- rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
- newWhile.getBefore().begin());
-
- Block &afterBlock = *op.getAfterBody();
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-
- rewriter.replaceOp(op, newResults);
- return success();
- }
-};
-
/// Replace operations equivalent to the condition in the do block with true,
/// since otherwise the block would not be evaluated.
///
@@ -4264,127 +3622,6 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
}
};
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
-
- if (!llvm::any_of(op.getBeforeArguments(),
- [](Value arg) { return arg.use_empty(); }))
- return rewriter.notifyMatchFailure(op, "No args to remove");
-
- YieldOp yield = op.getYieldOp();
-
- // Collect results mapping, new terminator args and new result types.
- SmallVector<Value> newYields;
- SmallVector<Value> newInits;
- llvm::BitVector argsToErase;
-
- size_t argsCount = op.getBeforeArguments().size();
- newYields.reserve(argsCount);
- newInits.reserve(argsCount);
- argsToErase.reserve(argsCount);
- for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
- op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
- if (beforeArg.use_empty()) {
- argsToErase.push_back(true);
- } else {
- argsToErase.push_back(false);
- newYields.emplace_back(yieldValue);
- newInits.emplace_back(initValue);
- }
- }
-
- Block &beforeBlock = *op.getBeforeBody();
- Block &afterBlock = *op.getAfterBody();
-
- beforeBlock.eraseArguments(argsToErase);
-
- Location loc = op.getLoc();
- auto newWhileOp =
- WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
- /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
- Block &newBeforeBlock = *newWhileOp.getBeforeBody();
- Block &newAfterBlock = *newWhileOp.getAfterBody();
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(yield);
- rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
- newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
- newAfterBlock.getArguments());
-
- rewriter.replaceOp(op, newWhileOp.getResults());
- return success();
- }
-};
-
-/// Remove duplicated ConditionOp args.
-struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- ConditionOp condOp = op.getConditionOp();
- ValueRange condOpArgs = condOp.getArgs();
-
- llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
-
- if (argsSet.size() == condOpArgs.size())
- return rewriter.notifyMatchFailure(op, "No results to remove");
-
- llvm::SmallDenseMap<Value, unsigned> argsMap;
- SmallVector<Value> newArgs;
- argsMap.reserve(condOpArgs.size());
- newArgs.reserve(condOpArgs.size());
- for (Value arg : condOpArgs) {
- if (!argsMap.count(arg)) {
- auto pos = static_cast<unsigned>(argsMap.size());
- argsMap.insert({arg, pos});
- newArgs.emplace_back(arg);
- }
- }
-
- ValueRange argsRange(newArgs);
-
- Location loc = op.getLoc();
- auto newWhileOp =
- scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
- /*beforeBody*/ nullptr,
- /*afterBody*/ nullptr);
- Block &newBeforeBlock = *newWhileOp.getBeforeBody();
- Block &newAfterBlock = *newWhileOp.getAfterBody();
-
- SmallVector<Value> afterArgsMapping;
- SmallVector<Value> resultsMapping;
- for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
- auto it = argsMap.find(arg);
- assert(it != argsMap.end());
- auto pos = it->second;
- afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
- resultsMapping.emplace_back(newWhileOp->getResult(pos));
- }
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(condOp);
- rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
- argsRange);
-
- Block &beforeBlock = *op.getBeforeBody();
- Block &afterBlock = *op.getAfterBody();
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
- newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
- rewriter.replaceOp(op, resultsMapping);
- return success();
- }
-};
-
/// If both ranges contain same values return mappping indices from args2 to
/// args1. Otherwise return std::nullopt.
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
@@ -4475,11 +3712,10 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<RemoveLoopInvariantArgsFromBeforeBlock,
- RemoveLoopInvariantValueYielded, WhileConditionTruth,
- WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
- context);
+ results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
+ WhileMoveIfDown>(context);
+ populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ results, WhileOp::getOperationName());
}
//===----------------------------------------------------------------------===//
@@ -4654,43 +3890,11 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IndexSwitchOp op,
- PatternRewriter &rewriter) const override {
- // Find dead results.
- BitVector deadResults(op.getNumResults(), false);
- for (auto [idx, result] : llvm::enumerate(op.getResults()))
- if (result.use_empty())
- deadResults[idx] = true;
- if (!deadResults.any())
- return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
- // Erase dead results.
- auto newOp =
- cast<scf::IndexSwitchOp>(rewriter.eraseOpResults(op, deadResults));
-
- // Erase operands from yield ops.
- auto updateCaseRegion = [&](Region ®ion) {
- Operation *terminator = region.front().getTerminator();
- assert(isa<YieldOp>(terminator) && "expected yield op");
- rewriter.modifyOpInPlace(
- terminator, [&]() { terminator->eraseOperands(deadResults); });
- };
- updateCaseRegion(newOp.getDefaultRegion());
- for (Region &caseRegion : newOp.getCaseRegions())
- updateCaseRegion(caseRegion);
-
- return success();
- }
-};
-
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+ results.add<FoldConstantCase>(context);
+ populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ results, IndexSwitchOp::getOperationName());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index d393ddb8d8336..ce9ce0002f6e9 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -10,7 +10,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -521,6 +523,23 @@ void RegionBranchOpInterface::getSuccessorOperandInputMapping(
}
}
+static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping(
+ const RegionBranchSuccessorMapping &operandToInputs) {
+ RegionBranchInverseSuccessorMapping inputToOperands;
+ for (const auto &[operand, inputs] : operandToInputs) {
+ for (Value input : inputs)
+ inputToOperands[input].push_back(operand);
+ }
+ return inputToOperands;
+}
+
+void RegionBranchOpInterface::getSuccessorInputOperandMapping(
+ RegionBranchInverseSuccessorMapping &mapping) {
+ RegionBranchSuccessorMapping operandToInputs;
+ getSuccessorOperandInputMapping(operandToInputs);
+ mapping = invertRegionBranchSuccessorMapping(operandToInputs);
+}
+
SmallVector<RegionBranchPoint>
RegionBranchOpInterface::getAllRegionBranchPoints() {
SmallVector<RegionBranchPoint> branchPoints;
@@ -583,3 +602,471 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
+
+/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
+/// successor input and `a` is a "possible value" of `b`. Possible values are
+/// successor operand values that are (maybe transitively) forwarded to `b`.
+static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
+ assert((b.getDefiningOp() == regionBranchOp ||
+ b.getParentRegion()->getParentOp() == regionBranchOp) &&
+ "b must be a region successor input");
+
+ // Case 1: `a` is defined inside of the region branch op. `a` must be
+ // directly nested in the region branch op. Otherwise, it could not have
+ // been among the possible values for a region successor input.
+ if (a.getParentRegion()->getParentOp() == regionBranchOp) {
+ // Case 1.1: If `b` is a result of the region branch op, `a` is not in
+ // scope for `b`.
+ // Example:
+ // %b = region_op({
+ // ^bb0(%a1: ...):
+ // %a2 = ...
+ // })
+ if (isa<OpResult>(b))
+ return false;
+
+ // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
+ // for `b` only if it is also an entry block argument of the same region.
+ // Example:
+ // region_op({
+ // ^bb0(%b: ..., %a: ...):
+ // ...
+ // })
+ assert(isa<BlockArgument>(b) && "b must be a block argument");
+ return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
+ cast<BlockArgument>(b).getOwner();
+ }
+
+ // Case 2: `a` is defined outside of the region branch op. In that case, we
+ // can safely assume that `a` was defined before `b`. Otherwise, it could not
+ // be among the possible values for a region successor input.
+ // Example:
+ // { <- %a1 parent region begins here.
+ // ^bb0(%a1: ...):
+ // %a2 = ...
+ // %b1 = reigon_op({
+ // ^bb1(%b2: ...):
+ // ...
+ // })
+ // }
+ return true;
+}
+
+/// Compute all non-successor input values that a successor input could have
+/// based on the given successor input to successor operand mapping.
+///
+/// Example 1:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+/// scf.yield %arg0 : ...
+/// }
+/// getPossibleValuesOfSuccessorInput(%arg0) = {%0}
+/// getPossibleValuesOfSuccessorInput(%r) = {%0}
+///
+/// Example 2:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+/// ...
+/// scf.yield %1 : ...
+/// }
+/// getPossibleValuesOfSuccessorInput(%arg0) = {%0, %1}
+/// getPossibleValuesOfSuccessorInput(%r) = {%0, %1}
+static llvm::SmallDenseSet<Value> computePossibleValuesOfSuccessorInput(
+ Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
+ llvm::SmallDenseSet<Value> possibleValues;
+ llvm::SmallDenseSet<Value> visited;
+ SmallVector<Value> worklist;
+
+ // Starting with the given value, trace back all predecessor values (i.e.,
+ // preceding successor operands) and add them to the set of possible values.
+ // If the successor operand is again a successor input, do not add it to
+ // result set, but instead continue the traversal.
+ worklist.push_back(value);
+ while (!worklist.empty()) {
+ Value next = worklist.pop_back_val();
+ auto it = inputToOperands.find(next);
+ if (it == inputToOperands.end()) {
+ possibleValues.insert(next);
+ continue;
+ }
+ for (OpOperand *operand : it->second)
+ if (visited.insert(operand->get()).second)
+ worklist.push_back(operand->get());
+ }
+
+ return possibleValues;
+}
+
+namespace {
+/// Try to make successor inputs dead by replacing their uses with values that
+/// are not successor inputs. This pattern enables additional canonicalization
+/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
+///
+/// Example:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+/// scf.yield %arg1, %arg1 : ...
+/// }
+/// use(%r0, %r1)
+///
+/// possibleValues(%r0) = {%0, %1}
+/// possibleValues(%r1) = {%1} ==> replace uses of %r1 with %1.
+/// possibleValues(%arg0) = {%0, %1}
+/// possibleValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
+///
+/// IR after pattern application:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+/// scf.yield %1, %1 : ...
+/// }
+/// use(%r0, %1)
+///
+/// Note that %r1 and %arg1 are dead now. The IR can now be further
+/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
+struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
+ MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "isolated-from-above ops are not supported");
+
+ // Compute the mapping of successor inputs to successor operands.
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ RegionBranchInverseSuccessorMapping inputToOperands;
+ regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
+
+ // Try to replace the uses of each successor input one-by-one.
+ bool changed = false;
+ for (Value value : inputToOperands.keys()) {
+ // Nothing to do for successor inputs that are already dead.
+ if (value.use_empty())
+ continue;
+ // Nothing to do for successor inputs that may have multiple possible
+ // values.
+ llvm::SmallDenseSet<Value> possibleValues =
+ computePossibleValuesOfSuccessorInput(value, inputToOperands);
+ if (possibleValues.size() != 1)
+ continue;
+ assert(*possibleValues.begin() != value &&
+ "successor inputs are supposed to be excluded");
+ // Do not replace `value` with the found possible value if doing so would
+ // violate dominance. Example:
+ // %r = scf.execute_region ... {
+ // %a = ...
+ // scf.yield %a : ...
+ // }
+ // use(%r)
+ // In the above example, possibleValues(%r) = {%a}, but %a cannot be used
+ // as a replacement for %r due to dominance / scope.
+ if (!isDefinedBefore(regionBranchOp, *possibleValues.begin(), value))
+ continue;
+ rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+ changed = true;
+ }
+ return success(changed);
+ }
+};
+
+/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
+/// found, create a new bit vector with the given size and initialize it with
+/// false.
+template <typename MappingTy, typename KeyTy>
+static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
+ unsigned size) {
+ return mapping.try_emplace(key, size, false).first->second;
+}
+
+/// Compute tied successor inputs. Tied successor inputs are successor inputs
+/// that come as a set. If you erase one value from a set, you must erase all
+/// values from the set. Otherwise, the op would become structurally invalid.
+/// Each successor input appears in exactly one set.
+///
+/// Example:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+/// ...
+/// }
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
+static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
+ const RegionBranchSuccessorMapping &operandToInputs) {
+ llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+ for (const auto &[operand, inputs] : operandToInputs) {
+ assert(!inputs.empty() && "expected non-empty inputs");
+ Value firstInput = inputs.front();
+ tiedSuccessorInputs.insert(firstInput);
+ for (Value nextInput : llvm::drop_begin(inputs)) {
+ // As we explore more successor operand to successor input mappings,
+ // existing sets may get merged.
+ tiedSuccessorInputs.unionSets(firstInput, nextInput);
+ }
+ }
+ return tiedSuccessorInputs;
+}
+
+/// Remove dead successor inputs from region branch ops. A successor input is
+/// dead if it has no uses. Successor inputs come in sets of tied values: if
+/// you remove one value from a set, you must remove all values from the set.
+/// Furthermore, successor operands must also be removed. (Op operands are not
+/// part of the set, but the set is built based on the successor operand to
+/// successor input mapping.)
+///
+/// Example 1:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+/// scf.yield %0, %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
+/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
+/// resulting IR is as follows:
+///
+/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
+/// scf.yield %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// Example 2:
+/// %r0, %r1 = scf.while (%arg0 = %0) {
+/// scf.condition(...) %arg0, %arg0 : ...
+/// } do {
+/// ^bb0(%arg1: ..., %arg2: ...):
+/// scf.yield %arg1 : ...
+/// }
+/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
+///
+/// Example 3:
+/// %r1, %r2 = scf.if ... {
+/// scf.yield %0, %1 : ...
+/// } else {
+/// scf.yield %2, %3 : ...
+/// }
+/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
+/// value can be removed independently of the other values.
+struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
+ RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "isolated-from-above ops are not supported");
+
+ // Compute tied values: values that must come as a set. If you remove one,
+ // you must remove all. If a successor op operand is forwarded to two
+ // successor inputs %a and %b, both %a and %b are in the same set.
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ RegionBranchSuccessorMapping operandToInputs;
+ regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
+ llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
+ computeTiedSuccessorInputs(operandToInputs);
+
+ // Determine which values to remove and group them by block and operation.
+ SmallVector<Value> valuesToRemove;
+ DenseMap<Block *, BitVector> blockArgsToRemove;
+ DenseMap<Operation *, BitVector> resultsToRemove;
+ // Iterate over all sets of tied successor inputs.
+ for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+ it != e; ++it) {
+ if (!(*it)->isLeader())
+ continue;
+
+ // Value can be removed if it is dead and all other tied values are also
+ // dead.
+ bool allDead = true;
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ // Iterate over all values in the set and check their liveness.
+ if (!memberIt->use_empty()) {
+ allDead = false;
+ break;
+ }
+ }
+ if (!allDead)
+ continue;
+
+ // The entire set is dead. Group values by block and operation to
+ // simplify removal.
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+ // Set blockArgsToRemove[block][arg_number] = true.
+ BitVector &vector =
+ lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
+ arg.getOwner()->getNumArguments());
+ vector.set(arg.getArgNumber());
+ } else {
+ // Set resultsToRemove[op][result_number] = true.
+ OpResult result = cast<OpResult>(*memberIt);
+ BitVector &vector =
+ lookupOrCreateBitVector(resultsToRemove, result.getDefiningOp(),
+ result.getDefiningOp()->getNumResults());
+ vector.set(result.getResultNumber());
+ }
+ valuesToRemove.push_back(*memberIt);
+ }
+ }
+
+ if (valuesToRemove.empty())
+ return rewriter.notifyMatchFailure(op, "no values to remove");
+
+ // Find operands that must be removed together with the values.
+ RegionBranchInverseSuccessorMapping inputsToOperands =
+ invertRegionBranchSuccessorMapping(operandToInputs);
+ DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+ for (Value value : valuesToRemove) {
+ for (OpOperand *operand : inputsToOperands[value]) {
+ // Set operandsToRemove[op][operand_number] = true.
+ BitVector &vector =
+ lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
+ operand->getOwner()->getNumOperands());
+ vector.set(operand->getOperandNumber());
+ }
+ }
+
+ // Erase operands.
+ for (auto &pair : operandsToRemove) {
+ Operation *op = pair.first;
+ BitVector &operands = pair.second;
+ rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+ }
+
+ // Erase block arguments.
+ for (auto &pair : blockArgsToRemove) {
+ Block *block = pair.first;
+ BitVector &blockArg = pair.second;
+ rewriter.modifyOpInPlace(block->getParentOp(),
+ [&]() { block->eraseArguments(blockArg); });
+ }
+
+ // Erase op results.
+ for (auto [op, resultsToErase] : resultsToRemove)
+ rewriter.eraseOpResults(op, resultsToErase);
+
+ return success();
+ }
+};
+
+/// Return "true" if the two values are owned by the same operation or block.
+static bool haveSameOwner(Value a, Value b) {
+ void *aOwner, *bOwner;
+ if (auto arg = dyn_cast<BlockArgument>(a))
+ aOwner = arg.getOwner();
+ else
+ aOwner = a.getDefiningOp();
+ if (auto arg = dyn_cast<BlockArgument>(b))
+ bOwner = arg.getOwner();
+ else
+ bOwner = b.getDefiningOp();
+ return aOwner == bOwner;
+}
+
+/// Get the block argument or op result number of the given value.
+static unsigned getArgOrResultNumber(Value value) {
+ if (auto opResult = llvm::dyn_cast<OpResult>(value))
+ return opResult.getResultNumber();
+ return llvm::cast<BlockArgument>(value).getArgNumber();
+}
+
+/// Find duplicate successor inputs and make all dead except for one. Two
+/// successor inputs are "duplicate" if their corresponding successor operands
+/// have the same values. This pattern enables additional canonicalization
+/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
+///
+/// Example:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
+/// use(%arg0, %arg1)
+/// ...
+/// scf.yield %x, %x : ...
+/// }
+/// use(%r0, %r1)
+///
+/// Operands of successor input %r0: [%0, %x]
+/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
+/// Replace %r1 with %r0.
+///
+/// Operands of successor input %arg0: [%0, %x]
+/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
+/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
+/// as for the other tied successor inputs above. Otherwise, a set of tied
+/// successor inputs may not become entirely dead.)
+///
+/// The resulting IR is as follows:
+/// %r1, %r2 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
+/// use(%arg0, %arg0)
+/// ...
+/// scf.yield %x, %x : ...
+/// }
+/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct,
+/// // but does not help with further canonicalizations.
+struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
+ RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+ "isolated-from-above ops are not supported");
+
+ // Collect all successor inputs and sort them. When dropping the uses of a
+ // successor input, we'd like to also drop the uses of the same tied
+ // successor inputs. Otherwise, a set of tied successor inputs may not
+ // become entirely dead, which is required for
+ // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
+ // (Sorting is not required for correctness.)
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ RegionBranchInverseSuccessorMapping inputsToOperands;
+ regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
+ SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
+ llvm::sort(inputs, [](Value a, Value b) {
+ return getArgOrResultNumber(a) < getArgOrResultNumber(b);
+ });
+
+ // Check every distinct pair of successor inputs for duplicates. Replace
+ // `input2` with `input1` if they are duplicates.
+ bool changed = false;
+ unsigned numInputs = inputs.size();
+ for (auto i : llvm::seq<unsigned>(0, numInputs)) {
+ Value input1 = inputs[i];
+ for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
+ Value input2 = inputs[j];
+ // Nothing to do if input2 is already dead.
+ if (input2.use_empty())
+ continue;
+ // Replace only values that belong to the same block / operation.
+ // This implies that the two values are either both block arguments or
+ // both op results.
+ if (!haveSameOwner(input1, input2))
+ continue;
+
+ // Gather the predecessor value for each predecessor (region branch
+ // point). The two inputs are duplicates if each predecessor forwards
+ // the same value.
+ DenseMap<Operation *, Value> operands1, operands2;
+ for (OpOperand *operand : inputsToOperands[input1]) {
+ assert(!operands1.contains(operand->getOwner()));
+ operands1[operand->getOwner()] = operand->get();
+ }
+ for (OpOperand *operand : inputsToOperands[input2]) {
+ assert(!operands2.contains(operand->getOwner()));
+ operands2[operand->getOwner()] = operand->get();
+ }
+ if (operands1 == operands2) {
+ rewriter.replaceAllUsesWith(input2, input1);
+ changed = true;
+ }
+ }
+ }
+ return success(changed);
+ }
+};
+} // namespace
+
+void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
+ RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
+ patterns.add<MakeRegionBranchOpSuccessorInputsDead,
+ RemoveDuplicateSuccessorInputUses,
+ RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
+ opName, benefit);
+}
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d5d0aee3bbe25..11dc4f04af32e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1071,17 +1071,17 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
-// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
+// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
// CHECK: tensor.extract %{{.*}}[]
-// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]]
// CHECK: } do {
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>):
// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
-// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
-// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
+// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+// CHECK: scf.yield %[[VAL0]], %[[VAL1]]
// CHECK: }
-// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
+// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#1, %[[ZERO]]
// CHECK-LABEL: @while_loop_invariant_argument_different_order
func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
@@ -1736,11 +1736,11 @@ module {
// Test case with multiple scf.yield ops with at least one different operand, then no change.
-// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+// CHECK: %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
// CHECK: ^bb1:
-// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: scf.yield %{{.*}} : memref<1x120xui8>
// CHECK: ^bb2:
-// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: scf.yield %{{.*}} : memref<1x120xui8>
// CHECK: }
module {
@@ -2178,16 +2178,14 @@ func.func @scf_for_all_step_size_0() {
// CHECK-SAME: %[[arg0:.*]]: index
// CHECK-DAG: %[[c10:.*]] = arith.constant 10
// CHECK-DAG: %[[c11:.*]] = arith.constant 11
-// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: scf.index_switch %[[arg0]]
// CHECK: case 1 {
// CHECK: memref.store %[[c10]]
-// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
// CHECK: default {
// CHECK: memref.store %[[c11]]
-// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
-// CHECK: return %[[switch]]
+// CHECK: return %[[arg0]]
func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
%non_live, %live = scf.index_switch %arg0 -> i32, index
case 1 {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 2584573c8b4dc..ae83eac0c376f 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -250,13 +250,13 @@ func.func @main() -> (i32, i32) {
// CHECK-NEXT: }
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
// CHECK-CANONICALIZE-NEXT: } do {
// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE: scf.yield %[[live_1]] : i32
// CHECK-CANONICALIZE-NEXT: }
// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0
// CHECK-CANONICALIZE-NEXT: }
@@ -306,7 +306,7 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1
-// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> ()
// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
// CHECK-CANONICALIZE-NEXT: } do {
More information about the llvm-branch-commits
mailing list