[llvm] [VPlan] Keep common flags during CSE. (PR #157664)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 9 06:04:04 PDT 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/157664

During CSE, we don't have to drop all poison-generating flags on mis-match, we can keep the ones common on both recipes.

>From a33366f4fffe8f82d3a2eeb1283cac3c4d42def7 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 9 Sep 2025 13:57:21 +0100
Subject: [PATCH] [VPlan] Only drop flags on mis-match during CSE.

During CSE, we don't have to drop poison-generating flags, if both the
re-used recipe and the to-be-replaced recipe have the same flags.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  4 +++
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 36 +++++++++++++++++++
 .../Transforms/Vectorize/VPlanTransforms.cpp  |  4 +--
 .../LoopVectorize/PowerPC/vectorize-bswap.ll  |  2 +-
 .../LoopVectorize/X86/scatter_crash.ll        |  8 ++---
 llvm/test/Transforms/LoopVectorize/flags.ll   |  2 +-
 6 files changed, 48 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index b93bdf244237e..53291a931530f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -721,6 +721,10 @@ class VPIRFlags {
     AllFlags = Other.AllFlags;
   }
 
+  /// Only keep flags also present in \p Other. \p Other must have the same
+  /// OpType as the current object.
+  void intersectFlags(const VPIRFlags &Other);
+
   /// Drop all poison-generating flags.
   void dropPoisonGeneratingFlags() {
     // NOTE: This needs to be kept in-sync with
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 46162a9276469..9f1311fbd0687 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -392,6 +392,42 @@ void VPPartialReductionRecipe::print(raw_ostream &O, const Twine &Indent,
 }
 #endif
 
+void VPIRFlags::intersectFlags(const VPIRFlags &Other) {
+  assert(OpType == Other.OpType && "OpType must match");
+  switch (OpType) {
+  case OperationType::OverflowingBinOp:
+    WrapFlags.HasNUW &= Other.WrapFlags.HasNUW;
+    WrapFlags.HasNSW &= Other.WrapFlags.HasNSW;
+    break;
+  case OperationType::Trunc:
+    TruncFlags.HasNUW &= Other.TruncFlags.HasNUW;
+    TruncFlags.HasNSW &= Other.TruncFlags.HasNSW;
+    break;
+  case OperationType::DisjointOp:
+    DisjointFlags.IsDisjoint &= Other.DisjointFlags.IsDisjoint;
+    break;
+  case OperationType::PossiblyExactOp:
+    ExactFlags.IsExact = Other.ExactFlags.IsExact;
+    break;
+  case OperationType::GEPOp:
+    GEPFlags &= Other.GEPFlags;
+    break;
+  case OperationType::FPMathOp:
+    FMFs.NoNaNs &= Other.FMFs.NoNaNs;
+    FMFs.NoInfs &= Other.FMFs.NoInfs;
+    break;
+  case OperationType::NonNegOp:
+    NonNegFlags.NonNeg &= Other.NonNegFlags.NonNeg;
+    break;
+  case OperationType::Cmp:
+    assert(CmpPredicate == Other.CmpPredicate && "Cannot drop CmpPredicate");
+    break;
+  case OperationType::Other:
+    assert(AllFlags == Other.AllFlags && "Cannot drop other flags");
+    break;
+  }
+}
+
 FastMathFlags VPIRFlags::getFastMathFlags() const {
   assert(OpType == OperationType::FPMathOp &&
          "recipe doesn't have fast math flags");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 10b2f5df2e23e..d86b53dd894fb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2042,9 +2042,9 @@ void VPlanTransforms::cse(VPlan &Plan) {
         // V must dominate Def for a valid replacement.
         if (!VPDT.dominates(V->getParent(), VPBB))
           continue;
-        // Drop poison-generating flags when reusing a value.
+        // Only keep flags present on both V and Def.
         if (auto *RFlags = dyn_cast<VPRecipeWithIRFlags>(V))
-          RFlags->dropPoisonGeneratingFlags();
+          RFlags->intersectFlags(*cast<VPRecipeWithIRFlags>(Def));
         Def->replaceAllUsesWith(V);
         continue;
       }
diff --git a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
index 36c3a2a612d82..db1f2c71e0f77 100644
--- a/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
+++ b/llvm/test/Transforms/LoopVectorize/PowerPC/vectorize-bswap.ll
@@ -16,7 +16,7 @@ define dso_local void @test(ptr %Arr, i32 signext %Len) {
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP1:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[ARR:%.*]], i64 [[TMP1]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[TMP4:%.*]] = call <4 x i32> @llvm.bswap.v4i32(<4 x i32> [[WIDE_LOAD]])
 ; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr [[TMP2]], align 4
diff --git a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
index df54411f7e710..c2dfce0aa70b8 100644
--- a/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/scatter_crash.ll
@@ -142,8 +142,8 @@ define void @_Z3fn1v() #0 {
 ; CHECK-NEXT:    [[TMP32:%.*]] = add nsw <16 x i64> [[TMP30]], [[VEC_IND37]]
 ; CHECK-NEXT:    [[TMP33:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP32]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[TMP34]])
-; CHECK-NEXT:    [[TMP49:%.*]] = or <16 x i64> [[VEC_IND37]], splat (i64 1)
-; CHECK-NEXT:    [[TMP36:%.*]] = add <16 x i64> [[TMP30]], [[TMP49]]
+; CHECK-NEXT:    [[TMP49:%.*]] = or disjoint <16 x i64> [[VEC_IND37]], splat (i64 1)
+; CHECK-NEXT:    [[TMP36:%.*]] = add nsw <16 x i64> [[TMP30]], [[TMP49]]
 ; CHECK-NEXT:    [[TMP37:%.*]] = getelementptr inbounds [10 x i32], <16 x ptr> [[TMP31]], <16 x i64> [[TMP36]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 8), <16 x ptr> [[TMP37]], i32 8, <16 x i1> [[TMP34]])
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v16i32.v16p0(<16 x i32> splat (i32 7), <16 x ptr> [[TMP33]], i32 16, <16 x i1> [[BROADCAST_SPLAT]])
@@ -191,8 +191,8 @@ define void @_Z3fn1v() #0 {
 ; CHECK-NEXT:    [[TMP46:%.*]] = add nsw <8 x i64> [[TMP44]], [[VEC_IND70]]
 ; CHECK-NEXT:    [[TMP47:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP46]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[TMP48]])
-; CHECK-NEXT:    [[TMP54:%.*]] = or <8 x i64> [[VEC_IND70]], splat (i64 1)
-; CHECK-NEXT:    [[TMP50:%.*]] = add <8 x i64> [[TMP44]], [[TMP54]]
+; CHECK-NEXT:    [[TMP54:%.*]] = or disjoint <8 x i64> [[VEC_IND70]], splat (i64 1)
+; CHECK-NEXT:    [[TMP50:%.*]] = add nsw <8 x i64> [[TMP44]], [[TMP54]]
 ; CHECK-NEXT:    [[TMP51:%.*]] = getelementptr inbounds [10 x i32], <8 x ptr> [[TMP45]], <8 x i64> [[TMP50]], i64 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 8), <8 x ptr> [[TMP51]], i32 8, <8 x i1> [[TMP48]])
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> splat (i32 7), <8 x ptr> [[TMP47]], i32 16, <8 x i1> [[BROADCAST_SPLAT73]])
diff --git a/llvm/test/Transforms/LoopVectorize/flags.ll b/llvm/test/Transforms/LoopVectorize/flags.ll
index cef8ea656afaa..cbdcd50476b98 100644
--- a/llvm/test/Transforms/LoopVectorize/flags.ll
+++ b/llvm/test/Transforms/LoopVectorize/flags.ll
@@ -175,7 +175,7 @@ define void @gep_with_shared_nusw_and_others(i64 %n, ptr %A) {
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr float, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr nusw float, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x float>, ptr [[TMP1]], align 4
 ; CHECK-NEXT:    store <4 x float> [[WIDE_LOAD]], ptr [[TMP1]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4



More information about the llvm-commits mailing list