[llvm] [VectorCombine] Add special handling for truncating shuffles (PR #70013)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 24 01:23:59 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nabeel Omer (omern1)

<details>
<summary>Changes</summary>

When dealing with a truncating shuffle, we can end up in a situation where the type passed to getShuffleCost is the type of the result of the shuffle, and the mask references an element which is out of bounds of the result vector.

If dealing with truncating shuffles, pass the type of the input vectors to `getShuffleCost()` in order to avoid an out-of-bounds assertion.

---
Full diff: https://github.com/llvm/llvm-project/pull/70013.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+8-4) 
- (added) llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll (+17) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 16efc3b2336f2a5..943ff52bf3c1bd2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1472,21 +1472,25 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
       dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
   if (!ShuffleInputType)
     return false;
-  int NumInputElts = ShuffleInputType->getNumElements();
+  unsigned int NumInputElts = ShuffleInputType->getNumElements();
 
   // Find the mask from sorting the lanes into order. This is most likely to
   // become a identity or concat mask. Undef elements are pushed to the end.
   SmallVector<int> ConcatMask;
   Shuffle->getShuffleMask(ConcatMask);
   sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
+  // In the case of a truncating shuffle it's possible for the mask
+  // to have an index greater than the size of the resulting vector.
+  // This requires special handling.
+  bool IsTruncatingShuffle = VecType->getNumElements() < NumInputElts;
   bool UsesSecondVec =
-      any_of(ConcatMask, [&](int M) { return M >= NumInputElts; });
+      any_of(ConcatMask, [&](int M) { return M >= (int) NumInputElts; });
   InstructionCost OldCost = TTI.getShuffleCost(
       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
-      UsesSecondVec ? VecType : ShuffleInputType, Shuffle->getShuffleMask());
+      (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType, Shuffle->getShuffleMask());
   InstructionCost NewCost = TTI.getShuffleCost(
       UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc,
-      UsesSecondVec ? VecType : ShuffleInputType, ConcatMask);
+      (UsesSecondVec && !IsTruncatingShuffle) ? VecType : ShuffleInputType, ConcatMask);
 
   LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
                     << "\n");
diff --git a/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll b/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll
new file mode 100644
index 000000000000000..4b429b30a7f5e59
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/reduction-truncating-vecs.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
+; RUN: opt -S --passes=vector-combine -mtriple=x86_64-sie-ps5 < %s | FileCheck %s
+
+define i16 @test_spill_mixed() {
+; CHECK-LABEL: define i16 @test_spill_mixed() {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <8 x i32> zeroinitializer, <8 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 9>
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]])
+; CHECK-NEXT:    ret i16 0
+;
+entry:
+  %0 = shufflevector <8 x i32> zeroinitializer, <8 x i32> zeroinitializer, <4 x i32> <i32 1, i32 2, i32 3, i32 9>
+  %1 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0)
+  ret i16 0
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/70013


More information about the llvm-commits mailing list