[llvm] 3b98335 - [AArch64] Alter mull buildvectors(ext(..)) combine to work on shuffles

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 4 15:07:55 PDT 2022


Author: David Green
Date: 2022-04-04T23:07:47+01:00
New Revision: 3b9833597e810d4c485487d2f094a8e223af5548

URL: https://github.com/llvm/llvm-project/commit/3b9833597e810d4c485487d2f094a8e223af5548
DIFF: https://github.com/llvm/llvm-project/commit/3b9833597e810d4c485487d2f094a8e223af5548.diff

LOG: [AArch64] Alter mull buildvectors(ext(..)) combine to work on shuffles

D120018 altered this combine to work on buildvectors as opposed to
shuffle dup's. This works well for dups and other things that are
expanded into buildvectors. Some shuffles are legal though, and stay as
vector_shuffle through lowering. This expands the transform to also
handle shuffles, so that we can turn mul(shuffle(sext into
mul(sext(shuffle and more readily make smull/umull instructions. This
can come up from the SLP vectorizer adding shuffles that are costed from
extends.

Differential Revision: https://reviews.llvm.org/D123012

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 38079a191f72a..f6e1c6140147e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -13629,15 +13629,17 @@ static EVT calculatePreExtendType(SDValue Extend) {
   }
 }
 
-/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector)
-/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
-static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
+/// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern
+/// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector
+/// SExt/ZExt rather than the scalar SExt/ZExt
+static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
   EVT VT = BV.getValueType();
-  if (BV.getOpcode() != ISD::BUILD_VECTOR)
+  if (BV.getOpcode() != ISD::BUILD_VECTOR &&
+      BV.getOpcode() != ISD::VECTOR_SHUFFLE)
     return SDValue();
 
-  // Use the first item in the buildvector to get the size of the extend, and
-  // make sure it looks valid.
+  // Use the first item in the buildvector/shuffle to get the size of the
+  // extend, and make sure it looks valid.
   SDValue Extend = BV->getOperand(0);
   unsigned ExtendOpcode = Extend.getOpcode();
   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
@@ -13646,15 +13648,22 @@ static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
   if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
       ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
     return SDValue();
+  // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
+  // calculatePreExtendType will work without issue.
+  if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
+      ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
+    return SDValue();
 
   // Restrict valid pre-extend data type
   EVT PreExtendType = calculatePreExtendType(Extend);
   if (PreExtendType == MVT::Other ||
-      PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2)
+      PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
     return SDValue();
 
   // Make sure all other operands are equally extended
   for (SDValue Op : drop_begin(BV->ops())) {
+    if (Op.isUndef())
+      continue;
     unsigned Opc = Op.getOpcode();
     bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
                      Opc == ISD::AssertSext;
@@ -13662,15 +13671,26 @@ static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
       return SDValue();
   }
 
-  EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
-  EVT PreExtendLegalType =
-      PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
+  SDValue NBV;
   SDLoc DL(BV);
-  SmallVector<SDValue, 8> NewOps;
-  for (SDValue Op : BV->ops())
-    NewOps.push_back(
-        DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType));
-  SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
+  if (BV.getOpcode() == ISD::BUILD_VECTOR) {
+    EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
+    EVT PreExtendLegalType =
+        PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
+    SmallVector<SDValue, 8> NewOps;
+    for (SDValue Op : BV->ops())
+      NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType)
+                                    : DAG.getAnyExtOrTrunc(Op.getOperand(0), DL,
+                                                           PreExtendLegalType));
+    NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
+  } else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE
+    EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType());
+    NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0),
+                               BV.getOperand(1).isUndef()
+                                   ? DAG.getUNDEF(PreExtendVT)
+                                   : BV.getOperand(1).getOperand(0),
+                               cast<ShuffleVectorSDNode>(BV)->getMask());
+  }
   return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
 }
 
@@ -13682,8 +13702,8 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
   if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
     return SDValue();
 
-  SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG);
-  SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG);
+  SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG);
+  SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG);
 
   // Neither operands have been changed, don't make any further changes
   if (!Op0 && !Op1)

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index 2146221581842..e0d7759c5a66b 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -245,9 +245,8 @@ entry:
 define <8 x i16> @missing_insert(<8 x i8> %b) {
 ; CHECK-LABEL: missing_insert:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #4
-; CHECK-NEXT:    mul v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    ext v1.8b, v0.8b, v0.8b, #2
+; CHECK-NEXT:    smull v0.8h, v1.8b, v0.8b
 ; CHECK-NEXT:    ret
 entry:
     %ext.b = sext <8 x i8> %b to <8 x i16>
@@ -259,11 +258,8 @@ entry:
 define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
 ; CHECK-LABEL: shufsext_v8i8_v8i16:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    rev64 v0.8h, v0.8h
-; CHECK-NEXT:    ext v0.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    rev64 v0.8b, v0.8b
+; CHECK-NEXT:    smull v0.8h, v0.8b, v1.8b
 ; CHECK-NEXT:    ret
 entry:
   %in = sext <8 x i8> %src to <8 x i16>
@@ -276,17 +272,8 @@ entry:
 define <2 x i64> @shufsext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) {
 ; CHECK-LABEL: shufsext_v2i32_v2i64:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-NEXT:    sshll v1.2d, v1.2s, #0
-; CHECK-NEXT:    ext v0.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fmov x9, d1
-; CHECK-NEXT:    mov x8, v1.d[1]
-; CHECK-NEXT:    fmov x10, d0
-; CHECK-NEXT:    mov x11, v0.d[1]
-; CHECK-NEXT:    mul x9, x10, x9
-; CHECK-NEXT:    mul x8, x11, x8
-; CHECK-NEXT:    fmov d0, x9
-; CHECK-NEXT:    mov v0.d[1], x8
+; CHECK-NEXT:    rev64 v0.2s, v0.2s
+; CHECK-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEXT:    ret
 entry:
   %in = sext <2 x i32> %src to <2 x i64>
@@ -299,11 +286,8 @@ entry:
 define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
 ; CHECK-LABEL: shufzext_v8i8_v8i16:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    rev64 v0.8h, v0.8h
-; CHECK-NEXT:    ext v0.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    rev64 v0.8b, v0.8b
+; CHECK-NEXT:    umull v0.8h, v0.8b, v1.8b
 ; CHECK-NEXT:    ret
 entry:
   %in = zext <8 x i8> %src to <8 x i16>
@@ -316,17 +300,8 @@ entry:
 define <2 x i64> @shufzext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) {
 ; CHECK-LABEL: shufzext_v2i32_v2i64:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-NEXT:    sshll v1.2d, v1.2s, #0
-; CHECK-NEXT:    ext v0.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    fmov x9, d1
-; CHECK-NEXT:    mov x8, v1.d[1]
-; CHECK-NEXT:    fmov x10, d0
-; CHECK-NEXT:    mov x11, v0.d[1]
-; CHECK-NEXT:    mul x9, x10, x9
-; CHECK-NEXT:    mul x8, x11, x8
-; CHECK-NEXT:    fmov d0, x9
-; CHECK-NEXT:    mov v0.d[1], x8
+; CHECK-NEXT:    rev64 v0.2s, v0.2s
+; CHECK-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEXT:    ret
 entry:
   %in = sext <2 x i32> %src to <2 x i64>
@@ -339,11 +314,8 @@ entry:
 define <8 x i16> @shufzext_v8i8_v8i16_twoin(<8 x i8> %src1, <8 x i8> %src2, <8 x i8> %b) {
 ; CHECK-LABEL: shufzext_v8i8_v8i16_twoin:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    trn1 v0.8h, v0.8h, v1.8h
-; CHECK-NEXT:    ushll v1.8h, v2.8b, #0
-; CHECK-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    trn1 v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    umull v0.8h, v0.8b, v2.8b
 ; CHECK-NEXT:    ret
 entry:
   %in1 = zext <8 x i8> %src1 to <8 x i16>


        


More information about the llvm-commits mailing list