[llvm-branch-commits] [mlir] [mlir][draft] Consolidate patterns into `RegionBranchOpInterface` patterns (PR #174094)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 1 04:13:11 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/174094
>From 8b2bcb7b6652272bc614d37ca2305000797b2b3f Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 31 Dec 2025 14:07:51 +0000
Subject: [PATCH] [mlir][draft] Consolidate patterns into
RegionBranchOpInterface patterns
fix some tests
---
.../SparseTensor/IR/SparseTensorOps.td | 7 +-
.../mlir/Interfaces/ControlFlowInterfaces.h | 2 +
.../mlir/Interfaces/ControlFlowInterfaces.td | 9 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 1134 ++++-------------
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 39 +
mlir/test/Dialect/SCF/canonicalize.mlir | 24 +-
.../Dialect/SparseTensor/sparse_kernels.mlir | 16 +-
.../test/Dialect/SparseTensor/sparse_out.mlir | 34 +-
.../Vector/vector-warp-distribute.mlir | 6 +-
mlir/test/Transforms/remove-dead-values.mlir | 8 +-
10 files changed, 377 insertions(+), 902 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a61d90a0c39b1..f41b3694d9c79 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1304,9 +1304,10 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
let hasVerifier = 1;
}
-def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
- ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
+def SparseTensor_YieldOp : SparseTensor_Op<"yield",
+ [Pure, Terminator, ReturnLike,
+ ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
+ "ForeachOp", "IterateOp", "CoIterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..a7565f9f7bb78 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -188,6 +188,8 @@ LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
/// possible successors.) Operands that not forwarded at all are not present in
/// the mapping.
using RegionBranchSuccessorMapping = DenseMap<OpOperand *, SmallVector<Value>>;
+using RegionBranchInverseSuccessorMapping =
+ DenseMap<Value, SmallVector<OpOperand *>>;
/// This class represents a successor of a region. A region successor can either
/// be another region, or the parent operation. If the successor is a region,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..9366e5562b774 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -355,6 +355,15 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
::mlir::RegionBranchSuccessorMapping &mapping,
std::optional<::mlir::RegionBranchPoint> src = std::nullopt);
+ /// Build a mapping from successor inputs to successor operands. This is
+ /// the same as "getSuccessorOperandInputMapping", but inverted.
+ void getSuccessorInputOperandMapping(
+ ::mlir::RegionBranchInverseSuccessorMapping &mapping);
+
+ /// Compute all values that a successor input could possibly have. If the
+ /// given value is not a successor input, return an empty set.
+ ::llvm::DenseSet<Value> computePossibleValuesOfSuccessorInput(::mlir::Value value);
+
/// Return all possible region branch points: the region branch op itself
/// and all region branch terminators.
::llvm::SmallVector<::mlir::RegionBranchPoint> getAllRegionBranchPoints();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0a123112cf68f..06b542d1c1dae 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -291,102 +292,9 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
-// Pattern to eliminate ExecuteRegionOp results which forward external
-// values from the region. In case there are multiple yield operations,
-// all of them must have the same operands in order for the pattern to be
-// applicable.
-struct ExecuteRegionForwardingEliminator
- : public OpRewritePattern<ExecuteRegionOp> {
- using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExecuteRegionOp op,
- PatternRewriter &rewriter) const override {
- if (op.getNumResults() == 0)
- return failure();
-
- SmallVector<Operation *> yieldOps;
- for (Block &block : op.getRegion()) {
- if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
- yieldOps.push_back(yield.getOperation());
- }
-
- if (yieldOps.empty())
- return failure();
-
- // Check if all yield operations have the same operands.
- auto yieldOpsOperands = yieldOps[0]->getOperands();
- for (auto *yieldOp : yieldOps) {
- if (yieldOp->getOperands() != yieldOpsOperands)
- return failure();
- }
-
- SmallVector<Value> externalValues;
- SmallVector<Value> internalValues;
- SmallVector<Value> opResultsToReplaceWithExternalValues;
- SmallVector<Value> opResultsToKeep;
- for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
- if (isValueFromInsideRegion(yieldedValue, op)) {
- internalValues.push_back(yieldedValue);
- opResultsToKeep.push_back(op.getResult(index));
- } else {
- externalValues.push_back(yieldedValue);
- opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
- }
- }
- // No yielded external values - nothing to do.
- if (externalValues.empty())
- return failure();
-
- // There are yielded external values - create a new execute_region returning
- // just the internal values.
- SmallVector<Type> resultTypes;
- for (Value value : internalValues)
- resultTypes.push_back(value.getType());
- auto newOp =
- ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
- newOp->setAttrs(op->getAttrs());
-
- // Move old op's region to the new operation.
- rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
- newOp.getRegion().end());
-
- // Replace all yield operations with a new yield operation with updated
- // results. scf.execute_region must have at least one yield operation.
- for (auto *yieldOp : yieldOps) {
- rewriter.setInsertionPoint(yieldOp);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
- ValueRange(internalValues));
- }
-
- // Replace the old operation with the external values directly.
- rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
- externalValues);
- // Replace the old operation's remaining results with the new operation's
- // results.
- rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
- rewriter.eraseOp(op);
- return success();
- }
-
-private:
- bool isValueFromInsideRegion(Value value,
- ExecuteRegionOp executeRegionOp) const {
- // Check if the value is defined within the execute_region
- if (Operation *defOp = value.getDefiningOp())
- return &executeRegionOp.getRegion() == defOp->getParentRegion();
-
- // If it's a block argument, check if it's from within the region
- if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
- return &executeRegionOp.getRegion() == blockArg.getParentRegion();
-
- return false; // Value is from outside the region
- }
-};
-
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
- ExecuteRegionForwardingEliminator>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
}
void ExecuteRegionOp::getSuccessorRegions(
@@ -989,146 +897,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
}
namespace {
-// Fold away ForOp iter arguments when:
-// 1) The op yields the iter arguments.
-// 2) The argument's corresponding outer region iterators (inputs) are yielded.
-// 3) The iter arguments have no use and the corresponding (operation) results
-// have no use.
-//
-// These arguments must be defined outside of the ForOp region and can just be
-// forwarded after simplifying the op inits, yields and returns.
-//
-// The implementation uses `inlineBlockBefore` to steal the content of the
-// original ForOp and avoid cloning.
-struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
- using OpRewritePattern<scf::ForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(scf::ForOp forOp,
- PatternRewriter &rewriter) const final {
- bool canonicalize = false;
-
- // An internal flat vector of block transfer
- // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
- // transformed block argument mappings. This plays the role of a
- // IRMapping for the particular use case of calling into
- // `inlineBlockBefore`.
- int64_t numResults = forOp.getNumResults();
- SmallVector<bool, 4> keepMask;
- keepMask.reserve(numResults);
- SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
- newResultValues;
- newBlockTransferArgs.reserve(1 + numResults);
- newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
- newIterArgs.reserve(forOp.getInitArgs().size());
- newYieldValues.reserve(numResults);
- newResultValues.reserve(numResults);
- DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
- for (auto [init, arg, result, yielded] :
- llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
- // Forwarded is `true` when:
- // 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument the corresponding input is yielded.
- // 3) The region `iter` argument has no use, and the corresponding op
- // result has no use.
- bool forwarded = (arg == yielded) || (init == yielded) ||
- (arg.use_empty() && result.use_empty());
- if (forwarded) {
- canonicalize = true;
- keepMask.push_back(false);
- newBlockTransferArgs.push_back(init);
- newResultValues.push_back(init);
- continue;
- }
-
- // Check if a previous kept argument always has the same values for init
- // and yielded values.
- if (auto it = initYieldToArg.find({init, yielded});
- it != initYieldToArg.end()) {
- canonicalize = true;
- keepMask.push_back(false);
- auto [sameArg, sameResult] = it->second;
- rewriter.replaceAllUsesWith(arg, sameArg);
- rewriter.replaceAllUsesWith(result, sameResult);
- // The replacement value doesn't matter because there are no uses.
- newBlockTransferArgs.push_back(init);
- newResultValues.push_back(init);
- continue;
- }
-
- // This value is kept.
- initYieldToArg.insert({{init, yielded}, {arg, result}});
- keepMask.push_back(true);
- newIterArgs.push_back(init);
- newYieldValues.push_back(yielded);
- newBlockTransferArgs.push_back(Value()); // placeholder with null value
- newResultValues.push_back(Value()); // placeholder with null value
- }
-
- if (!canonicalize)
- return failure();
-
- scf::ForOp newForOp =
- scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs,
- /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
- newForOp->setAttrs(forOp->getAttrs());
- Block &newBlock = newForOp.getRegion().front();
-
- // Replace the null placeholders with newly constructed values.
- newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
- for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
- idx != e; ++idx) {
- Value &blockTransferArg = newBlockTransferArgs[1 + idx];
- Value &newResultVal = newResultValues[idx];
- assert((blockTransferArg && newResultVal) ||
- (!blockTransferArg && !newResultVal));
- if (!blockTransferArg) {
- blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
- newResultVal = newForOp.getResult(collapsedIdx++);
- }
- }
-
- Block &oldBlock = forOp.getRegion().front();
- assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
- "unexpected argument size mismatch");
-
- // No results case: the scf::ForOp builder already created a zero
- // result terminator. Merge before this terminator and just get rid of the
- // original terminator that has been merged in.
- if (newIterArgs.empty()) {
- auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
- rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
- rewriter.replaceOp(forOp, newResultValues);
- return success();
- }
-
- // No terminator case: merge and rewrite the merged terminator.
- auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(mergedTerminator);
- SmallVector<Value, 4> filteredOperands;
- filteredOperands.reserve(newResultValues.size());
- for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
- if (keepMask[idx])
- filteredOperands.push_back(mergedTerminator.getOperand(idx));
- scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
- filteredOperands);
- };
-
- rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
- auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- cloneFilteredTerminator(mergedYieldOp);
- rewriter.eraseOp(mergedYieldOp);
- rewriter.replaceOp(forOp, newResultValues);
- return success();
- }
-};
-
/// Rewriting pattern that erases loops that are known not to iterate, replaces
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
@@ -1236,12 +1004,283 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};
+/// Is a defined before b?
+static bool isDefinedBefore(Value a, Value b) {
+ Region *aRegion = a.getParentRegion();
+ Region *bRegion = b.getParentRegion();
+
+ if (aRegion->getParentOp()->isProperAncestor(bRegion->getParentOp())) {
+ return true;
+ }
+ if (aRegion == bRegion) {
+ Block *aBlock = a.getParentBlock();
+ Block *bBlock = b.getParentBlock();
+ if (aBlock != bBlock)
+ return false;
+ if (isa<BlockArgument>(a))
+ return true;
+ if (isa<BlockArgument>(b))
+ return false;
+ return a.getDefiningOp()->isBeforeInBlock(b.getDefiningOp());
+ }
+
+ return false;
+}
+
+// Try to make successor inputs dead by replacing their uses with values that
+// are not successor inputs. This pattern enables additional canonicalization
+// opportunities for RemoveDeadValues.
+struct RemoveUsesOfIdenticalValues
+ : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+ using OpInterfaceRewritePattern<
+ RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+ PatternRewriter &rewriter) const override {
+ // TODO: ForallOp data flow is modeled incompletely.
+ if (isa<ForallOp>(op))
+ return failure();
+
+ // Gather all potential successor inputs. (Other values may also be
+ // included, but we're not doing anything with them.)
+ SmallVector<Value> values;
+ llvm::append_range(values, op->getResults());
+ for (Region &r : op->getRegions())
+ llvm::append_range(values, r.getArguments());
+
+ bool changed = false;
+ for (Value value : values) {
+ if (value.use_empty())
+ continue;
+ DenseSet<Value> possibleValues =
+ op.computePossibleValuesOfSuccessorInput(value);
+ if (possibleValues.size() == 1 && *possibleValues.begin() != value &&
+ isDefinedBefore(*possibleValues.begin(), value)) {
+ // Value is same as another value.
+ rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+ changed = true;
+ }
+ }
+ return success(changed);
+ }
+};
+
+/// Pattern to remove dead values from region branch ops.
+struct RemoveDeadValues
+ : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+ using OpInterfaceRewritePattern<
+ RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+ PatternRewriter &rewriter) const override {
+ // TODO: ForallOp data flow is modeled incompletely.
+ if (isa<ForallOp>(op))
+ return failure();
+
+ // Compute tied values: values that must come as a set. If you remove one,
+ // you must remove all.
+ RegionBranchSuccessorMapping operandToInputs;
+ op.getSuccessorOperandInputMapping(operandToInputs);
+ llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+ for (const auto &[operand, inputs] : operandToInputs) {
+ assert(!inputs.empty() && "expected non-empty inputs");
+ Value firstInput = inputs.front();
+ tiedSuccessorInputs.insert(firstInput);
+ for (Value nextInput : llvm::drop_begin(inputs))
+ tiedSuccessorInputs.unionSets(firstInput, nextInput);
+ }
+
+ // Determine which values to remove and group them by block and operation.
+ SmallVector<Value> valuesToRemove;
+ DenseMap<Block *, BitVector> blockArgsToRemove;
+ DenseMap<Operation *, BitVector> resultsToRemove;
+ for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
+ it != e; ++it) {
+ if (!(*it)->isLeader())
+ continue;
+
+ // Value can be removed if it is dead and all other tied values are also
+ // dead.
+ bool allDead = true;
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ if (!memberIt->use_empty()) {
+ allDead = false;
+ break;
+ }
+ }
+ if (!allDead)
+ continue;
+
+ // Group values by block and operation.
+ for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
+ memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
+ if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
+ BitVector &vector =
+ blockArgsToRemove
+ .try_emplace(arg.getOwner(),
+ arg.getOwner()->getNumArguments(), false)
+ .first->second;
+ vector.set(arg.getArgNumber());
+ } else {
+ OpResult result = cast<OpResult>(*memberIt);
+ BitVector &vector =
+ resultsToRemove
+ .try_emplace(result.getDefiningOp(),
+ result.getDefiningOp()->getNumResults(), false)
+ .first->second;
+ vector.set(result.getResultNumber());
+ }
+ valuesToRemove.push_back(*memberIt);
+ }
+ }
+
+ if (valuesToRemove.empty())
+ return rewriter.notifyMatchFailure(op, "no values to remove");
+
+ // Find operands that must be removed together with the values.
+ RegionBranchInverseSuccessorMapping inputsToOperands;
+ op.getSuccessorInputOperandMapping(inputsToOperands);
+ DenseMap<Operation *, llvm::BitVector> operandsToRemove;
+ for (Value value : valuesToRemove) {
+ for (OpOperand *operand : inputsToOperands[value]) {
+ BitVector &vector =
+ operandsToRemove
+ .try_emplace(operand->getOwner(),
+ operand->getOwner()->getNumOperands(), false)
+ .first->second;
+ vector.set(operand->getOperandNumber());
+ }
+ }
+
+ // Erase operands.
+ for (auto [op, operands] : operandsToRemove) {
+ rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
+ }
+
+ // Erase block arguments.
+ for (auto [block, blockArgs] : blockArgsToRemove) {
+ rewriter.modifyOpInPlace(block->getParentOp(),
+ [&]() { block->eraseArguments(blockArgs); });
+ }
+
+ // Erase op results.
+ // TODO: Can we move this to RewriterBase, so we have a uniform API,
+ // similar to eraseArguments?
+ for (auto [op, resultsToErase] : resultsToRemove) {
+ rewriter.setInsertionPoint(op);
+ SmallVector<Type> newResultTypes;
+ for (OpResult result : op->getResults())
+ if (!resultsToErase[result.getResultNumber()])
+ newResultTypes.push_back(result.getType());
+ OperationState state(op->getLoc(), op->getName().getStringRef(),
+ op->getOperands(), newResultTypes, op->getAttrs());
+ for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
+ state.addRegion();
+ Operation *newOp = rewriter.create(state);
+ for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
+ // Move all blocks of `region` into `newRegion`.
+ Region &newRegion = newOp->getRegion(index);
+ rewriter.inlineRegionBefore(region, newRegion, newRegion.begin());
+ }
+
+ SmallVector<Value> newResults;
+ unsigned nextLiveResult = 0;
+ for (auto [index, result] : llvm::enumerate(op->getResults())) {
+ if (!resultsToErase[index]) {
+ newResults.push_back(newOp->getResult(nextLiveResult++));
+ } else {
+ newResults.push_back(Value());
+ }
+ }
+ rewriter.replaceOp(op, newResults);
+ }
+
+ return success();
+ }
+};
+
+void *getContainerOwnerOfValue(Value value) {
+ if (auto opResult = llvm::dyn_cast<OpResult>(value))
+ return opResult.getDefiningOp();
+ return llvm::cast<BlockArgument>(value).getOwner();
+}
+
+unsigned getArgOrResultNumber(Value value) {
+ if (auto opResult = llvm::dyn_cast<OpResult>(value))
+ return opResult.getResultNumber();
+ return llvm::cast<BlockArgument>(value).getArgNumber();
+}
+
+/// Pattern to make duplicate successor inputs dead. Two successor inputs are
+/// duplicate if their corresponding successor operands have the same values.
+struct RemoveDuplicateSuccessorInputUses
+ : public OpInterfaceRewritePattern<RegionBranchOpInterface> {
+ using OpInterfaceRewritePattern<
+ RegionBranchOpInterface>::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(RegionBranchOpInterface op,
+ PatternRewriter &rewriter) const override {
+ // TODO: ForallOp data flow is modeled incompletely.
+ if (isa<ForallOp>(op))
+ return failure();
+
+ // Collect all successor inputs and sort them. When dropping the uses of a
+ // successor input, we'd like to also drop the uses of the same tied
+ // successor inputs. Otherwise, a set of tied successor inputs may not
+ // become entirely dead, which is required for RemoveDeadValues to erase
+ // them. (Sorting is not required for correctness.)
+ RegionBranchInverseSuccessorMapping inputsToOperands;
+ op.getSuccessorInputOperandMapping(inputsToOperands);
+ SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
+ llvm::sort(inputs, [](Value a, Value b) {
+ return getArgOrResultNumber(a) < getArgOrResultNumber(b);
+ });
+
+ bool changed = false;
+ for (unsigned i = 0, e = inputs.size(); i < e; i++) {
+ Value input1 = inputs[i];
+ for (unsigned j = i + 1; j < e; j++) {
+ Value input2 = inputs[j];
+ // Nothing to do if input2 is already dead.
+ if (input2.use_empty())
+ continue;
+ // Replace only values of the same kind.
+ if (isa<BlockArgument>(input1) != isa<BlockArgument>(input2))
+ continue;
+ // Replace only values that belong to the same block / operation.
+ if (getContainerOwnerOfValue(input1) !=
+ getContainerOwnerOfValue(input2))
+ continue;
+
+ // Gather the predecessor value for each predecessor (region branch
+ // point). The two inputs are duplicates if each predecessor forwards
+ // the same value.
+ DenseMap<Operation *, Value> operands1, operands2;
+ for (OpOperand *operand : inputsToOperands[input1]) {
+ assert(!operands1.contains(operand->getOwner()));
+ operands1[operand->getOwner()] = operand->get();
+ }
+ for (OpOperand *operand : inputsToOperands[input2]) {
+ assert(!operands2.contains(operand->getOwner()));
+ operands2[operand->getOwner()] = operand->get();
+ }
+ if (operands1 == operands2) {
+ rewriter.replaceAllUsesWith(input2, input1);
+ changed = true;
+ }
+ }
+ }
+ return success(changed);
+ }
+};
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
- context);
+ results.add<SimplifyTrivialLoops, ForOpTensorCastFolder,
+ RemoveUsesOfIdenticalValues, RemoveDeadValues,
+ RemoveDuplicateSuccessorInputUses>(context);
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -2409,61 +2448,6 @@ void IfOp::getRegionInvocationBounds(
}
namespace {
-// Pattern to remove unused IfOp results.
-struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
- PatternRewriter &rewriter) const {
- // Move all operations to the destination block.
- rewriter.mergeBlocks(source, dest);
- // Replace the yield op by one that returns only the used values.
- auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
- SmallVector<Value, 4> usedOperands;
- llvm::transform(usedResults, std::back_inserter(usedOperands),
- [&](OpResult result) {
- return yieldOp.getOperand(result.getResultNumber());
- });
- rewriter.modifyOpInPlace(yieldOp,
- [&]() { yieldOp->setOperands(usedOperands); });
- }
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- // Compute the list of used results.
- SmallVector<OpResult, 4> usedResults;
- llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
- [](OpResult result) { return !result.use_empty(); });
-
- // Replace the operation if only a subset of its results have uses.
- if (usedResults.size() == op.getNumResults())
- return failure();
-
- // Compute the result types of the replacement operation.
- SmallVector<Type, 4> newTypes;
- llvm::transform(usedResults, std::back_inserter(newTypes),
- [](OpResult result) { return result.getType(); });
-
- // Create a replacement operation with empty then and else regions.
- auto newOp =
- IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
- rewriter.createBlock(&newOp.getThenRegion());
- rewriter.createBlock(&newOp.getElseRegion());
-
- // Move the bodies and replace the terminators (note there is a then and
- // an else region since the operation returns results).
- transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
- transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
-
- // Replace the operation by the new one.
- SmallVector<Value, 4> repResults(op.getNumResults());
- for (const auto &en : llvm::enumerate(usedResults))
- repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
- rewriter.replaceOp(op, repResults);
- return success();
- }
-};
-
struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -3034,8 +3018,8 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, RemoveUnusedResults,
- ReplaceIfYieldWithConditionOrValue>(context);
+ RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
+ context);
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3873,390 +3857,6 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
}
};
-/// Remove loop invariant arguments from `before` block of scf.while.
-/// A before block argument is considered loop invariant if :-
-/// 1. i-th yield operand is equal to the i-th while operand.
-/// 2. i-th yield operand is k-th after block argument which is (k+1)-th
-/// condition operand AND this (k+1)-th condition operand is equal to i-th
-/// iter argument/while operand.
-/// For the arguments which are removed, their uses inside scf.while
-/// are replaced with their corresponding initial value.
-///
-/// Eg:
-/// INPUT :-
-/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
-/// ..., %argN_before = %N)
-/// {
-/// ...
-/// scf.condition(%cond) %arg1_before, %arg0_before,
-/// %arg2_before, %arg0_before, ...
-/// } do {
-/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-/// ..., %argK_after):
-/// ...
-/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
-/// }
-///
-/// OUTPUT :-
-/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
-/// %N)
-/// {
-/// ...
-/// scf.condition(%cond) %b, %a, %arg2_before, %a, ...
-/// } do {
-/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
-/// ..., %argK_after):
-/// ...
-/// scf.yield %arg1_after, ..., %argN
-/// }
-///
-/// EXPLANATION:
-/// We iterate over each yield operand.
-/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand
-/// %arg0_before, which in turn is the 0-th iter argument. So we
-/// remove 0-th before block argument and yield operand, and replace
-/// all uses of the 0-th before block argument with its initial value
-/// %a.
-/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial
-/// value. So we remove this operand and the corresponding before
-/// block argument and replace all uses of 1-th before block argument
-/// with %b.
-struct RemoveLoopInvariantArgsFromBeforeBlock
- : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- Block &afterBlock = *op.getAfterBody();
- Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
- ConditionOp condOp = op.getConditionOp();
- OperandRange condOpArgs = condOp.getArgs();
- Operation *yieldOp = afterBlock.getTerminator();
- ValueRange yieldOpArgs = yieldOp->getOperands();
-
- bool canSimplify = false;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
- auto index = static_cast<unsigned>(it.index());
- auto [initVal, yieldOpArg] = it.value();
- // If i-th yield operand is equal to the i-th operand of the scf.while,
- // the i-th before block argument is a loop invariant.
- if (yieldOpArg == initVal) {
- canSimplify = true;
- break;
- }
- // If the i-th yield operand is k-th after block argument, then we check
- // if the (k+1)-th condition op operand is equal to either the i-th before
- // block argument or the initial value of i-th before block argument. If
- // the comparison results `true`, i-th before block argument is a loop
- // invariant.
- auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
- if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
- Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
- if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
- canSimplify = true;
- break;
- }
- }
- }
-
- if (!canSimplify)
- return failure();
-
- SmallVector<Value> newInitArgs, newYieldOpArgs;
- DenseMap<unsigned, Value> beforeBlockInitValMap;
- SmallVector<Location> newBeforeBlockArgLocs;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
- auto index = static_cast<unsigned>(it.index());
- auto [initVal, yieldOpArg] = it.value();
-
- // If i-th yield operand is equal to the i-th operand of the scf.while,
- // the i-th before block argument is a loop invariant.
- if (yieldOpArg == initVal) {
- beforeBlockInitValMap.insert({index, initVal});
- continue;
- } else {
- // If the i-th yield operand is k-th after block argument, then we check
- // if the (k+1)-th condition op operand is equal to either the i-th
- // before block argument or the initial value of i-th before block
- // argument. If the comparison results `true`, i-th before block
- // argument is a loop invariant.
- auto yieldOpBlockArg = llvm::dyn_cast<BlockArgument>(yieldOpArg);
- if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
- Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
- if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
- beforeBlockInitValMap.insert({index, initVal});
- continue;
- }
- }
- }
- newInitArgs.emplace_back(initVal);
- newYieldOpArgs.emplace_back(yieldOpArg);
- newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
- }
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(yieldOp);
- rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
- }
-
- auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(),
- newInitArgs);
-
- Block &newBeforeBlock = *rewriter.createBlock(
- &newWhile.getBefore(), /*insertPt*/ {},
- ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
-
- Block &beforeBlock = *op.getBeforeBody();
- SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
- // For each i-th before block argument we find it's replacement value as :-
- // 1. If i-th before block argument is a loop invariant, we fetch it's
- // initial value from `beforeBlockInitValMap` by querying for key `i`.
- // 2. Else we fetch j-th new before block argument as the replacement
- // value of i-th before block argument.
- for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
- // If the index 'i' argument was a loop invariant we fetch it's initial
- // value from `beforeBlockInitValMap`.
- if (beforeBlockInitValMap.count(i) != 0)
- newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
- else
- newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
- }
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
- rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
- newWhile.getAfter().begin());
-
- rewriter.replaceOp(op, newWhile.getResults());
- return success();
- }
-};
-
-/// Remove loop invariant value from result (condition op) of scf.while.
-/// A value is considered loop invariant if the final value yielded by
-/// scf.condition is defined outside of the `before` block. We remove the
-/// corresponding argument in `after` block and replace the use with the value.
-/// We also replace the use of the corresponding result of scf.while with the
-/// value.
-///
-/// Eg:
-/// INPUT :-
-/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
-/// %argN_before = %N) {
-/// ...
-/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
-/// } do {
-/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
-/// ...
-/// some_func(%arg1_after)
-/// ...
-/// scf.yield %arg0_after, %arg2_after, ..., %argN_after
-/// }
-///
-/// OUTPUT :-
-/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
-/// ...
-/// scf.condition(%cond) %arg0, %arg1, ..., %argM
-/// } do {
-/// ^bb0(%arg0, %arg3, ..., %argM):
-/// ...
-/// some_func(%a)
-/// ...
-/// scf.yield %arg0, %b, ..., %argN
-/// }
-///
-/// EXPLANATION:
-/// 1. The 1-th and 2-th operand of scf.condition are defined outside the
-/// before block of scf.while, so they get removed.
-/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
-/// replaced by %b.
-/// 3. The corresponding after block argument %arg1_after's uses are
-/// replaced by %a and %arg2_after's uses are replaced by %b.
-struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- Block &beforeBlock = *op.getBeforeBody();
- ConditionOp condOp = op.getConditionOp();
- OperandRange condOpArgs = condOp.getArgs();
-
- bool canSimplify = false;
- for (Value condOpArg : condOpArgs) {
- // Those values not defined within `before` block will be considered as
- // loop invariant values. We map the corresponding `index` with their
- // value.
- if (condOpArg.getParentBlock() != &beforeBlock) {
- canSimplify = true;
- break;
- }
- }
-
- if (!canSimplify)
- return failure();
-
- Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
-
- SmallVector<Value> newCondOpArgs;
- SmallVector<Type> newAfterBlockType;
- DenseMap<unsigned, Value> condOpInitValMap;
- SmallVector<Location> newAfterBlockArgLocs;
- for (const auto &it : llvm::enumerate(condOpArgs)) {
- auto index = static_cast<unsigned>(it.index());
- Value condOpArg = it.value();
- // Those values not defined within `before` block will be considered as
- // loop invariant values. We map the corresponding `index` with their
- // value.
- if (condOpArg.getParentBlock() != &beforeBlock) {
- condOpInitValMap.insert({index, condOpArg});
- } else {
- newCondOpArgs.emplace_back(condOpArg);
- newAfterBlockType.emplace_back(condOpArg.getType());
- newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
- }
- }
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(condOp);
- rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
- newCondOpArgs);
- }
-
- auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType,
- op.getOperands());
-
- Block &newAfterBlock =
- *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
- newAfterBlockType, newAfterBlockArgLocs);
-
- Block &afterBlock = *op.getAfterBody();
- // Since a new scf.condition op was created, we need to fetch the new
- // `after` block arguments which will be used while replacing operations of
- // previous scf.while's `after` blocks. We'd also be fetching new result
- // values too.
- SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
- SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
- for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
- Value afterBlockArg, result;
- // If index 'i' argument was loop invariant we fetch it's value from the
- // `condOpInitMap` map.
- if (condOpInitValMap.count(i) != 0) {
- afterBlockArg = condOpInitValMap[i];
- result = afterBlockArg;
- } else {
- afterBlockArg = newAfterBlock.getArgument(j);
- result = newWhile.getResult(j);
- j++;
- }
- newAfterBlockArgs[i] = afterBlockArg;
- newWhileResults[i] = result;
- }
-
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
- rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
- newWhile.getBefore().begin());
-
- rewriter.replaceOp(op, newWhileResults);
- return success();
- }
-};
-
-/// Remove WhileOp results that are also unused in 'after' block.
-///
-/// %0:2 = scf.while () : () -> (i32, i64) {
-/// %condition = "test.condition"() : () -> i1
-/// %v1 = "test.get_some_value"() : () -> i32
-/// %v2 = "test.get_some_value"() : () -> i64
-/// scf.condition(%condition) %v1, %v2 : i32, i64
-/// } do {
-/// ^bb0(%arg0: i32, %arg1: i64):
-/// "test.use"(%arg0) : (i32) -> ()
-/// scf.yield
-/// }
-/// return %0#0 : i32
-///
-/// becomes
-/// %0 = scf.while () : () -> (i32) {
-/// %condition = "test.condition"() : () -> i1
-/// %v1 = "test.get_some_value"() : () -> i32
-/// %v2 = "test.get_some_value"() : () -> i64
-/// scf.condition(%condition) %v1 : i32
-/// } do {
-/// ^bb0(%arg0: i32):
-/// "test.use"(%arg0) : (i32) -> ()
-/// scf.yield
-/// }
-/// return %0 : i32
-struct WhileUnusedResult : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- auto term = op.getConditionOp();
- auto afterArgs = op.getAfterArguments();
- auto termArgs = term.getArgs();
-
- // Collect results mapping, new terminator args and new result types.
- SmallVector<unsigned> newResultsIndices;
- SmallVector<Type> newResultTypes;
- SmallVector<Value> newTermArgs;
- SmallVector<Location> newArgLocs;
- bool needUpdate = false;
- for (const auto &it :
- llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) {
- auto i = static_cast<unsigned>(it.index());
- Value result = std::get<0>(it.value());
- Value afterArg = std::get<1>(it.value());
- Value termArg = std::get<2>(it.value());
- if (result.use_empty() && afterArg.use_empty()) {
- needUpdate = true;
- } else {
- newResultsIndices.emplace_back(i);
- newTermArgs.emplace_back(termArg);
- newResultTypes.emplace_back(result.getType());
- newArgLocs.emplace_back(result.getLoc());
- }
- }
-
- if (!needUpdate)
- return failure();
-
- {
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(term);
- rewriter.replaceOpWithNewOp<ConditionOp>(term, term.getCondition(),
- newTermArgs);
- }
-
- auto newWhile =
- WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits());
-
- Block &newAfterBlock = *rewriter.createBlock(
- &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs);
-
- // Build new results list and new after block args (unused entries will be
- // null).
- SmallVector<Value> newResults(op.getNumResults());
- SmallVector<Value> newAfterBlockArgs(op.getNumResults());
- for (const auto &it : llvm::enumerate(newResultsIndices)) {
- newResults[it.value()] = newWhile.getResult(it.index());
- newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index());
- }
-
- rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
- newWhile.getBefore().begin());
-
- Block &afterBlock = *op.getAfterBody();
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
-
- rewriter.replaceOp(op, newResults);
- return success();
- }
-};
-
/// Replace operations equivalent to the condition in the do block with true,
/// since otherwise the block would not be evaluated.
///
@@ -4321,127 +3921,6 @@ struct WhileCmpCond : public OpRewritePattern<scf::WhileOp> {
}
};
-/// Remove unused init/yield args.
-struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> {
- using OpRewritePattern<WhileOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
-
- if (!llvm::any_of(op.getBeforeArguments(),
- [](Value arg) { return arg.use_empty(); }))
- return rewriter.notifyMatchFailure(op, "No args to remove");
-
- YieldOp yield = op.getYieldOp();
-
- // Collect results mapping, new terminator args and new result types.
- SmallVector<Value> newYields;
- SmallVector<Value> newInits;
- llvm::BitVector argsToErase;
-
- size_t argsCount = op.getBeforeArguments().size();
- newYields.reserve(argsCount);
- newInits.reserve(argsCount);
- argsToErase.reserve(argsCount);
- for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip(
- op.getBeforeArguments(), yield.getOperands(), op.getInits())) {
- if (beforeArg.use_empty()) {
- argsToErase.push_back(true);
- } else {
- argsToErase.push_back(false);
- newYields.emplace_back(yieldValue);
- newInits.emplace_back(initValue);
- }
- }
-
- Block &beforeBlock = *op.getBeforeBody();
- Block &afterBlock = *op.getAfterBody();
-
- beforeBlock.eraseArguments(argsToErase);
-
- Location loc = op.getLoc();
- auto newWhileOp =
- WhileOp::create(rewriter, loc, op.getResultTypes(), newInits,
- /*beforeBody*/ nullptr, /*afterBody*/ nullptr);
- Block &newBeforeBlock = *newWhileOp.getBeforeBody();
- Block &newAfterBlock = *newWhileOp.getAfterBody();
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(yield);
- rewriter.replaceOpWithNewOp<YieldOp>(yield, newYields);
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
- newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock,
- newAfterBlock.getArguments());
-
- rewriter.replaceOp(op, newWhileOp.getResults());
- return success();
- }
-};
-
-/// Remove duplicated ConditionOp args.
-struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(WhileOp op,
- PatternRewriter &rewriter) const override {
- ConditionOp condOp = op.getConditionOp();
- ValueRange condOpArgs = condOp.getArgs();
-
- llvm::SmallPtrSet<Value, 8> argsSet(llvm::from_range, condOpArgs);
-
- if (argsSet.size() == condOpArgs.size())
- return rewriter.notifyMatchFailure(op, "No results to remove");
-
- llvm::SmallDenseMap<Value, unsigned> argsMap;
- SmallVector<Value> newArgs;
- argsMap.reserve(condOpArgs.size());
- newArgs.reserve(condOpArgs.size());
- for (Value arg : condOpArgs) {
- if (!argsMap.count(arg)) {
- auto pos = static_cast<unsigned>(argsMap.size());
- argsMap.insert({arg, pos});
- newArgs.emplace_back(arg);
- }
- }
-
- ValueRange argsRange(newArgs);
-
- Location loc = op.getLoc();
- auto newWhileOp =
- scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(),
- /*beforeBody*/ nullptr,
- /*afterBody*/ nullptr);
- Block &newBeforeBlock = *newWhileOp.getBeforeBody();
- Block &newAfterBlock = *newWhileOp.getAfterBody();
-
- SmallVector<Value> afterArgsMapping;
- SmallVector<Value> resultsMapping;
- for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) {
- auto it = argsMap.find(arg);
- assert(it != argsMap.end());
- auto pos = it->second;
- afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos));
- resultsMapping.emplace_back(newWhileOp->getResult(pos));
- }
-
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(condOp);
- rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
- argsRange);
-
- Block &beforeBlock = *op.getBeforeBody();
- Block &afterBlock = *op.getAfterBody();
-
- rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock,
- newBeforeBlock.getArguments());
- rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping);
- rewriter.replaceOp(op, resultsMapping);
- return success();
- }
-};
-
/// If both ranges contain same values return mappping indices from args2 to
/// args1. Otherwise return std::nullopt.
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
@@ -4532,11 +4011,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<RemoveLoopInvariantArgsFromBeforeBlock,
- RemoveLoopInvariantValueYielded, WhileConditionTruth,
- WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
- WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>(
- context);
+ results.add<WhileConditionTruth, WhileCmpCond, WhileOpAlignBeforeArgs,
+ WhileMoveIfDown>(context);
}
//===----------------------------------------------------------------------===//
@@ -4711,59 +4187,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IndexSwitchOp op,
- PatternRewriter &rewriter) const override {
- // Find dead results.
- BitVector deadResults(op.getNumResults(), false);
- SmallVector<Type> newResultTypes;
- for (auto [idx, result] : llvm::enumerate(op.getResults())) {
- if (!result.use_empty()) {
- newResultTypes.push_back(result.getType());
- } else {
- deadResults[idx] = true;
- }
- }
- if (!deadResults.any())
- return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
- // Create new op without dead results and inline case regions.
- auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
- op.getArg(), op.getCases(),
- op.getCaseRegions().size());
- auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
- rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
- // Remove respective operands from yield op.
- Operation *terminator = newRegion.front().getTerminator();
- assert(isa<YieldOp>(terminator) && "expected yield op");
- rewriter.modifyOpInPlace(
- terminator, [&]() { terminator->eraseOperands(deadResults); });
- };
- for (auto [oldRegion, newRegion] :
- llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
- inlineCaseRegion(oldRegion, newRegion);
- inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
- // Replace op with new op.
- SmallVector<Value> newResults(op.getNumResults(), Value());
- unsigned nextNewResult = 0;
- for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
- if (deadResults[idx])
- continue;
- newResults[idx] = newOp.getResult(nextNewResult++);
- }
- rewriter.replaceOp(op, newResults);
- return success();
- }
-};
-
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+ results.add<FoldConstantCase>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index d393ddb8d8336..ed94205d32f19 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -521,6 +521,45 @@ void RegionBranchOpInterface::getSuccessorOperandInputMapping(
}
}
+void RegionBranchOpInterface::getSuccessorInputOperandMapping(
+ RegionBranchInverseSuccessorMapping &mapping) {
+ RegionBranchSuccessorMapping operandToInputs;
+ getSuccessorOperandInputMapping(operandToInputs);
+ for (const auto &[operand, inputs] : operandToInputs) {
+ for (Value input : inputs)
+ mapping[input].push_back(operand);
+ }
+}
+
+DenseSet<Value>
+RegionBranchOpInterface::computePossibleValuesOfSuccessorInput(Value value) {
+ RegionBranchInverseSuccessorMapping inputToOperands;
+ getSuccessorInputOperandMapping(inputToOperands);
+
+ DenseSet<Value> possibleValues;
+ DenseSet<Value> visited;
+ SmallVector<Value> worklist;
+
+ // Starting with the given value, trace back all predecessor values (i.e.,
+ // preceding successor operands) and add them to the set of possible values.
+ // If the successor operand is again a successor input, do not add it to
+ // result set, but instead continue the traversal.
+ worklist.push_back(value);
+ while (!worklist.empty()) {
+ Value next = worklist.pop_back_val();
+ auto it = inputToOperands.find(next);
+ if (it == inputToOperands.end()) {
+ possibleValues.insert(next);
+ continue;
+ }
+ for (OpOperand *operand : it->second)
+ if (visited.insert(operand->get()).second)
+ worklist.push_back(operand->get());
+ }
+
+ return possibleValues;
+}
+
SmallVector<RegionBranchPoint>
RegionBranchOpInterface::getAllRegionBranchPoints() {
SmallVector<RegionBranchPoint> branchPoints;
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index d5d0aee3bbe25..11dc4f04af32e 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1071,17 +1071,17 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
// CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
-// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
+// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
// CHECK: tensor.extract %{{.*}}[]
-// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
+// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]]
// CHECK: } do {
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>):
// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
-// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
-// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
+// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG2]]
+// CHECK: scf.yield %[[VAL0]], %[[VAL1]]
// CHECK: }
-// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
+// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#1, %[[ZERO]]
// CHECK-LABEL: @while_loop_invariant_argument_different_order
func.func @while_loop_invariant_argument_different_order(%arg : tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
@@ -1736,11 +1736,11 @@ module {
// Test case with multiple scf.yield ops with at least one different operand, then no change.
-// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+// CHECK: %[[VAL_3:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
// CHECK: ^bb1:
-// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: scf.yield %{{.*}} : memref<1x120xui8>
// CHECK: ^bb2:
-// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: scf.yield %{{.*}} : memref<1x120xui8>
// CHECK: }
module {
@@ -2178,16 +2178,14 @@ func.func @scf_for_all_step_size_0() {
// CHECK-SAME: %[[arg0:.*]]: index
// CHECK-DAG: %[[c10:.*]] = arith.constant 10
// CHECK-DAG: %[[c11:.*]] = arith.constant 11
-// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: scf.index_switch %[[arg0]]
// CHECK: case 1 {
// CHECK: memref.store %[[c10]]
-// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
// CHECK: default {
// CHECK: memref.store %[[c11]]
-// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
-// CHECK: return %[[switch]]
+// CHECK: return %[[arg0]]
func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
%non_live, %live = scf.index_switch %arg0 -> i32, index
case 1 {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
index 5f2aa5e3a2736..b0a1af31d8806 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir
@@ -129,13 +129,13 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref<?xindex>
// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref<?xindex>
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_32:.*]]:4 = scf.while (%[[VAL_33:.*]] = %[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]], %[[VAL_36:.*]] = %[[VAL_21]]) : (index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>) -> (index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_32:.*]]:3 = scf.while (%[[VAL_33:.*]] = %[[VAL_27]], %[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_26]]) : (index, index, index) -> (index, index, index) {
// CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_29]] : index
// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_31]] : index
// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_37]], %[[VAL_38]] : i1
-// CHECK: scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]] : index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK: scf.condition(%[[VAL_39]]) %[[VAL_33]], %[[VAL_34]], %[[VAL_35]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index, %[[VAL_43:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>):
+// CHECK: ^bb0(%[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index):
// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_40]]] : memref<?xindex>
// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref<?xindex>
// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index
@@ -143,7 +143,7 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index
// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index
// CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1
-// CHECK: %[[VAL_51:.*]]:2 = scf.if %[[VAL_50]] -> (index, tensor<4x4xf64, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) {
// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_40]]] : memref<?xf64>
// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref<?xindex>
// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index
@@ -167,9 +167,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
// CHECK: memref.store %[[VAL_63]], %[[VAL_23]]{{\[}}%[[VAL_59]]] : memref<?xf64>
// CHECK: scf.yield %[[VAL_68:.*]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_69:.*]], %[[VAL_43]] : index, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_69:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_42]], %[[VAL_43]] : index, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_42]] : index
// CHECK: }
// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index
// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index
@@ -177,9 +177,9 @@ func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>,
// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index
// CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index
// CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_41]] : index
-// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]]#0, %[[VAL_76]]#1 : index, index, index, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index
// CHECK: }
-// CHECK: %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_78]]#3{{\[}}%[[VAL_22]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_77:.*]] = sparse_tensor.compress %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_78:.*]]#2 into %[[VAL_21]]{{\[}}%[[VAL_22]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<4x4xf64, #sparse{{[0-9]*}}>
// CHECK: scf.yield %[[VAL_77]] : tensor<4x4xf64, #sparse{{[0-9]*}}>
// CHECK: }
// CHECK: %[[VAL_79:.*]] = sparse_tensor.load %[[VAL_80:.*]] hasInserts : tensor<4x4xf64, #sparse{{[0-9]*}}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
index 4dff06b8155dd..67d1573058460 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir
@@ -216,13 +216,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref<?xindex>
// CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index
// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref<?xindex>
-// CHECK: %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) -> (index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]]) : (index, index, i32, i1) -> (index, index, i32, i1) {
// CHECK: %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index
// CHECK: %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index
// CHECK: %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1
-// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]] : index, index, i32, i1
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse{{[0-9]*}}>):
+// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1):
// CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref<?xindex>
// CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref<?xindex>
// CHECK: %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index
@@ -230,14 +230,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
// CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
// CHECK: %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
// CHECK: %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1
-// CHECK: %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, i1) {
// CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref<?xi32>
// CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref<?xi32>
// CHECK: %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32
// CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32
-// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]] : i32, i1
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]] : i32, i1
// CHECK: }
// CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
// CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index
@@ -245,13 +245,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
// CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
// CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index
// CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index
-// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, i1
// CHECK: }
// CHECK: %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor<?x?xi32, #sparse{{[0-9]*}}>) {
-// CHECK: %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_105:.*]] = tensor.insert %[[VAL_74]]#2 into %[[VAL_59]]{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse{{[0-9]*}}>
// CHECK: scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse{{[0-9]*}}>
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_74]]#4 : tensor<?x?xi32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_59]] : tensor<?x?xi32, #sparse{{[0-9]*}}>
// CHECK: }
// CHECK: scf.yield %[[VAL_202]] : tensor<?x?xi32, #sparse{{[0-9]*}}>
// CHECK: } else {
@@ -339,13 +339,13 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xindex>
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_2]]] : memref<?xindex>
// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref<?xindex>
-// CHECK: %[[VAL_34:.*]]:4 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]], %[[VAL_38:.*]] = %[[VAL_23]]) : (index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>) -> (index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : (index, index, index) -> (index, index, index) {
// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_31]] : index
// CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_33]] : index
// CHECK: %[[VAL_41:.*]] = arith.andi %[[VAL_39]], %[[VAL_40]] : i1
-// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]], %[[VAL_38]] : index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : index, index, index
// CHECK: } do {
-// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index, %[[VAL_45:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>):
+// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index):
// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_42]]] : memref<?xindex>
// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_43]]] : memref<?xindex>
// CHECK: %[[VAL_48:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_46]] : index
@@ -353,7 +353,7 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
// CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_49]] : index
// CHECK: %[[VAL_51:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_49]] : index
// CHECK: %[[VAL_52:.*]] = arith.andi %[[VAL_50]], %[[VAL_51]] : i1
-// CHECK: %[[VAL_53:.*]]:2 = scf.if %[[VAL_52]] -> (index, tensor<?x?xf32, #sparse{{[0-9]*}}>) {
+// CHECK: %[[VAL_53:.*]] = scf.if %[[VAL_52]] -> (index) {
// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_42]]] : memref<?xf32>
// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_43]]] : memref<?xindex>
// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index
@@ -377,9 +377,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
// CHECK: memref.store %[[VAL_65]], %[[VAL_25]]{{\[}}%[[VAL_61]]] : memref<?xf32>
// CHECK: scf.yield %[[VAL_70:.*]] : index
// CHECK: }
-// CHECK: scf.yield %[[VAL_71:.*]], %[[VAL_45]] : index, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_71:.*]] : index
// CHECK: } else {
-// CHECK: scf.yield %[[VAL_44]], %[[VAL_45]] : index, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_44]] : index
// CHECK: }
// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_49]] : index
// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index
@@ -387,9 +387,9 @@ func.func @sumred(%arga: tensor<?x?x?xi32, #SparseTensor>,
// CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_47]], %[[VAL_49]] : index
// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index
// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_43]] : index
-// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]]#0, %[[VAL_78]]#1 : index, index, index, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]] : index, index, index
// CHECK: }
-// CHECK: %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_80]]#3{{\[}}%[[VAL_24]]] : memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse{{[0-9]*}}>
+// CHECK: %[[VAL_79:.*]] = sparse_tensor.compress %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_80:.*]]#2 into %[[VAL_23]]{{\[}}%[[VAL_24]]] : memref<?xf32>, memref<?xi1>, memref<?xindex>, tensor<?x?xf32, #sparse{{[0-9]*}}>
// CHECK: scf.yield %[[VAL_79]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
// CHECK: }
// CHECK: %[[VAL_81:.*]] = sparse_tensor.load %[[VAL_82:.*]] hasInserts : tensor<?x?xf32, #sparse{{[0-9]*}}>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 135db02d543ef..18fb6852f6875 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1330,11 +1330,11 @@ func.func @vector_insert_1d_broadcast(%laneid: index, %pos: index) -> (vector<96
// -----
// CHECK-PROP-LABEL: func @vector_insert_0d(
-// CHECK-PROP: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32)
+// CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (f32)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
-// CHECK-PROP: gpu.yield %[[VEC]], %[[VAL]]
-// CHECK-PROP: vector.broadcast %[[W]]#1 : f32 to vector<f32>
+// CHECK-PROP: gpu.yield %[[VAL]]
+// CHECK-PROP: vector.broadcast %[[W]] : f32 to vector<f32>
func.func @vector_insert_0d(%laneid: index) -> (vector<f32>) {
%r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
%0 = "some_def"() : () -> (vector<f32>)
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index b9a883dbd524e..5bf5487974d35 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -250,13 +250,13 @@ func.func @main() -> (i32, i32) {
// CHECK-NEXT: }
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
+// CHECK-CANONICALIZE: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
// CHECK-CANONICALIZE-NEXT: } do {
// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE: scf.yield %[[live_1]] : i32
// CHECK-CANONICALIZE-NEXT: }
// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0
// CHECK-CANONICALIZE-NEXT: }
@@ -306,7 +306,7 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 {
// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0
// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1
-// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> ()
// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32
// CHECK-CANONICALIZE-NEXT: } do {
More information about the llvm-branch-commits
mailing list