[llvm] [AArch64] Handle ANY_EXTEND in BuildShuffleExtendCombine (PR #118308)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 07:25:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Hari Limaye (hazzlim)
<details>
<summary>Changes</summary>
Handle ANY_EXTEND when combining a buildvector/shuffle of extended
operands, as we can safely ignore ANY_EXTENDS when checking if all signs
of the other extends are matching.
---
Full diff: https://github.com/llvm/llvm-project/pull/118308.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+23-6)
- (modified) llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll (+5-10)
- (modified) llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll (+37-4)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..4e33f3e67ee742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18501,6 +18501,7 @@ static EVT calculatePreExtendType(SDValue Extend) {
switch (Extend.getOpcode()) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
+ case ISD::ANY_EXTEND:
return Extend.getOperand(0).getValueType();
case ISD::AssertSext:
case ISD::AssertZext:
@@ -18545,14 +18546,15 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
// extend, and make sure it looks valid.
SDValue Extend = BV->getOperand(0);
unsigned ExtendOpcode = Extend.getOpcode();
+ bool IsAnyExt = ExtendOpcode == ISD::ANY_EXTEND;
bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
ExtendOpcode == ISD::AssertSext;
- if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
+ if (!IsAnyExt && !IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
return SDValue();
- // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
- // calculatePreExtendType will work without issue.
+ // Shuffle inputs are vector, limit to SIGN_EXTEND/ZERO_EXTEND/ANY_EXTEND to
+ // ensure calculatePreExtendType will work without issue.
if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
return SDValue();
@@ -18563,15 +18565,27 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
return SDValue();
- // Make sure all other operands are equally extended
+ // Make sure all other operands are equally extended.
+ bool SeenZExtOrSExt = !IsAnyExt;
for (SDValue Op : drop_begin(BV->ops())) {
if (Op.isUndef())
continue;
+
+ if (calculatePreExtendType(Op) != PreExtendType)
+ return SDValue();
+
unsigned Opc = Op.getOpcode();
+ if (Opc == ISD::ANY_EXTEND)
+ continue;
+
bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
Opc == ISD::AssertSext;
- if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
+
+ if (SeenZExtOrSExt && OpcIsSExt != IsSExt)
return SDValue();
+
+ IsSExt = OpcIsSExt;
+ SeenZExtOrSExt = true;
}
SDValue NBV;
@@ -18594,7 +18608,10 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
: BV.getOperand(1).getOperand(0),
cast<ShuffleVectorSDNode>(BV)->getMask());
}
- return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
+ unsigned ExtOpc = !SeenZExtOrSExt ? ISD::ANY_EXTEND
+ : IsSExt ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
+ return DAG.getNode(ExtOpc, DL, VT, NBV);
}
/// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
index 482135b721da49..95c54cd8b01511 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
@@ -10,18 +10,13 @@ target triple = "aarch64-unknown-linux-gnu"
define dso_local i32 @dupext_crashtest(i32 %e) local_unnamed_addr {
; CHECK-LABEL: dupext_crashtest:
; CHECK: // %bb.0: // %for.body.lr.ph
+; CHECK-NEXT: dup v0.2s, w0
; CHECK-NEXT: .LBB0_1: // %vector.body
; CHECK-NEXT: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT: ldr d0, [x8]
-; CHECK-NEXT: ushll v0.2d, v0.2s, #0
-; CHECK-NEXT: fmov x9, d0
-; CHECK-NEXT: mov x8, v0.d[1]
-; CHECK-NEXT: mul w9, w0, w9
-; CHECK-NEXT: mul w8, w0, w8
-; CHECK-NEXT: fmov d0, x9
-; CHECK-NEXT: mov v0.d[1], x8
-; CHECK-NEXT: xtn v0.2s, v0.2d
-; CHECK-NEXT: str d0, [x8]
+; CHECK-NEXT: ldr d1, [x8]
+; CHECK-NEXT: smull v1.2d, v0.2s, v1.2s
+; CHECK-NEXT: xtn v1.2s, v1.2d
+; CHECK-NEXT: str d1, [x8]
; CHECK-NEXT: b .LBB0_1
for.body.lr.ph:
%conv314 = zext i32 %e to i64
diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index 0ce92a20fb3a17..8bb5c62ad43dd5 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -158,6 +158,39 @@ entry:
ret <2 x i64> %out
}
+define <2 x i32> @dupzext_v2i32_v2i64_trunc(i32 %src, <2 x i32> %b) {
+; CHECK-SD-LABEL: dupzext_v2i32_v2i64_trunc:
+; CHECK-SD: // %bb.0: // %entry
+; CHECK-SD-NEXT: dup v1.2s, w0
+; CHECK-SD-NEXT: smull v0.2d, v1.2s, v0.2s
+; CHECK-SD-NEXT: xtn v0.2s, v0.2d
+; CHECK-SD-NEXT: ret
+;
+; CHECK-GI-LABEL: dupzext_v2i32_v2i64_trunc:
+; CHECK-GI: // %bb.0: // %entry
+; CHECK-GI-NEXT: mov w8, w0
+; CHECK-GI-NEXT: ushll v0.2d, v0.2s, #0
+; CHECK-GI-NEXT: dup v1.2d, x8
+; CHECK-GI-NEXT: fmov x9, d0
+; CHECK-GI-NEXT: mov x11, v0.d[1]
+; CHECK-GI-NEXT: fmov x8, d1
+; CHECK-GI-NEXT: mov x10, v1.d[1]
+; CHECK-GI-NEXT: mul x8, x8, x9
+; CHECK-GI-NEXT: mul x9, x10, x11
+; CHECK-GI-NEXT: mov v0.d[0], x8
+; CHECK-GI-NEXT: mov v0.d[1], x9
+; CHECK-GI-NEXT: xtn v0.2s, v0.2d
+; CHECK-GI-NEXT: ret
+entry:
+ %in = zext i32 %src to i64
+ %ext.b = zext <2 x i32> %b to <2 x i64>
+ %broadcast.splatinsert = insertelement <2 x i64> undef, i64 %in, i64 0
+ %broadcast.splat = shufflevector <2 x i64> %broadcast.splatinsert, <2 x i64> undef, <2 x i32> zeroinitializer
+ %prod = mul nuw <2 x i64> %broadcast.splat, %ext.b
+ %out = trunc <2 x i64> %prod to <2 x i32>
+ ret <2 x i32> %out
+}
+
; Unsupported combines
define <2 x i16> @dupsext_v2i8_v2i16(i8 %src, <2 x i8> %b) {
@@ -407,10 +440,10 @@ define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
;
; CHECK-GI-LABEL: shufsext_v8i8_v8i16:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: adrp x8, .LCPI13_0
+; CHECK-GI-NEXT: adrp x8, .LCPI14_0
; CHECK-GI-NEXT: sshll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT: ldr q0, [x8, :lo12:.LCPI13_0]
+; CHECK-GI-NEXT: ldr q0, [x8, :lo12:.LCPI14_0]
; CHECK-GI-NEXT: tbl v0.16b, { v2.16b, v3.16b }, v0.16b
; CHECK-GI-NEXT: mul v0.8h, v0.8h, v1.8h
; CHECK-GI-NEXT: ret
@@ -460,10 +493,10 @@ define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
;
; CHECK-GI-LABEL: shufzext_v8i8_v8i16:
; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: adrp x8, .LCPI15_0
+; CHECK-GI-NEXT: adrp x8, .LCPI16_0
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT: ldr q0, [x8, :lo12:.LCPI15_0]
+; CHECK-GI-NEXT: ldr q0, [x8, :lo12:.LCPI16_0]
; CHECK-GI-NEXT: tbl v0.16b, { v2.16b, v3.16b }, v0.16b
; CHECK-GI-NEXT: mul v0.8h, v0.8h, v1.8h
; CHECK-GI-NEXT: ret
``````````
</details>
https://github.com/llvm/llvm-project/pull/118308
More information about the llvm-commits
mailing list