[llvm] [X86] Fix arithmetic error in extractVector (PR #128052)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 20 13:32:21 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Daniel Zabawa (daniel-zabawa)

<details>
<summary>Changes</summary>

The computation of the element count for the result VT in extractVector is incorrect when vector width does not divide VT.getSizeInBits(), which can occur when the source vector element count is not a power of two, e.g. extracting a vectorWidth 256b vector from a 384b source.

This rewrites the expression so the division is exact given that vectorWidth is a multiple of the source element size.

---
Full diff: https://github.com/llvm/llvm-project/pull/128052.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+3-3) 
- (added) llvm/test/CodeGen/X86/pr128052.ll (+30) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 1c9d43ce4c062..d79dd9d5cdd72 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -4066,9 +4066,9 @@ static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
                                 const SDLoc &dl, unsigned vectorWidth) {
   EVT VT = Vec.getValueType();
   EVT ElVT = VT.getVectorElementType();
-  unsigned Factor = VT.getSizeInBits() / vectorWidth;
-  EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
-                                  VT.getVectorNumElements() / Factor);
+  unsigned ResultNumElts =
+      (VT.getVectorNumElements() * vectorWidth) / VT.getSizeInBits();
+  EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT, ResultNumElts);
 
   // Extract the relevant vectorWidth bits.  Generate an EXTRACT_SUBVECTOR
   unsigned ElemsPerChunk = vectorWidth / ElVT.getSizeInBits();
diff --git a/llvm/test/CodeGen/X86/pr128052.ll b/llvm/test/CodeGen/X86/pr128052.ll
new file mode 100644
index 0000000000000..1a67e64b69832
--- /dev/null
+++ b/llvm/test/CodeGen/X86/pr128052.ll
@@ -0,0 +1,30 @@
+; Ensure assertion is not hit when folding concat of two contiguous extract_subvector operations
+; from a source with a non-power-of-two vector length.
+; RUN: llc -mattr=+avx2 < %s
+
+source_filename = "foo.c"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define void @foo(ptr noundef %pDst, ptr noundef %pSrc) {
+bb0:
+  %sptr1 = getelementptr i8, ptr %pSrc, i64 32
+  %load598 = load <12 x float>, ptr %sptr1, align 1
+  br label %bb1
+bb1:
+  %sptr0 = getelementptr i8, ptr %pSrc, i64 16
+  %load617 = load <12 x float>, ptr %sptr0, align 1
+  %42 = fsub contract <12 x float> %load617, %load598
+  %43 = shufflevector <12 x float> %42, <12 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %44 = fsub contract <12 x float> %load617, %load598
+  %45 = shufflevector <12 x float> %44, <12 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %46 = fsub contract <12 x float> %load617, %load598
+  %47 = shufflevector <12 x float> %46, <12 x float> poison, <4 x i32> <i32 8, i32 9, i32 10, i32 11>
+  %dptr0 = getelementptr i8, ptr %pDst, i64 16
+  %dptr1 = getelementptr i8, ptr %pDst, i64 32 
+  %dptr2 = getelementptr i8, ptr %pDst, i64 48
+  store <4 x float> %43, ptr %dptr0, align 1
+  store <4 x float> %45, ptr %dptr1, align 1
+  store <4 x float> %47, ptr %dptr2, align 1
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/128052


More information about the llvm-commits mailing list