[llvm] [SandboxVec][Legality] Fix mask on diamond reuse with shuffle (PR #126963)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 11:55:12 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/126963

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.

>From 3408853f7fa716bc97f837f9b7a2855aa3b78000 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 29 Jan 2025 14:37:38 -0800
Subject: [PATCH] [SandboxVec][Legality] Fix mask on diamond reuse with shuffle

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors
in case of a diamond reuse with shuffle. The mask needs to enumerate all
elements of a vector, not treat the original vector value as a single element.
That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask
needs to have 4 indices, not just 2.
---
 .../Vectorize/SandboxVectorizer/InstrMaps.h   |  5 +++-
 .../Vectorize/SandboxVectorizer/Legality.cpp  | 10 +++++--
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  |  2 ++
 .../SandboxVectorizer/bottomup_basic.ll       | 22 +++++++++++++++
 .../SandboxVectorizer/InstrMapsTest.cpp       | 27 +++++++++++++++++++
 5 files changed, 63 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index c931319d3b002..9bdf940fc77b7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -18,6 +18,7 @@
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 #include <algorithm>
 
 namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ class InstrMaps {
   /// Update the map to reflect that \p Origs got vectorized into \p Vec.
   void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
     auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
-    for (auto [Lane, Orig] : enumerate(Origs)) {
+    unsigned Lane = 0;
+    for (Value *Orig : Origs) {
       auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
       assert(Pair.second && "Orig already exists in the map!");
       (void)Pair;
       OrigToLaneMap[Orig] = Lane;
+      Lane += VecUtils::getNumLanes(Orig);
     }
   }
   void clear() {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index c9329c24e1f4c..366243231379f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,14 +202,20 @@ CollectDescr
 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
   Vec.reserve(Bndl.size());
-  for (auto [Lane, V] : enumerate(Bndl)) {
+  uint32_t LaneAccum;
+  for (auto [Elm, V] : enumerate(Bndl)) {
+    uint32_t VLanes = VecUtils::getNumLanes(V);
     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
       // If there is a vector containing `V`, then get the lane it came from.
       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
-      Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+      // This could be a vector, like <2 x float> in which case the mask needs
+      // to enumerate all lanes.
+      for (int Ln = 0; Ln != VLanes; ++Ln)
+        Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
     } else {
       Vec.emplace_back(V);
     }
+    LaneAccum += VLanes;
   }
   return CollectDescr(std::move(Vec));
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 507d163240127..4fb029d3344b8 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
     const ShuffleMask &Mask =
         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
     NewVec = createShuffle(VecOp, Mask, UserBB);
+    assert(NewVec->getType() == VecOp->getType() &&
+           "Expected same type! Bad mask ?");
     break;
   }
   case LegalityResultID::DiamondReuseMultiInput: {
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 45b937dc1b1b6..301d6649669f4 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
   ret void
 }
 
+; Same but with <2 x float> elements instead of scalars.
+define void @diamondWithShuffleFromVec(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffleFromVec(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT:    [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT:    store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+  %ld0 = load <2 x float>, ptr %ptr0
+  %ld1 = load <2 x float>, ptr %ptr1
+  %sub0 = fsub <2 x float> %ld0, %ld1
+  %sub1 = fsub <2 x float> %ld1, %ld0
+  store <2 x float> %sub0, ptr %ptr0
+  store <2 x float> %sub1, ptr %ptr1
+  ret void
+}
+
 define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
 ; CHECK-LABEL: define void @diamondMultiInput(
 ; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index 1d7c8f9cdde04..5b033f0edcb02 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
+
+TEST_F(InstrMapsTest, VectorLanes) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
+  %vadd0 = add <2 x i8> %v0, %v1
+  %vadd1 = add <2 x i8> %v0, %v1
+  %vadd2 = add <4 x i8> %v2, %v3
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+
+  auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
+
+  sandboxir::InstrMaps IMaps(Ctx);
+
+  // Check that the vector lanes are calculated correctly.
+  IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
+}



More information about the llvm-commits mailing list