[llvm] [InstCombine] Make backedge check in op of phi transform more precise (PR #106075)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 06:55:29 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/106075

The op of phi transform wants to prevent moving an operation across a backedge, as this may lead to an infinite combine loop.

Currently, this is done using isPotentiallyReachable(). The problem with that is that all blocks inside a loop are reachable from each other. This means that the op of phi transform is effectively completely disabled for code inside loops, even when it's not actually operating on a loop phi (just a phi that happens to be in a loop).

Fix this by explicitly computing the backedges inside the function instead. Do this via RPOT, which is a bit more efficient than using FindFunctionBackedges() (which does it without any pre-computed analyses).

For irreducible cycles, there may be multiple possible choices of backedge, and this just picks one of them. This is still sufficient to prevent combine loops.

This also removes the last use of LoopInfo in InstCombine -- I'll drop the analysis in a followup.

The change has some compile-time impact: http://llvm-compile-time-tracker.com/compare.php?from=4549a8d251cfa91cc6230139595f0b7efdf199d9&to=d1b84c8bb23048e02097d8489e6de0f25ceaa2b8&stat=instructions%3Au I believe this is due to second order changes (the regressions correlate with code size increases).

>From 131c3e1ffd5b76428e60999f2efa7dcab9deb366 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 26 Aug 2024 12:54:26 +0200
Subject: [PATCH] [InstCombine] Make backedge check in op of phi transform more
 precise

The op of phi transform wants to prevent moving an operation across
a backedge, as this may lead to an infinite combine loop.

Currently, this is done using isPotentiallyReachable(). The problem
with that is that all blocks inside a loop are reachable from each
other. This means that the op of phi transform is effectively
completely disabled for code inside loops, even when it's not
actually operating on a loop phi (just a phi that happens to be
in a loop).

Fix this by explicitly computing the backedges inside the function
instead. Do this via RPOT, which is a bit more efficient than
using FindFunctionBackedges() (which does it without any pre-computed
analyses).

For irreducible cycles, there may be multiple possible choices of
backedge, and this just picks one of them. This is still sufficient
to prevent combine loops.
---
 .../llvm/Transforms/InstCombine/InstCombiner.h | 13 +++++++++++++
 .../InstCombine/InstructionCombining.cpp       | 18 ++++++++++++++----
 llvm/test/Transforms/InstCombine/phi.ll        |  6 +++---
 .../test/Transforms/LoopVectorize/induction.ll |  6 +++---
 4 files changed, 33 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index c2ea88a107c32a..05322f7650efc7 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -94,6 +94,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// Order of predecessors to canonicalize phi nodes towards.
   SmallDenseMap<BasicBlock *, SmallVector<BasicBlock *>, 8> PredOrder;
 
+  /// Backedges, used to avoid pushing instructions across backedges in cases
+  /// where this may result in infinite combine loops. For irreducible loops
+  /// this picks an arbitrary backedge.
+  SmallDenseSet<std::pair<const BasicBlock *, const BasicBlock *>, 8> BackEdges;
+  bool ComputedBackEdges = false;
+
 public:
   InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder,
                bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
@@ -359,6 +365,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
       std::function<void(Instruction *, unsigned, APInt, APInt &)>
           SimplifyAndSetOp);
 
+  void computeBackEdges();
+  bool isBackEdge(const BasicBlock *From, const BasicBlock *To) {
+    if (!ComputedBackEdges)
+      computeBackEdges();
+    return BackEdges.contains({From, To});
+  }
+
   /// Inserts an instruction \p New before instruction \p Old
   ///
   /// Also adds the new instruction to the worklist and returns \p New so that
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8a96d1d0fb4c90..7c42c91abeb472 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1812,12 +1812,10 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
       if (cast<Instruction>(InVal)->getParent() == NonSimplifiedBB)
         return nullptr;
 
-    // If the incoming non-constant value is reachable from the phis block,
-    // we'll push the operation across a loop backedge. This could result in
+    // Do not push the operation across a loop backedge. This could result in
     // an infinite combine loop, and is generally non-profitable (especially
     // if the operation was originally outside the loop).
-    if (isPotentiallyReachable(PN->getParent(), NonSimplifiedBB, nullptr, &DT,
-                               LI))
+    if (isBackEdge(NonSimplifiedBB, PN->getParent()))
       return nullptr;
   }
 
@@ -5373,6 +5371,18 @@ bool InstCombinerImpl::prepareWorklist(Function &F) {
   return MadeIRChange;
 }
 
+void InstCombiner::computeBackEdges() {
+  // Collect backedges.
+  SmallPtrSet<BasicBlock *, 16> Visited;
+  for (BasicBlock *BB : RPOT) {
+    Visited.insert(BB);
+    for (BasicBlock *Succ : successors(BB))
+      if (Visited.contains(Succ))
+        BackEdges.insert({BB, Succ});
+  }
+  ComputedBackEdges = true;
+}
+
 static bool combineInstructionsOverFunction(
     Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA,
     AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
diff --git a/llvm/test/Transforms/InstCombine/phi.ll b/llvm/test/Transforms/InstCombine/phi.ll
index 673c8f6c9488d6..6e6c5ac8e75a61 100644
--- a/llvm/test/Transforms/InstCombine/phi.ll
+++ b/llvm/test/Transforms/InstCombine/phi.ll
@@ -2721,11 +2721,11 @@ define void @phi_op_in_loop(i1 %c, i32 %x) {
 ; CHECK:       loop:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[LOOP_LATCH:%.*]]
 ; CHECK:       if:
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[X:%.*]], 1
 ; CHECK-NEXT:    br label [[LOOP_LATCH]]
 ; CHECK:       loop.latch:
-; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ [[X:%.*]], [[IF]] ], [ 0, [[LOOP]] ]
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[PHI]], 1
-; CHECK-NEXT:    call void @use(i32 [[AND]])
+; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ [[TMP1]], [[IF]] ], [ 0, [[LOOP]] ]
+; CHECK-NEXT:    call void @use(i32 [[PHI]])
 ; CHECK-NEXT:    br label [[LOOP]]
 ;
   br label %loop
diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll
index 08d05a1e2db69f..c51ba0e5b6ea4e 100644
--- a/llvm/test/Transforms/LoopVectorize/induction.ll
+++ b/llvm/test/Transforms/LoopVectorize/induction.ll
@@ -5504,11 +5504,11 @@ define i32 @PR32419(i32 %a, i16 %b) {
 ; INTERLEAVE-NEXT:    br i1 [[VAR2]], label [[FOR_INC]], label [[FOR_COND:%.*]]
 ; INTERLEAVE:       for.cond:
 ; INTERLEAVE-NEXT:    [[VAR3:%.*]] = urem i16 [[B]], [[VAR1]]
+; INTERLEAVE-NEXT:    [[TMP50:%.*]] = sext i16 [[VAR3]] to i32
 ; INTERLEAVE-NEXT:    br label [[FOR_INC]]
 ; INTERLEAVE:       for.inc:
-; INTERLEAVE-NEXT:    [[VAR4:%.*]] = phi i16 [ [[VAR3]], [[FOR_COND]] ], [ 0, [[FOR_BODY]] ]
-; INTERLEAVE-NEXT:    [[VAR5:%.*]] = sext i16 [[VAR4]] to i32
-; INTERLEAVE-NEXT:    [[VAR6]] = or i32 [[VAR0]], [[VAR5]]
+; INTERLEAVE-NEXT:    [[VAR4:%.*]] = phi i32 [ [[TMP50]], [[FOR_COND]] ], [ 0, [[FOR_BODY]] ]
+; INTERLEAVE-NEXT:    [[VAR6]] = or i32 [[VAR0]], [[VAR4]]
 ; INTERLEAVE-NEXT:    [[I_NEXT]] = add nsw i32 [[I]], 1
 ; INTERLEAVE-NEXT:    [[COND:%.*]] = icmp eq i32 [[I_NEXT]], 0
 ; INTERLEAVE-NEXT:    br i1 [[COND]], label [[FOR_END]], label [[FOR_BODY]], !llvm.loop [[LOOP49:![0-9]+]]



More information about the llvm-commits mailing list