[llvm] [ValueTracking] Add `matchSimpleBinaryIntrinsicRecurrence` helper (PR #145964)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 27 10:04:57 PDT 2025
https://github.com/antoniofrighetto updated https://github.com/llvm/llvm-project/pull/145964
>From c11ea449e59cfbd10b7ba2ed11a049ca5184164a Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Fri, 27 Jun 2025 19:03:46 +0200
Subject: [PATCH] [ValueTracking] Add `matchSimpleBinaryIntrinsicRecurrence`
helper
Similarly to what it is being done to match simple recurrence cycle
relations, attempt to match value-accumulating recurrences of kind:
```
%umax.acc = phi i8 [ %umax, %backedge ], [ %a, %entry ]
%umax = call i8 @llvm.umax.i8(i8 %umax.acc, i8 %b)
```
Preliminary work to let InstCombine avoid folding such recurrences,
so that simple loop-invariant computation may get hoisted. Minor
opportunity to refactor out code as well.
---
llvm/include/llvm/Analysis/ValueTracking.h | 16 ++++
llvm/lib/Analysis/ValueTracking.cpp | 73 +++++++++++--------
llvm/unittests/Analysis/ValueTrackingTest.cpp | 52 +++++++++++++
3 files changed, 111 insertions(+), 30 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index c804f551f5a75..15e23de3878dc 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -21,6 +21,7 @@
#include "llvm/IR/FMF.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/Compiler.h"
#include <cassert>
@@ -965,6 +966,21 @@ LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
Value *&Start, Value *&Step);
+/// Attempt to match a simple value-accumulating recurrence of the form:
+/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge]
+/// %llvm.intrinsic = call Ty @llvm.intrinsic(%OtherOp, %llvm.intrinsic.acc)
+/// OR
+/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge]
+/// %llvm.intrinsic = call Ty @llvm.intrinsic(%llvm.intrinsic.acc, %OtherOp)
+///
+/// The recurrence relation is of kind:
+/// X_0 = %a (initial value),
+/// X_i = call @llvm.binary.intrinsic(X_i-1, %b)
+/// Where %b is not required to be loop-invariant.
+LLVM_ABI bool matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
+ PHINode *&P, Value *&Init,
+ Value *&OtherOp);
+
/// Return true if RHS is known to be implied true by LHS. Return false if
/// RHS is known to be implied false by LHS. Otherwise, return std::nullopt if
/// no implication can be made. A & B must be i1 (boolean) values or a vector of
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 93c22212a27ce..e576f4899810a 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9070,46 +9070,43 @@ llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
return {Intrinsic::not_intrinsic, false};
}
-bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
- Value *&Start, Value *&Step) {
+template <typename InstTy>
+static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
+ Value *&Init, Value *&OtherOp) {
// Handle the case of a simple two-predecessor recurrence PHI.
// There's a lot more that could theoretically be done here, but
// this is sufficient to catch some interesting cases.
// TODO: Expand list -- gep, uadd.sat etc.
- if (P->getNumIncomingValues() != 2)
+ if (PN->getNumIncomingValues() != 2)
return false;
- for (unsigned i = 0; i != 2; ++i) {
- Value *L = P->getIncomingValue(i);
- Value *R = P->getIncomingValue(!i);
- auto *LU = dyn_cast<BinaryOperator>(L);
- if (!LU)
- continue;
- Value *LL = LU->getOperand(0);
- Value *LR = LU->getOperand(1);
-
- // Find a recurrence.
- if (LL == P)
- L = LR;
- else if (LR == P)
- L = LL;
- else
- continue; // Check for recurrence with L and R flipped.
-
- // We have matched a recurrence of the form:
- // %iv = [R, %entry], [%iv.next, %backedge]
- // %iv.next = binop %iv, L
- // OR
- // %iv = [R, %entry], [%iv.next, %backedge]
- // %iv.next = binop L, %iv
- BO = LU;
- Start = R;
- Step = L;
- return true;
+ for (unsigned I = 0; I != 2; ++I) {
+ if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(I))) {
+ Value *LHS = Operation->getOperand(0);
+ Value *RHS = Operation->getOperand(1);
+ if (LHS != PN && RHS != PN)
+ continue;
+
+ Inst = Operation;
+ Init = PN->getIncomingValue(!I);
+ OtherOp = (LHS == PN) ? RHS : LHS;
+ return true;
+ }
}
return false;
}
+bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
+ Value *&Start, Value *&Step) {
+ // We try to match a recurrence of the form:
+ // %iv = [Start, %entry], [%iv.next, %backedge]
+ // %iv.next = binop %iv, Step
+ // Or:
+ // %iv = [Start, %entry], [%iv.next, %backedge]
+ // %iv.next = binop Step, %iv
+ return matchTwoInputRecurrence(P, BO, Start, Step);
+}
+
bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
Value *&Start, Value *&Step) {
BinaryOperator *BO = nullptr;
@@ -9119,6 +9116,22 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
}
+bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
+ PHINode *&P, Value *&Init,
+ Value *&OtherOp) {
+ // Binary intrinsics only supported for now.
+ if (I->arg_size() != 2 || I->getType() != I->getArgOperand(0)->getType() ||
+ I->getType() != I->getArgOperand(1)->getType())
+ return false;
+
+ IntrinsicInst *II = nullptr;
+ P = dyn_cast<PHINode>(I->getArgOperand(0));
+ if (!P)
+ P = dyn_cast<PHINode>(I->getArgOperand(1));
+
+ return P && matchTwoInputRecurrence(P, II, Init, OtherOp) && II == I;
+}
+
/// Return true if "icmp Pred LHS RHS" is always true.
static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
const Value *RHS) {
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 6031898f7f679..dbe72286c3a0b 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -1257,6 +1257,58 @@ TEST_F(ValueTrackingTest, computePtrAlignment) {
EXPECT_EQ(getKnownAlignment(A, DL, CxtI3, &AC, &DT), Align(16));
}
+TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceUMax) {
+ auto M = parseModule(R"(
+ define i8 @test(i8 %a, i8 %b) {
+ entry:
+ br label %loop
+ loop:
+ %iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ]
+ %umax.acc = phi i8 [ %umax, %loop ], [ %a, %entry ]
+ %umax = call i8 @llvm.umax.i8(i8 %umax.acc, i8 %b)
+ %iv.next = add nuw i8 %iv, 1
+ %cmp = icmp ult i8 %iv.next, 10
+ br i1 %cmp, label %loop, label %exit
+ exit:
+ ret i8 %umax
+ }
+ )");
+
+ auto *F = M->getFunction("test");
+ auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "umax"));
+ auto *UMaxAcc = &cast<PHINode>(findInstructionByName(F, "umax.acc"));
+ PHINode *PN;
+ Value *Init, *OtherOp;
+ EXPECT_TRUE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp));
+ EXPECT_EQ(UMaxAcc, PN);
+ EXPECT_EQ(F->getArg(0), Init);
+ EXPECT_EQ(F->getArg(1), OtherOp);
+}
+
+TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceNegativeFSHR) {
+ auto M = parseModule(R"(
+ define i8 @test(i8 %a, i8 %b, i8 %c) {
+ entry:
+ br label %loop
+ loop:
+ %iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ]
+ %fshr.acc = phi i8 [ %fshr, %loop ], [ %a, %entry ]
+ %fshr = call i8 @llvm.fshr.i8(i8 %fshr.acc, i8 %b, i8 %c)
+ %iv.next = add nuw i8 %iv, 1
+ %cmp = icmp ult i8 %iv.next, 10
+ br i1 %cmp, label %loop, label %exit
+ exit:
+ ret i8 %fshr
+ }
+ )");
+
+ auto *F = M->getFunction("test");
+ auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "fshr"));
+ PHINode *PN;
+ Value *Init, *OtherOp;
+ EXPECT_FALSE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp));
+}
+
TEST_F(ComputeKnownBitsTest, ComputeKnownBits) {
parseAssembly(
"define i32 @test(i32 %a, i32 %b) {\n"
More information about the llvm-commits
mailing list