[llvm] [AArch64] Handle ANY_EXTEND in BuildShuffleExtendCombine (PR #118308)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 07:25:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Hari Limaye (hazzlim)

<details>
<summary>Changes</summary>

Handle ANY_EXTEND when combining a buildvector/shuffle of extended
operands, as we can safely ignore ANY_EXTENDS when checking if all signs
of the other extends are matching.


---
Full diff: https://github.com/llvm/llvm-project/pull/118308.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+23-6) 
- (modified) llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll (+5-10) 
- (modified) llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll (+37-4) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..4e33f3e67ee742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18501,6 +18501,7 @@ static EVT calculatePreExtendType(SDValue Extend) {
   switch (Extend.getOpcode()) {
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
+  case ISD::ANY_EXTEND:
     return Extend.getOperand(0).getValueType();
   case ISD::AssertSext:
   case ISD::AssertZext:
@@ -18545,14 +18546,15 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
   // extend, and make sure it looks valid.
   SDValue Extend = BV->getOperand(0);
   unsigned ExtendOpcode = Extend.getOpcode();
+  bool IsAnyExt = ExtendOpcode == ISD::ANY_EXTEND;
   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
                 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
                 ExtendOpcode == ISD::AssertSext;
-  if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
+  if (!IsAnyExt && !IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
       ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
     return SDValue();
-  // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
-  // calculatePreExtendType will work without issue.
+  // Shuffle inputs are vector, limit to SIGN_EXTEND/ZERO_EXTEND/ANY_EXTEND to
+  // ensure calculatePreExtendType will work without issue.
   if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
       ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
     return SDValue();
@@ -18563,15 +18565,27 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
       PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
     return SDValue();
 
-  // Make sure all other operands are equally extended
+  // Make sure all other operands are equally extended.
+  bool SeenZExtOrSExt = !IsAnyExt;
   for (SDValue Op : drop_begin(BV->ops())) {
     if (Op.isUndef())
       continue;
+
+    if (calculatePreExtendType(Op) != PreExtendType)
+      return SDValue();
+
     unsigned Opc = Op.getOpcode();
+    if (Opc == ISD::ANY_EXTEND)
+      continue;
+
     bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
                      Opc == ISD::AssertSext;
-    if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
+
+    if (SeenZExtOrSExt && OpcIsSExt != IsSExt)
       return SDValue();
+
+    IsSExt = OpcIsSExt;
+    SeenZExtOrSExt = true;
   }
 
   SDValue NBV;
@@ -18594,7 +18608,10 @@ static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
                                    : BV.getOperand(1).getOperand(0),
                                cast<ShuffleVectorSDNode>(BV)->getMask());
   }
-  return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
+  unsigned ExtOpc = !SeenZExtOrSExt ? ISD::ANY_EXTEND
+                    : IsSExt        ? ISD::SIGN_EXTEND
+                                    : ISD::ZERO_EXTEND;
+  return DAG.getNode(ExtOpc, DL, VT, NBV);
 }
 
 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
index 482135b721da49..95c54cd8b01511 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext-crash.ll
@@ -10,18 +10,13 @@ target triple = "aarch64-unknown-linux-gnu"
 define dso_local i32 @dupext_crashtest(i32 %e) local_unnamed_addr {
 ; CHECK-LABEL: dupext_crashtest:
 ; CHECK:       // %bb.0: // %for.body.lr.ph
+; CHECK-NEXT:    dup v0.2s, w0
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    ldr d0, [x8]
-; CHECK-NEXT:    ushll v0.2d, v0.2s, #0
-; CHECK-NEXT:    fmov x9, d0
-; CHECK-NEXT:    mov x8, v0.d[1]
-; CHECK-NEXT:    mul w9, w0, w9
-; CHECK-NEXT:    mul w8, w0, w8
-; CHECK-NEXT:    fmov d0, x9
-; CHECK-NEXT:    mov v0.d[1], x8
-; CHECK-NEXT:    xtn v0.2s, v0.2d
-; CHECK-NEXT:    str d0, [x8]
+; CHECK-NEXT:    ldr d1, [x8]
+; CHECK-NEXT:    smull v1.2d, v0.2s, v1.2s
+; CHECK-NEXT:    xtn v1.2s, v1.2d
+; CHECK-NEXT:    str d1, [x8]
 ; CHECK-NEXT:    b .LBB0_1
 for.body.lr.ph:
   %conv314 = zext i32 %e to i64
diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index 0ce92a20fb3a17..8bb5c62ad43dd5 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -158,6 +158,39 @@ entry:
   ret <2 x i64> %out
 }
 
+define <2 x i32> @dupzext_v2i32_v2i64_trunc(i32 %src, <2 x i32> %b) {
+; CHECK-SD-LABEL: dupzext_v2i32_v2i64_trunc:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    dup v1.2s, w0
+; CHECK-SD-NEXT:    smull v0.2d, v1.2s, v0.2s
+; CHECK-SD-NEXT:    xtn v0.2s, v0.2d
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: dupzext_v2i32_v2i64_trunc:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    mov w8, w0
+; CHECK-GI-NEXT:    ushll v0.2d, v0.2s, #0
+; CHECK-GI-NEXT:    dup v1.2d, x8
+; CHECK-GI-NEXT:    fmov x9, d0
+; CHECK-GI-NEXT:    mov x11, v0.d[1]
+; CHECK-GI-NEXT:    fmov x8, d1
+; CHECK-GI-NEXT:    mov x10, v1.d[1]
+; CHECK-GI-NEXT:    mul x8, x8, x9
+; CHECK-GI-NEXT:    mul x9, x10, x11
+; CHECK-GI-NEXT:    mov v0.d[0], x8
+; CHECK-GI-NEXT:    mov v0.d[1], x9
+; CHECK-GI-NEXT:    xtn v0.2s, v0.2d
+; CHECK-GI-NEXT:    ret
+entry:
+  %in = zext i32 %src to i64
+  %ext.b = zext <2 x i32> %b to <2 x i64>
+  %broadcast.splatinsert = insertelement <2 x i64> undef, i64 %in, i64 0
+  %broadcast.splat = shufflevector <2 x i64> %broadcast.splatinsert, <2 x i64> undef, <2 x i32> zeroinitializer
+  %prod = mul nuw <2 x i64> %broadcast.splat, %ext.b
+  %out = trunc <2 x i64> %prod to <2 x i32>
+  ret <2 x i32> %out
+}
+
 ; Unsupported combines
 
 define <2 x i16> @dupsext_v2i8_v2i16(i8 %src, <2 x i8> %b) {
@@ -407,10 +440,10 @@ define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
 ;
 ; CHECK-GI-LABEL: shufsext_v8i8_v8i16:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    adrp x8, .LCPI13_0
+; CHECK-GI-NEXT:    adrp x8, .LCPI14_0
 ; CHECK-GI-NEXT:    sshll v2.8h, v0.8b, #0
 ; CHECK-GI-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT:    ldr q0, [x8, :lo12:.LCPI13_0]
+; CHECK-GI-NEXT:    ldr q0, [x8, :lo12:.LCPI14_0]
 ; CHECK-GI-NEXT:    tbl v0.16b, { v2.16b, v3.16b }, v0.16b
 ; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
 ; CHECK-GI-NEXT:    ret
@@ -460,10 +493,10 @@ define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
 ;
 ; CHECK-GI-LABEL: shufzext_v8i8_v8i16:
 ; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    adrp x8, .LCPI15_0
+; CHECK-GI-NEXT:    adrp x8, .LCPI16_0
 ; CHECK-GI-NEXT:    ushll v2.8h, v0.8b, #0
 ; CHECK-GI-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-GI-NEXT:    ldr q0, [x8, :lo12:.LCPI15_0]
+; CHECK-GI-NEXT:    ldr q0, [x8, :lo12:.LCPI16_0]
 ; CHECK-GI-NEXT:    tbl v0.16b, { v2.16b, v3.16b }, v0.16b
 ; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
 ; CHECK-GI-NEXT:    ret

``````````

</details>


https://github.com/llvm/llvm-project/pull/118308


More information about the llvm-commits mailing list