[llvm] Fix for logic in combineExtract() (PR #108208)
Jonas Paulsson via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 24 09:29:31 PDT 2024
https://github.com/JonPsson1 updated https://github.com/llvm/llvm-project/pull/108208
>From e116cec92534165b6b86b6d7c7c2586a23de6de7 Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Wed, 11 Sep 2024 13:26:09 +0200
Subject: [PATCH 1/2] Fix for logic in combineExtract()
---
.../Target/SystemZ/SystemZISelLowering.cpp | 9 +++------
.../SystemZ/DAGCombine_extract_vector_elt.ll | 20 +++++++++++++++++++
2 files changed, 23 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 3dabc5ef540cfb..42e564102af31b 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6586,13 +6586,12 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
// The number of bytes being extracted.
unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
- for (;;) {
+ while (canTreatAsByteVector(Op.getValueType())) {
unsigned Opcode = Op.getOpcode();
if (Opcode == ISD::BITCAST)
// Look through bitcasts.
Op = Op.getOperand(0);
- else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
- canTreatAsByteVector(Op.getValueType())) {
+ else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
// Get a VPERM-like permute mask and see whether the bytes covered
// by the extracted element are a contiguous sequence from one
// source operand.
@@ -6614,8 +6613,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
Index = Byte / BytesPerElement;
Op = Op.getOperand(unsigned(First) / Bytes.size());
Force = true;
- } else if (Opcode == ISD::BUILD_VECTOR &&
- canTreatAsByteVector(Op.getValueType())) {
+ } else if (Opcode == ISD::BUILD_VECTOR) {
// We can only optimize this case if the BUILD_VECTOR elements are
// at least as wide as the extracted value.
EVT OpVT = Op.getValueType();
@@ -6644,7 +6642,6 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
} else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
- canTreatAsByteVector(Op.getValueType()) &&
canTreatAsByteVector(Op.getOperand(0).getValueType())) {
// Make sure that only the unextended bits are significant.
EVT ExtVT = Op.getValueType();
diff --git a/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
new file mode 100644
index 00000000000000..d568af47dbafd0
--- /dev/null
+++ b/llvm/test/CodeGen/SystemZ/DAGCombine_extract_vector_elt.ll
@@ -0,0 +1,20 @@
+; RUN: llc -mtriple=s390x-linux-gnu -mcpu=z16 < %s | FileCheck %s
+;
+; Check that DAGCombiner doesn't crash in SystemZ combineExtract()
+; when handling EXTRACT_VECTOR_ELT with a vector of i1:s.
+
+define i32 @fun(i32 %arg) {
+; CHECK-LABEL: fun:
+entry:
+ %cc = icmp eq i32 %arg, 0
+ br label %loop
+
+loop:
+ %P = phi <128 x i1> [ zeroinitializer, %entry ], [ bitcast (<2 x i64> <i64 3, i64 3> to <128 x i1>), %loop ]
+ br i1 %cc, label %exit, label %loop
+
+exit:
+ %E = extractelement <128 x i1> %P, i64 0
+ %Res = zext i1 %E to i32
+ ret i32 %Res
+}
>From f2f1051c6472ff7179b8e9f39030e4720b79955c Mon Sep 17 00:00:00 2001
From: Jonas Paulsson <paulson1 at linux.ibm.com>
Date: Tue, 24 Sep 2024 18:25:59 +0200
Subject: [PATCH 2/2] Do a check before call instead.
---
llvm/lib/Target/SystemZ/SystemZISelLowering.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index 42e564102af31b..ba105c12bc4e97 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -6586,12 +6586,13 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
// The number of bytes being extracted.
unsigned BytesPerElement = VecVT.getVectorElementType().getStoreSize();
- while (canTreatAsByteVector(Op.getValueType())) {
+ for (;;) {
unsigned Opcode = Op.getOpcode();
if (Opcode == ISD::BITCAST)
// Look through bitcasts.
Op = Op.getOperand(0);
- else if (Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) {
+ else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
+ canTreatAsByteVector(Op.getValueType())) {
// Get a VPERM-like permute mask and see whether the bytes covered
// by the extracted element are a contiguous sequence from one
// source operand.
@@ -6613,7 +6614,8 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
Index = Byte / BytesPerElement;
Op = Op.getOperand(unsigned(First) / Bytes.size());
Force = true;
- } else if (Opcode == ISD::BUILD_VECTOR) {
+ } else if (Opcode == ISD::BUILD_VECTOR &&
+ canTreatAsByteVector(Op.getValueType())) {
// We can only optimize this case if the BUILD_VECTOR elements are
// at least as wide as the extracted value.
EVT OpVT = Op.getValueType();
@@ -6642,6 +6644,7 @@ SDValue SystemZTargetLowering::combineExtract(const SDLoc &DL, EVT ResVT,
} else if ((Opcode == ISD::SIGN_EXTEND_VECTOR_INREG ||
Opcode == ISD::ZERO_EXTEND_VECTOR_INREG ||
Opcode == ISD::ANY_EXTEND_VECTOR_INREG) &&
+ canTreatAsByteVector(Op.getValueType()) &&
canTreatAsByteVector(Op.getOperand(0).getValueType())) {
// Make sure that only the unextended bits are significant.
EVT ExtVT = Op.getValueType();
@@ -7358,8 +7361,9 @@ SDValue SystemZTargetLowering::combineEXTRACT_VECTOR_ELT(
if (auto *IndexN = dyn_cast<ConstantSDNode>(N->getOperand(1))) {
SDValue Op0 = N->getOperand(0);
EVT VecVT = Op0.getValueType();
- return combineExtract(SDLoc(N), N->getValueType(0), VecVT, Op0,
- IndexN->getZExtValue(), DCI, false);
+ if (canTreatAsByteVector(VecVT))
+ return combineExtract(SDLoc(N), N->getValueType(0), VecVT, Op0,
+ IndexN->getZExtValue(), DCI, false);
}
return SDValue();
}
More information about the llvm-commits
mailing list