[llvm] Fix for logic in combineExtract() (PR #108208)

Jonas Paulsson via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 09:29:31 PDT 2024


https://github.com/JonPsson1 updated https://github.com/llvm/llvm-project/pull/108208

>From e116cec92534165b6b86b6d7c7c2586a23de6de7 Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Wed, 11 Sep 2024 13:26:09 +0200
Subject: [PATCH 1/2] Fix for logic in combineExtract()

---
 .../Target/SystemZ/SystemZISelLowering.cpp    |  9 +++------
 .../SystemZ/DAGCombine_extract_vector_elt.ll  | 20 +++++++++++++++++++
 2 files changed, 23 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll

diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 3dabc5ef540cfb..42e564102af31b 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6586,13 +6586,12 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
   // The number of bytes being extracted.
   unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
 
-  for (;;) {
+  while (canTreatAsByteVector(Op.getValueType())) {
     unsigned Opcode = Op.getOpcode();
     if (Opcode == ISD::BITCAST)
       // Look through bitcasts.
       Op = Op.getOperand(0);
-    else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
-             canTreatAsByteVector(Op.getValueType())) {
+    else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
       // Get a VPERM-like permute mask and see whether the bytes covered
       // by the extracted element are a contiguous sequence from one
       // source operand.
@@ -6614,8 +6613,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
       Index = Byte / BytesPerElement;
       Op = Op.getOperand(unsigned(First) / Bytes.size());
       Force = true;
-    } else if (Opcode == ISD::BUILD_VECTOR &&
-               canTreatAsByteVector(Op.getValueType())) {
+    } else if (Opcode == ISD::BUILD_VECTOR) {
       // We can only optimize this case if the BUILD_VECTOR elements are
       // at least as wide as the extracted value.
       EVT OpVT = Op.getValueType();
@@ -6644,7 +6642,6 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
     } else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
-               canTreatAsByteVector(Op.getValueType()) &&
                canTreatAsByteVector(Op.getOperand(0).getValueType())) {
       // Make sure that only the unextended bits are significant.
       EVT ExtVT = Op.getValueType();
diff --git a/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
new file mode 100644
index 00000000000000..d568af47dbafd0
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
@@ -0,0 +1,20 @@
+; RUN: llc -mtriple=s390x-linux-gnu -mcpu=z16 < %s  | FileCheck %s
+;
+; Check that DAGCombiner doesn't crash in SystemZ combineExtract()
+; when handling EXTRACT_VECTOR_ELT with a vector of i1:s.
+
+define i32 @fun(i32 %arg) {
+; CHECK-LABEL: fun:
+entry:
+  %cc = icmp eq i32 %arg, 0
+  br label %loop
+
+loop:
+  %P = phi <128 x i1> [ zeroinitializer, %entry ], [ bitcast (<2 x i64> <i64 3, i64 3> to <128 x i1>), %loop ]
+  br i1 %cc, label %exit, label %loop
+
+exit:
+  %E = extractelement <128 x i1> %P, i64 0
+  %Res = zext i1 %E to i32
+  ret i32 %Res
+}

>From f2f1051c6472ff7179b8e9f39030e4720b79955c Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Tue, 24 Sep 2024 18:25:59 +0200
Subject: [PATCH 2/2] Do a check before call instead.

---
 llvm/lib/Target/SystemZ/SystemZISelLowering.cpp | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 42e564102af31b..ba105c12bc4e97 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6586,12 +6586,13 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
   // The number of bytes being extracted.
   unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
 
-  while (canTreatAsByteVector(Op.getValueType())) {
+  for (;;) {
     unsigned Opcode = Op.getOpcode();
     if (Opcode == ISD::BITCAST)
       // Look through bitcasts.
       Op = Op.getOperand(0);
-    else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
+    else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
+             canTreatAsByteVector(Op.getValueType())) {
       // Get a VPERM-like permute mask and see whether the bytes covered
       // by the extracted element are a contiguous sequence from one
       // source operand.
@@ -6613,7 +6614,8 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
       Index = Byte / BytesPerElement;
       Op = Op.getOperand(unsigned(First) / Bytes.size());
       Force = true;
-    } else if (Opcode == ISD::BUILD_VECTOR) {
+    } else if (Opcode == ISD::BUILD_VECTOR &&
+               canTreatAsByteVector(Op.getValueType())) {
       // We can only optimize this case if the BUILD_VECTOR elements are
       // at least as wide as the extracted value.
       EVT OpVT = Op.getValueType();
@@ -6642,6 +6644,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
     } else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
                 Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
+               canTreatAsByteVector(Op.getValueType()) &&
                canTreatAsByteVector(Op.getOperand(0).getValueType())) {
       // Make sure that only the unextended bits are significant.
       EVT ExtVT = Op.getValueType();
@@ -7358,8 +7361,9 @@ SDValue SystemZTargetLowering::combineEXTRACT_VECTOR_ELT(
   if (auto *IndexN = dyn_cast<ConstantSDNode>(N->getOperand(1))) {
     SDValue Op0 = N->getOperand(0);
     EVT VecVT = Op0.getValueType();
-    return combineExtract(SDLoc(N), N->getValueType(0), VecVT, Op0,
-                          IndexN->getZExtValue(), DCI, false);
+    if (canTreatAsByteVector(VecVT))
+      return combineExtract(SDLoc(N), N->getValueType(0), VecVT, Op0,
+                            IndexN->getZExtValue(), DCI, false);
   }
   return SDValue();
 }



More information about the llvm-commits mailing list