[llvm] [SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple vector inputs (PR #126965)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 12:54:07 PST 2025


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/126965

>From 57dd4b194d3d5ee8177b03b2c408fc002a1c6864 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 29 Jan 2025 16:23:21 -0800
Subject: [PATCH] [SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple
 vector inputs

When the operand comes from multiple inputs then we need additional packing
code. When the operands are scalar then we can use a single InsertElementInst.
But when the operands are vectors then we need a chain of ExtractElementInst
and InsertElementInst instructions to insert the vector value into the
destination vector. This is what this patch implements.
---
 .../Vectorize/SandboxVectorizer/Legality.cpp  |  3 +-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 35 +++++++++++++++----
 .../SandboxVectorizer/bottomup_basic.ll       | 33 +++++++++++++++++
 3 files changed, 63 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 36e28b48582f7..74634372156aa 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -203,13 +203,12 @@ LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
   Vec.reserve(Bndl.size());
   for (auto [Elm, V] : enumerate(Bndl)) {
-    uint32_t VLanes = VecUtils::getNumLanes(V);
     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
       // If there is a vector containing `V`, then get the lane it came from.
       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
       // This could be a vector, like <2 x float> in which case the mask needs
       // to enumerate all lanes.
-      for (unsigned Ln = 0; Ln != VLanes; ++Ln)
+      for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
         Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
     } else {
       Vec.emplace_back(V);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 4fb029d3344b8..0ccef5aecd28b 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -335,7 +335,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
   case LegalityResultID::DiamondReuseMultiInput: {
     const auto &Descr =
         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
-    Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
+    Type *ResTy = VecUtils::getWideType(Bndl[0]->getType(), Bndl.size());
 
     // TODO: Try to get WhereIt without creating a vector.
     SmallVector<Value *, 4> DescrInstrs;
@@ -347,7 +347,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
         getInsertPointAfterInstrs(DescrInstrs, UserBB);
 
     Value *LastV = PoisonValue::get(ResTy);
-    for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
+    unsigned Lane = 0;
+    for (const auto &ElmDescr : Descr.getDescrs()) {
       Value *VecOp = ElmDescr.getValue();
       Context &Ctx = VecOp->getContext();
       Value *ValueToInsert;
@@ -359,10 +360,32 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
       } else {
         ValueToInsert = VecOp;
       }
-      ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
-      Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
-                                             WhereIt, Ctx, "VIns");
-      LastV = Ins;
+      auto NumLanesToInsert = VecUtils::getNumLanes(ValueToInsert);
+      if (NumLanesToInsert == 1) {
+        // If we are inserting a scalar element then we need a single insert.
+        //   %VIns = insert %DstVec,  %SrcScalar, Lane
+        ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
+        LastV = InsertElementInst::create(LastV, ValueToInsert, LaneC, WhereIt,
+                                          Ctx, "VIns");
+      } else {
+        // If we are inserting a vector element then we need to extract and
+        // insert each vector element one by one with a chain of extracts and
+        // inserts, for example:
+        //   %VExt0 = extract %SrcVec, 0
+        //   %VIns0 = insert  %DstVec, %Vect0, Lane + 0
+        //   %VExt1 = extract %SrcVec, 1
+        //   %VIns1 = insert  %VIns0,  %Vect0, Lane + 1
+        for (unsigned LnCnt = 0; LnCnt != NumLanesToInsert; ++LnCnt) {
+          auto *ExtrIdxC = ConstantInt::get(Type::getInt32Ty(Ctx), LnCnt);
+          auto *ExtrI = ExtractElementInst::create(ValueToInsert, ExtrIdxC,
+                                                   WhereIt, Ctx, "VExt");
+          unsigned InsLane = Lane + LnCnt;
+          auto *InsLaneC = ConstantInt::get(Type::getInt32Ty(Ctx), InsLane);
+          LastV = InsertElementInst::create(LastV, ExtrI, InsLaneC, WhereIt,
+                                            Ctx, "VIns");
+        }
+      }
+      Lane += NumLanesToInsert;
     }
     NewVec = LastV;
     break;
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 301d6649669f4..6b18d4069e0ae 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -292,6 +292,39 @@ define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
   ret void
 }
 
+; Same but vectorizing <2 x float> vectors instead of scalars.
+define void @diamondMultiInputVector(ptr %ptr, ptr %ptrX) {
+; CHECK-LABEL: define void @diamondMultiInputVector(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[LDX:%.*]] = load <2 x float>, ptr [[PTRX]], align 8
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT:    [[VEXT:%.*]] = extractelement <2 x float> [[LDX]], i32 0
+; CHECK-NEXT:    [[INSI:%.*]] = insertelement <4 x float> poison, float [[VEXT]], i32 0
+; CHECK-NEXT:    [[VEXT1:%.*]] = extractelement <2 x float> [[LDX]], i32 1
+; CHECK-NEXT:    [[INSI2:%.*]] = insertelement <4 x float> [[INSI]], float [[VEXT1]], i32 1
+; CHECK-NEXT:    [[VEXT3:%.*]] = extractelement <4 x float> [[VECL]], i32 0
+; CHECK-NEXT:    [[VINS4:%.*]] = insertelement <4 x float> [[INSI2]], float [[VEXT3]], i32 2
+; CHECK-NEXT:    [[VEXT4:%.*]] = extractelement <4 x float> [[VECL]], i32 1
+; CHECK-NEXT:    [[VINS5:%.*]] = insertelement <4 x float> [[VINS4]], float [[VEXT4]], i32 3
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VINS5]]
+; CHECK-NEXT:    store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+  %ld0 = load <2 x float>, ptr %ptr0
+  %ld1 = load <2 x float>, ptr %ptr1
+
+  %ldX = load <2 x float>, ptr %ptrX
+
+  %sub0 = fsub <2 x float> %ld0, %ldX
+  %sub1 = fsub <2 x float> %ld1, %ld0
+  store <2 x float> %sub0, ptr %ptr0
+  store <2 x float> %sub1, ptr %ptr1
+  ret void
+}
+
 define void @diamondWithConstantVector(ptr %ptr) {
 ; CHECK-LABEL: define void @diamondWithConstantVector(
 ; CHECK-SAME: ptr [[PTR:%.*]]) {



More information about the llvm-commits mailing list