[Mlir-commits] [mlir] [mlir] move if-condition propagation to a standalone pass (PR #150278)
Mehdi Amini
llvmlistbot at llvm.org
Wed Sep 17 01:22:26 PDT 2025
================
@@ -0,0 +1,98 @@
+//===- IfConditionPropagation.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a pass for constant propagation of the condition of an
+// `scf.if` into its then and else regions as true and false respectively.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+/// Traverses the IR recursively (on region tree) and updates the uses of a
+/// value also as the condition of an `scf.if` to either `true` or `false`
+/// constants in the `then` and `else regions. This is done as a single
+/// post-order sweep over the IR (without `walk`) for efficiency reasons. While
+/// traversing, the function maintains the set of visited regions to quickly
+/// identify whether the value belong to a region that is known to be nested in
+/// the `then` or `else` branch of a specific loop.
+static void propagateIfConditionsImpl(Operation *root,
+ llvm::SmallPtrSet<Region *, 8> &visited) {
+ if (auto scfIf = dyn_cast<scf::IfOp>(root)) {
+ llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren;
+ // Visit the "then" region, collect children.
+ for (Block &block : scfIf.getThenRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, thenChildren);
+ }
+ }
+
+ // Visit the "else" region, collect children.
+ for (Block &block : scfIf.getElseRegion()) {
+ for (Operation &op : block) {
+ propagateIfConditionsImpl(&op, elseChildren);
+ }
+ }
+
+ // Update uses to point to constants instead.
+ OpBuilder builder(scfIf);
+ Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ /*value=*/true, /*width=*/1);
+ Value falseValue =
+ arith::ConstantIntOp::create(builder, scfIf.getLoc(),
+ /*value=*/false, /*width=*/1);
+
+ for (OpOperand &use : scfIf.getCondition().getUses()) {
+ if (thenChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(trueValue);
+ else if (elseChildren.contains(use.getOwner()->getParentRegion()))
+ use.set(falseValue);
+ }
+ if (trueValue.getUses().empty())
+ trueValue.getDefiningOp()->erase();
+ if (falseValue.getUses().empty())
+ falseValue.getDefiningOp()->erase();
+
+ // Append the two lists of children and return them.
+ visited.insert_range(thenChildren);
+ visited.insert_range(elseChildren);
----------------
joker-eph wrote:
I don't quite understand the logic for the `visited` set, the comment on the function says:
```
/// While
/// traversing, the function maintains the set of visited regions to quickly
/// identify whether the value belong to a region that is known to be nested in
/// the `then` or `else` branch of a specific loop.
```
But that does not seem to be the case, the visited set is only ever used for inserting here, never checked as far as I can see?
https://github.com/llvm/llvm-project/pull/150278
More information about the Mlir-commits
mailing list