[llvm] Fix for logic in combineExtract() (PR #108208)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 11 05:28:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-systemz

Author: Jonas Paulsson (JonPsson1)

<details>
<summary>Changes</summary>

A (csmith) test case appeared where combineExtract() crashed when the input vector was a bitcast into a vector of i1:s. Add a check using canTreatAsByteVector() for the immediate (first) Op as well. This takes care of (avoids) this case and does not seem to change any benchmarks or tests.

I am not sure how combineExtract() is supposed to work with various theoretical vectors like <3 x i24> or similar, considering the use of getStoreSize() of the vector element type. A vector element when part of a vector would have the same store size as the element size, but when extracted and used as a scalar it would become the next bigger legal integer type, or?

I guess I am confused about the use of getStoreSize() on vector elements: an i16 and i32 element would have the same store size, so I think it's a little weird how the computations works when BITCASTing between i16 and i32 vectors.

It is clear that in the i1 vector case the logic gets confused: 'End' becomes 128 (bytes), so the 'Op.getOperand(End / OpBytesPerElement - 1)' call uses an argument of 15, but Op is 'v2i64 = BUILD_VECTOR Constant:i64<3>, Constant:i64<3>'.


---
Full diff: https://github.com/llvm/llvm-project/pull/108208.diff


2 Files Affected:

- (modified) llvm/lib/Target/SystemZ/SystemZISelLowering.cpp (+3-6) 
- (added) llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll (+20) 


``````````diff
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 582a8c139b2937..bcb09e59ffb0c8 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6569,13 +6569,12 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
   // The number of bytes being extracted.
   unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
 
-  for (;;) {
+  while (canTreatAsByteVector(Op.getValueType())) {
     unsigned Opcode = Op.getOpcode();
     if (Opcode == ISD::BITCAST)
       // Look through bitcasts.
       Op = Op.getOperand(0);
-    else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
-             canTreatAsByteVector(Op.getValueType())) {
+    else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
       // Get a VPERM-like permute mask and see whether the bytes covered
       // by the extracted element are a contiguous sequence from one
       // source operand.
@@ -6597,8 +6596,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
       Index = Byte / BytesPerElement;
       Op = Op.getOperand(unsigned(First) / Bytes.size());
       Force = true;
-    } else if (Opcode == ISD::BUILD_VECTOR &&
-               canTreatAsByteVector(Op.getValueType())) {
+    } else if (Opcode == ISD::BUILD_VECTOR) {
       // We can only optimize this case if the BUILD_VECTOR elements are
       // at least as wide as the extracted value.
       EVT OpVT = Op.getValueType();
@@ -6627,7 +6625,6 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
     } else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
-               canTreatAsByteVector(Op.getValueType()) &&
                canTreatAsByteVector(Op.getOperand(0).getValueType())) {
       // Make sure that only the unextended bits are significant.
       EVT ExtVT = Op.getValueType();
diff --git a/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
new file mode 100644
index 00000000000000..d568af47dbafd0
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
@@ -0,0 +1,20 @@
+; RUN: llc -mtriple=s390x-linux-gnu -mcpu=z16 < %s  | FileCheck %s
+;
+; Check that DAGCombiner doesn't crash in SystemZ combineExtract()
+; when handling EXTRACT_VECTOR_ELT with a vector of i1:s.
+
+define i32 @fun(i32 %arg) {
+; CHECK-LABEL: fun:
+entry:
+  %cc = icmp eq i32 %arg, 0
+  br label %loop
+
+loop:
+  %P = phi <128 x i1> [ zeroinitializer, %entry ], [ bitcast (<2 x i64> <i64 3, i64 3> to <128 x i1>), %loop ]
+  br i1 %cc, label %exit, label %loop
+
+exit:
+  %E = extractelement <128 x i1> %P, i64 0
+  %Res = zext i1 %E to i32
+  ret i32 %Res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/108208


More information about the llvm-commits mailing list