[llvm] [SandboxVec][Legality] Diamond reuse multi input (PR #123426)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 15:57:37 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123426

This patch implements the diamond pattern where we are vectorizing toward the top of the diamond from both edges, but the second edge may use elements from a different vector or just scalar values. This requires some additional packing code (see lit test).

>From 628c24e3755796190804ef735d3816a124b7fd25 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 22 Nov 2024 10:12:28 -0800
Subject: [PATCH] [SandboxVec][Legality] Diamond reuse multi input

This patch implements the diamond pattern where we are vectorizing toward the
top of the diamond from both edges, but the second edge may use elements from
a different vector or just scalar values. This requires some additional packing
code (see lit test).
---
 .../Vectorize/SandboxVectorizer/Legality.h    | 22 ++++++++++--
 .../Vectorize/SandboxVectorizer/Legality.cpp  |  3 +-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 34 +++++++++++++++++++
 .../SandboxVectorizer/bottomup_basic.ll       | 27 +++++++++++++++
 4 files changed, 83 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 4858ebaf0770aa..f10c535aa820ee 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -81,6 +81,7 @@ enum class LegalityResultID {
   Widen,                   ///> Vectorize by combining scalars to a vector.
   DiamondReuse,            ///> Don't generate new code, reuse existing vector.
   DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
+  DiamondReuseMultiInput,  ///> Reuse more than one vector and/or scalars.
 };
 
 /// The reason for vectorizing or not vectorizing.
@@ -108,6 +109,8 @@ struct ToStr {
       return "DiamondReuse";
     case LegalityResultID::DiamondReuseWithShuffle:
       return "DiamondReuseWithShuffle";
+    case LegalityResultID::DiamondReuseMultiInput:
+      return "DiamondReuseMultiInput";
     }
     llvm_unreachable("Unknown LegalityResultID enum");
   }
@@ -287,6 +290,20 @@ class CollectDescr {
   }
 };
 
+class DiamondReuseMultiInput final : public LegalityResult {
+  friend class LegalityAnalysis;
+  CollectDescr Descr;
+  DiamondReuseMultiInput(CollectDescr &&Descr)
+      : LegalityResult(LegalityResultID::DiamondReuseMultiInput),
+        Descr(std::move(Descr)) {}
+
+public:
+  static bool classof(const LegalityResult *From) {
+    return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput;
+  }
+  const CollectDescr &getCollectDescr() const { return Descr; }
+};
+
 /// Performs the legality analysis and returns a LegalityResult object.
 class LegalityAnalysis {
   Scheduler Sched;
@@ -312,8 +329,9 @@ class LegalityAnalysis {
       : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
-  ResultT &createLegalityResult(ArgsT... Args) {
-    ResultPool.push_back(std::unique_ptr<ResultT>(new ResultT(Args...)));
+  ResultT &createLegalityResult(ArgsT &&...Args) {
+    ResultPool.push_back(
+        std::unique_ptr<ResultT>(new ResultT(std::move(Args)...)));
     return cast<ResultT>(*ResultPool.back());
   }
   /// Checks if it's legal to vectorize the instructions in \p Bndl.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index ad3e38e2f1d923..085f4cd67ab76e 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -223,7 +223,8 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
         return createLegalityResult<DiamondReuse>(Vec);
       return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
     }
-    llvm_unreachable("TODO: Unimplemented");
+    return createLegalityResult<DiamondReuseMultiInput>(
+        std::move(CollectDescrs));
   }
 
   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index d62023ea018846..c6ab3c1942c330 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -308,6 +308,40 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
     NewVec = createShuffle(VecOp, Mask);
     break;
   }
+  case LegalityResultID::DiamondReuseMultiInput: {
+    const auto &Descr =
+        cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
+    Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
+
+    // TODO: Try to get WhereIt without creating a vector.
+    SmallVector<Value *, 4> DescrInstrs;
+    for (const auto &ElmDescr : Descr.getDescrs()) {
+      if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
+        DescrInstrs.push_back(I);
+    }
+    auto WhereIt = getInsertPointAfterInstrs(DescrInstrs);
+
+    Value *LastV = PoisonValue::get(ResTy);
+    for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
+      Value *VecOp = ElmDescr.getValue();
+      Context &Ctx = VecOp->getContext();
+      Value *ValueToInsert;
+      if (ElmDescr.needsExtract()) {
+        ConstantInt *IdxC =
+            ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
+        ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
+                                                   VecOp->getContext(), "VExt");
+      } else {
+        ValueToInsert = VecOp;
+      }
+      ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
+      Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
+                                             WhereIt, Ctx, "VIns");
+      LastV = Ins;
+    }
+    NewVec = LastV;
+    break;
+  }
   case LegalityResultID::Pack: {
     // If we can't vectorize the seeds then just return.
     if (Depth == 0)
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index a3798af8399087..5b389e25d70d95 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -242,3 +242,30 @@ define void @diamondWithShuffle(ptr %ptr) {
   store float %sub1, ptr %ptr1
   ret void
 }
+
+define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
+; CHECK-LABEL: define void @diamondMultiInput(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDX:%.*]] = load float, ptr [[PTRX]], align 4
+; CHECK-NEXT:    [[VINS:%.*]] = insertelement <2 x float> poison, float [[LDX]], i32 0
+; CHECK-NEXT:    [[VEXT:%.*]] = extractelement <2 x float> [[VECL]], i32 0
+; CHECK-NEXT:    [[VINS1:%.*]] = insertelement <2 x float> [[VINS]], float [[VEXT]], i32 1
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VINS1]]
+; CHECK-NEXT:    store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr1
+
+  %ldX = load float, ptr %ptrX
+
+  %sub0 = fsub float %ld0, %ldX
+  %sub1 = fsub float %ld1, %ld0
+  store float %sub0, ptr %ptr0
+  store float %sub1, ptr %ptr1
+  ret void
+}



More information about the llvm-commits mailing list