[llvm] [VPlan] Replace disjoint or with add instead of dropping disjoint. (PR #83821)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 08:06:50 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/83821

>From 2a1c4712110cb6617517bbad39c6cb12e8c4a3d3 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 4 Mar 2024 10:59:21 +0000
Subject: [PATCH] [VPlan] Replace disjoint or with add instead of dropping
 disjoint.

Dropping disjoint from an OR may yield incorrect results, as some
analysis may have converted it to an Add implicitly (e.g. SCEV used
for dependence analysis). Instead, replace it with an equivalent Add.

This is possible as all users of the disjoint OR only access lanes
where the operands are disjoint or poison otherwise.

Note that replacing all disjoint ORs with ADDs instead of dropping the
flags is not strictly necessary. It is only needed for disjoint ORs
that SCEV treated as ADDs, but those are not tracked.
---
 .../Vectorize/LoopVectorizationPlanner.h      |  3 +++
 llvm/lib/Transforms/Vectorize/VPlan.h         |  8 ++++++
 .../Transforms/Vectorize/VPlanPatternMatch.h  | 27 ++++++++++++++++---
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 18 +++++++++++++
 4 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index a7ebf78e54ceb6..b94859864fff3c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -68,6 +68,9 @@ class VPBuilder {
 public:
   VPBuilder() = default;
   VPBuilder(VPBasicBlock *InsertBB) { setInsertPoint(InsertBB); }
+  VPBuilder(VPRecipeBase *InsertPt) {
+    setInsertPoint(InsertPt->getParent(), InsertPt->getIterator());
+  }
 
   /// Clear the insertion point: created instructions will not be inserted into
   /// a block.
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 16c09a83e777dd..b565b4351e16d4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1127,6 +1127,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
     return WrapFlags.HasNSW;
   }
 
+  bool isDisjoint() const {
+    assert(OpType == OperationType::DisjointOp &&
+           "recipe cannot have a disjoing flag");
+    return DisjointFlags.IsDisjoint;
+  }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   void printFlags(raw_ostream &O) const;
 #endif
@@ -2136,6 +2142,8 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags {
     assert(isPredicated() && "Trying to get the mask of a unpredicated recipe");
     return getOperand(getNumOperands() - 1);
   }
+
+  unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); }
 };
 
 /// A recipe for generating conditional branches on the bits of a mask.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index b90c588b607564..4b5b6b8cc3dbcf 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -73,12 +73,12 @@ template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
   }
 };
 
-template <typename Op0_t, typename Op1_t, unsigned Opcode>
-struct BinaryVPInstruction_match {
+template <typename RecipeTy, typename Op0_t, typename Op1_t, unsigned Opcode>
+struct BinaryRecipe_match {
   Op0_t Op0;
   Op1_t Op1;
 
-  BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+  BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
 
   bool match(const VPValue *V) {
     auto *DefR = V->getDefiningRecipe();
@@ -86,15 +86,27 @@ struct BinaryVPInstruction_match {
   }
 
   bool match(const VPRecipeBase *R) {
-    auto *DefR = dyn_cast<VPInstruction>(R);
+    auto *DefR = dyn_cast<RecipeTy>(R);
     if (!DefR || DefR->getOpcode() != Opcode)
       return false;
     assert(DefR->getNumOperands() == 2 &&
            "recipe with matched opcode does not have 2 operands");
     return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1));
   }
+
+  bool match(const VPSingleDefRecipe *R) {
+    return match(static_cast<const VPRecipeBase *>(R));
+  }
 };
 
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+using BinaryVPInstruction_match =
+    BinaryRecipe_match<VPInstruction, Op0_t, Op1_t, Opcode>;
+
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+using BinaryVPReplicate_match =
+    BinaryRecipe_match<VPReplicateRecipe, Op0_t, Op1_t, Opcode>;
+
 template <unsigned Opcode, typename Op0_t>
 inline UnaryVPInstruction_match<Op0_t, Opcode>
 m_VPInstruction(const Op0_t &Op0) {
@@ -130,6 +142,13 @@ inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
 m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
   return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
 }
+
+template <unsigned Opcode, typename Op0_t, typename Op1_t>
+inline BinaryVPReplicate_match<Op0_t, Op1_t, Opcode>
+m_VPReplicate(const Op0_t &Op0, const Op1_t &Op1) {
+  return BinaryVPReplicate_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
+}
+
 } // namespace VPlanPatternMatch
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 9d6deb802e2090..818647d5ea1bac 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1249,6 +1249,24 @@ void VPlanTransforms::dropPoisonGeneratingRecipes(
       // load/store. If the underlying instruction has poison-generating flags,
       // drop them directly.
       if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) {
+        VPValue *A, *B;
+        using namespace llvm::VPlanPatternMatch;
+        // Dropping disjoint from an OR may yield incorrect results, as some
+        // analysis may have converted it to an Add implicitly (e.g. SCEV used
+        // for dependence analysis). Instead, replace it with an equivalent Add.
+        // This is possible as all users of the disjoint OR only access lanes
+        // where the operands are disjoint or poison otherwise.
+        if (match(RecWithFlags,
+                  m_VPReplicate<Instruction::Or>(m_VPValue(A), m_VPValue(B))) &&
+            RecWithFlags->isDisjoint()) {
+          VPBuilder Builder(RecWithFlags);
+          VPInstruction *New = Builder.createOverflowingOp(
+              Instruction::Add, {A, B}, {false, false},
+              RecWithFlags->getDebugLoc());
+          RecWithFlags->replaceAllUsesWith(New);
+          RecWithFlags->eraseFromParent();
+          CurRec = New;
+        }
         RecWithFlags->dropPoisonGeneratingFlags();
       } else {
         Instruction *Instr = dyn_cast_or_null<Instruction>(



More information about the llvm-commits mailing list