[llvm] [VectorCombine] Add special handling for truncating shuffles (PR #70013)

Nabeel Omer via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 24 01:22:55 PDT 2023


https://github.com/omern1 created https://github.com/llvm/llvm-project/pull/70013

When dealing with a truncating shuffle, we can end up in a situation where the type passed to getShuffleCost is the type of the result of the shuffle, and the mask references an element which is out of bounds of the result vector.

If dealing with truncating shuffles, pass the type of the input vectors to `getShuffleCost()` in order to avoid an out-of-bounds assertion.

>From ae8360674dcfba05a5391bb9cf7be8a408fcf200 Mon Sep 17 00:00:00 2001
From: Nabeel Omer <Nabeel.Omer at sony.com>
Date: Mon, 23 Oct 2023 22:42:38 +0100
Subject: [PATCH] [VectorCombine] Add special handling for truncating shuffles

---
 llvm/lib/Transforms/Vectorize/VectorCombine.cpp | 12 ++++++++----
 .../X86/reduction-truncating-vecs.ll            | 17 +++++++++++++++++
 2 files changed, 25 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 16efc3b2336f2a5..943ff52bf3c1bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1472,21 +1472,25 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
   if (!ShuffleInputType)
     return false;
-  int NumInputElts = ShuffleInputType->getNumElements();
+  unsigned int NumInputElts = ShuffleInputType->getNumElements();
 
   // Find the mask from sorting the lanes into order. This is most likely to
   // become a identity or concat mask. Undef elements are pushed to the end.
   SmallVector<int> ConcatMask;
   Shuffle->getShuffleMask(ConcatMask);
   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
+  // In the case of a truncating shuffle it's possible for the mask
+  // to have an index greater than the size of the resulting vector.
+  // This requires special handling.
+  bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
   bool UsesSecondVec =
-      any_of(ConcatMask, [&](int M) { return M >= NumInputElts; });
+      any_of(ConcatMask, [&](int M) { return M >= (int) NumInputElts; });
   InstructionCost OldCost = TTI.getShuffleCost(
       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
-      UsesSecondVec ? VecType : ShuffleInputType, Shuffle->getShuffleMask());
+      (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType, Shuffle->getShuffleMask());
   InstructionCost NewCost = TTI.getShuffleCost(
       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
-      UsesSecondVec ? VecType : ShuffleInputType, ConcatMask);
+      (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType, ConcatMask);
 
   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
                     << "\n");
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll
new file mode 100644
index 000000000000000..4b429b30a7f5e59
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
+; RUN: opt -S --passes=vector-combine -mtriple=x86_64-sie-ps5 < %s | FileCheck %s
+
+define i16 @test_spill_mixed() {
+; CHECK-LABEL: define i16 @test_spill_mixed() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <8 x i32> zeroinitializer, <8 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 9>
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]])
+; CHECK-NEXT:    ret i16 0
+;
+entry:
+  %0 = shufflevector <8 x i32> zeroinitializer, <8 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 9>
+  %1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0)
+  ret i16 0
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)



More information about the llvm-commits mailing list