[llvm] 9fc1a0d - [AArch64] Alter mull shuffle(ext(..)) combine to work on buildvectors

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 21 07:44:34 PST 2022


Author: David Green
Date: 2022-02-21T15:44:30Z
New Revision: 9fc1a0dcb79afb31470751651c30e843c12e9ca5

URL: https://github.com/llvm/llvm-project/commit/9fc1a0dcb79afb31470751651c30e843c12e9ca5
DIFF: https://github.com/llvm/llvm-project/commit/9fc1a0dcb79afb31470751651c30e843c12e9ca5.diff

LOG: [AArch64] Alter mull shuffle(ext(..)) combine to work on buildvectors

We have a combine for converting mul(dup(ext(..)), ...) into
mul(ext(dup(..)), ..), for allowing more uses of smull and umull
instructions. Currently it looks for vector insert and shuffle vectors
to detect the element that we can convert to a vector extend. Not all
cases will have a shufflevector/insert element though.

This started by extending the recognition to buildvectors (with elements
that may be individually extended). The new method seems to cover all
the cases that the old method captured though, as the shuffle will
eventually be lowered to buildvectors, so the old method has been
removed to keep the code a little simpler. The new code detects legal
build_vector(ext(a), ext(b), ..), converting them to ext(build_vector(a,
b, ..)) providing all the extends/types match up.

Differential Revision: https://reviews.llvm.org/D120018

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
    llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 473984c658d39..30d30e88f2740 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -13448,33 +13448,17 @@ static EVT calculatePreExtendType(SDValue Extend) {
   }
 }
 
-/// Combines a dup(sext/zext) node pattern into sext/zext(dup)
+/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector)
 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
-static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
-                                                SelectionDAG &DAG) {
-  ShuffleVectorSDNode *ShuffleNode =
-      dyn_cast<ShuffleVectorSDNode>(VectorShuffle.getNode());
-  if (!ShuffleNode)
-    return SDValue();
-
-  // Ensuring the mask is zero before continuing
-  if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0)
-    return SDValue();
-
-  SDValue InsertVectorElt = VectorShuffle.getOperand(0);
-
-  if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT)
-    return SDValue();
-
-  SDValue InsertLane = InsertVectorElt.getOperand(2);
-  ConstantSDNode *Constant = dyn_cast<ConstantSDNode>(InsertLane.getNode());
-  // Ensures the insert is inserting into lane 0
-  if (!Constant || Constant->getZExtValue() != 0)
+static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
+  EVT VT = BV.getValueType();
+  if (BV.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
 
-  SDValue Extend = InsertVectorElt.getOperand(1);
+  // Use the first item in the buildvector to get the size of the extend, and
+  // make sure it looks valid.
+  SDValue Extend = BV->getOperand(0);
   unsigned ExtendOpcode = Extend.getOpcode();
-
   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
                 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
                 ExtendOpcode == ISD::AssertSext;
@@ -13484,30 +13468,28 @@ static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
 
   // Restrict valid pre-extend data type
   EVT PreExtendType = calculatePreExtendType(Extend);
-  if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 &&
-      PreExtendType != MVT::i32)
-    return SDValue();
-
-  EVT TargetType = VectorShuffle.getValueType();
-  EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType);
-  if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2)
+  if (PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2)
     return SDValue();
 
-  SDLoc DL(VectorShuffle);
-
-  SDValue InsertVectorNode = DAG.getNode(
-      InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT),
-      DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType),
-      DAG.getConstant(0, DL, MVT::i64));
-
-  std::vector<int> ShuffleMask(TargetType.getVectorNumElements());
-
-  SDValue VectorShuffleNode =
-      DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode,
-                           DAG.getUNDEF(PreExtendVT), ShuffleMask);
+  // Make sure all other operands are equally extended
+  for (SDValue Op : drop_begin(BV->ops())) {
+    unsigned Opc = Op.getOpcode();
+    bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
+                     Opc == ISD::AssertSext;
+    if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
+      return SDValue();
+  }
 
-  return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
-                     TargetType, VectorShuffleNode);
+  EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
+  EVT PreExtendLegalType =
+      PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
+  SDLoc DL(BV);
+  SmallVector<SDValue, 8> NewOps;
+  for (SDValue Op : BV->ops())
+    NewOps.push_back(
+        DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType));
+  SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
+  return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
 }
 
 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
@@ -13518,8 +13500,8 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
   if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
     return SDValue();
 
-  SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG);
-  SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG);
+  SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG);
+  SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG);
 
   // Neither operands have been changed, don't make any further changes
   if (!Op0 && !Op1)

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index bc31d41a55f43..5a57e6e82dd2e 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -156,10 +156,8 @@ entry:
 define <8 x i16> @nonsplat_shuffleinsert(i8 %src, <8 x i8> %b) {
 ; CHECK-LABEL: nonsplat_shuffleinsert:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    dup v1.8h, w8
-; CHECK-NEXT:    mul v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    dup v1.8b, w0
+; CHECK-NEXT:    smull v0.8h, v1.8b, v0.8b
 ; CHECK-NEXT:    ret
 entry:
     %in = sext i8 %src to i16

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
index 4f999edf3d571..12b451f509f73 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
@@ -201,25 +201,22 @@ define void @larger_smull(i16* nocapture noundef readonly %x, i16 noundef %y, i3
 ; CHECK-NEXT:    b .LBB3_6
 ; CHECK-NEXT:  .LBB3_3: // %vector.ph
 ; CHECK-NEXT:    and x10, x9, #0xfffffff0
+; CHECK-NEXT:    dup v0.4h, w8
 ; CHECK-NEXT:    add x11, x2, #32
 ; CHECK-NEXT:    add x12, x0, #16
 ; CHECK-NEXT:    mov x13, x10
-; CHECK-NEXT:    dup v0.4s, w8
+; CHECK-NEXT:    dup v1.8h, w8
 ; CHECK-NEXT:  .LBB3_4: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    ldp q1, q2, [x12, #-16]
+; CHECK-NEXT:    ldp q2, q3, [x12, #-16]
 ; CHECK-NEXT:    subs x13, x13, #16
 ; CHECK-NEXT:    add x12, x12, #32
-; CHECK-NEXT:    sshll2 v3.4s, v1.8h, #0
-; CHECK-NEXT:    sshll v1.4s, v1.4h, #0
-; CHECK-NEXT:    sshll2 v4.4s, v2.8h, #0
-; CHECK-NEXT:    sshll v2.4s, v2.4h, #0
-; CHECK-NEXT:    mul v3.4s, v0.4s, v3.4s
-; CHECK-NEXT:    mul v1.4s, v0.4s, v1.4s
-; CHECK-NEXT:    mul v4.4s, v0.4s, v4.4s
-; CHECK-NEXT:    mul v2.4s, v0.4s, v2.4s
-; CHECK-NEXT:    stp q1, q3, [x11, #-32]
-; CHECK-NEXT:    stp q2, q4, [x11], #64
+; CHECK-NEXT:    smull2 v4.4s, v1.8h, v2.8h
+; CHECK-NEXT:    smull v2.4s, v0.4h, v2.4h
+; CHECK-NEXT:    smull2 v5.4s, v1.8h, v3.8h
+; CHECK-NEXT:    smull v3.4s, v0.4h, v3.4h
+; CHECK-NEXT:    stp q2, q4, [x11, #-32]
+; CHECK-NEXT:    stp q3, q5, [x11], #64
 ; CHECK-NEXT:    b.ne .LBB3_4
 ; CHECK-NEXT:  // %bb.5: // %middle.block
 ; CHECK-NEXT:    cmp x10, x9
@@ -317,25 +314,22 @@ define void @larger_umull(i16* nocapture noundef readonly %x, i16 noundef %y, i3
 ; CHECK-NEXT:    b .LBB4_6
 ; CHECK-NEXT:  .LBB4_3: // %vector.ph
 ; CHECK-NEXT:    and x10, x9, #0xfffffff0
+; CHECK-NEXT:    dup v0.4h, w8
 ; CHECK-NEXT:    add x11, x2, #32
 ; CHECK-NEXT:    add x12, x0, #16
 ; CHECK-NEXT:    mov x13, x10
-; CHECK-NEXT:    dup v0.4s, w8
+; CHECK-NEXT:    dup v1.8h, w8
 ; CHECK-NEXT:  .LBB4_4: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    ldp q1, q2, [x12, #-16]
+; CHECK-NEXT:    ldp q2, q3, [x12, #-16]
 ; CHECK-NEXT:    subs x13, x13, #16
 ; CHECK-NEXT:    add x12, x12, #32
-; CHECK-NEXT:    ushll2 v3.4s, v1.8h, #0
-; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    ushll2 v4.4s, v2.8h, #0
-; CHECK-NEXT:    ushll v2.4s, v2.4h, #0
-; CHECK-NEXT:    mul v3.4s, v0.4s, v3.4s
-; CHECK-NEXT:    mul v1.4s, v0.4s, v1.4s
-; CHECK-NEXT:    mul v4.4s, v0.4s, v4.4s
-; CHECK-NEXT:    mul v2.4s, v0.4s, v2.4s
-; CHECK-NEXT:    stp q1, q3, [x11, #-32]
-; CHECK-NEXT:    stp q2, q4, [x11], #64
+; CHECK-NEXT:    umull2 v4.4s, v1.8h, v2.8h
+; CHECK-NEXT:    umull v2.4s, v0.4h, v2.4h
+; CHECK-NEXT:    umull2 v5.4s, v1.8h, v3.8h
+; CHECK-NEXT:    umull v3.4s, v0.4h, v3.4h
+; CHECK-NEXT:    stp q2, q4, [x11, #-32]
+; CHECK-NEXT:    stp q3, q5, [x11], #64
 ; CHECK-NEXT:    b.ne .LBB4_4
 ; CHECK-NEXT:  // %bb.5: // %middle.block
 ; CHECK-NEXT:    cmp x10, x9
@@ -435,12 +429,13 @@ define i16 @red_mla_dup_ext_u8_s8_s16(i8* noalias nocapture noundef readonly %A,
 ; CHECK-NEXT:    mov w0, wzr
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB5_4: // %vector.ph
+; CHECK-NEXT:    dup v2.8b, w9
 ; CHECK-NEXT:    and x11, x10, #0xfffffff0
-; CHECK-NEXT:    add x8, x0, #8
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
-; CHECK-NEXT:    mov x12, x11
+; CHECK-NEXT:    add x8, x0, #8
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-NEXT:    dup v2.8h, w9
+; CHECK-NEXT:    mov x12, x11
+; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
 ; CHECK-NEXT:  .LBB5_5: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldp d3, d4, [x8, #-8]


        


More information about the llvm-commits mailing list