[llvm] [VPlan] Keep common flags during CSE. (PR #157664)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 9 06:04:04 PDT 2025
https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/157664
During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.
>From a33366f4fffe8f82d3a2eeb1283cac3c4d42def7 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 9 Sep 2025 13:57:21 +0100
Subject: [PATCH] [VPlan] Only drop flags on mis-match during CSE.
During CSE, we don't have to drop poison-generating flags, if both the
re-used recipe and the to-be-replaced recipe have the same flags.
---
llvm/lib/Transforms/Vectorize/VPlan.h | 4 +++
.../lib/Transforms/Vectorize/VPlanRecipes.cpp | 36 +++++++++++++++++++
.../Transforms/Vectorize/VPlanTransforms.cpp | 4 +--
.../LoopVectorize/PowerPC/vectorize-bswap.ll | 2 +-
.../LoopVectorize/X86/scatter_crash.ll | 8 ++---
llvm/test/Transforms/LoopVectorize/flags.ll | 2 +-
6 files changed, 48 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b93bdf244237e..53291a931530f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -721,6 +721,10 @@ class VPIRFlags {
AllFlags = Other.AllFlags;
}
+ /// Only keep flags also present in \p Other. \p Other must have the same
+ /// OpType as the current object.
+ void intersectFlags(const VPIRFlags &Other);
+
/// Drop all poison-generating flags.
void dropPoisonGeneratingFlags() {
// NOTE: This needs to be kept in-sync with
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46162a9276469..9f1311fbd0687 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif
+void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
+ assert(OpType == Other.OpType && "OpType must match");
+ switch (OpType) {
+ case OperationType::OverflowingBinOp:
+ WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
+ WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
+ break;
+ case OperationType::Trunc:
+ TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
+ TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
+ break;
+ case OperationType::DisjointOp:
+ DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
+ break;
+ case OperationType::PossiblyExactOp:
+ ExactFlags.IsExact = Other.ExactFlags.IsExact;
+ break;
+ case OperationType::GEPOp:
+ GEPFlags &= Other.GEPFlags;
+ break;
+ case OperationType::FPMathOp:
+ FMFs.NoNaNs &= Other.FMFs.NoNaNs;
+ FMFs.NoInfs &= Other.FMFs.NoInfs;
+ break;
+ case OperationType::NonNegOp:
+ NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
+ break;
+ case OperationType::Cmp:
+ assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
+ break;
+ case OperationType::Other:
+ assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
+ break;
+ }
+}
+
FastMathFlags VPIRFlags::getFastMathFlags() const {
assert(OpType == OperationType::FPMathOp &&
"recipe doesn't have fast math flags");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 10b2f5df2e23e..d86b53dd894fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
// V must dominate Def for a valid replacement.
if (!VPDT.dominates(V->getParent(), VPBB))
continue;
- // Drop poison-generating flags when reusing a value.
+ // Only keep flags present on both V and Def.
if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
- RFlags->dropPoisonGeneratingFlags();
+ RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
Def->replaceAllUsesWith(V);
continue;
}
diff --git a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
index 36c3a2a612d82..db1f2c71e0f77 100644
--- a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
+++ b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
@@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
; CHECK-NEXT: store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4
diff --git a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
index df54411f7e710..c2dfce0aa70b8 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
@@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
; CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
-; CHECK-NEXT: [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
-; CHECK-NEXT: [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
+; CHECK-NEXT: [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
+; CHECK-NEXT: [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
; CHECK-NEXT: call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
@@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
; CHECK-NEXT: [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
; CHECK-NEXT: [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
-; CHECK-NEXT: [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
-; CHECK-NEXT: [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
+; CHECK-NEXT: [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
+; CHECK-NEXT: [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
; CHECK-NEXT: [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
; CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])
diff --git a/llvm/test/Transforms/LoopVectorize/flags.ll b/llvm/test/Transforms/LoopVectorize/flags.ll
index cef8ea656afaa..cbdcd50476b98 100644
--- a/llvm/test/Transforms/LoopVectorize/flags.ll
+++ b/llvm/test/Transforms/LoopVectorize/flags.ll
@@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) {
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4
; CHECK-NEXT: store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
More information about the llvm-commits
mailing list