[llvm] 774b571 - [AArch64] Alter mull shuffle(ext(..)) combine to work on buildvectors

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 22 15:37:27 PST 2022


Author: David Green
Date: 2022-02-22T23:37:22Z
New Revision: 774b571546915d34a7254b38833001c77745e760

URL: https://github.com/llvm/llvm-project/commit/774b571546915d34a7254b38833001c77745e760
DIFF: https://github.com/llvm/llvm-project/commit/774b571546915d34a7254b38833001c77745e760.diff

LOG: [AArch64] Alter mull shuffle(ext(..)) combine to work on buildvectors

We have a combine for converting mul(dup(ext(..)), ...) into
mul(ext(dup(..)), ..), for allowing more uses of smull and umull
instructions. Currently it looks for vector insert and shuffle vectors
to detect the element that we can convert to a vector extend. Not all
cases will have a shufflevector/insert element though.

This started by extending the recognition to buildvectors (with elements
that may be individually extended). The new method seems to cover all
the cases that the old method captured though, as the shuffle will
eventually be lowered to buildvectors, so the old method has been
removed to keep the code a little simpler. The new code detects legal
build_vector(ext(a), ext(b), ..), converting them to ext(build_vector(a,
b, ..)) providing all the extends/types match up.

Differential Revision: https://reviews.llvm.org/D120018

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
    llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 58d91c3412a93..2fe77449b3a07 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -13448,33 +13448,17 @@ static EVT calculatePreExtendType(SDValue Extend) {
   }
 }
 
-/// Combines a dup(sext/zext) node pattern into sext/zext(dup)
+/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector)
 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
-static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
-                                                SelectionDAG &DAG) {
-  ShuffleVectorSDNode *ShuffleNode =
-      dyn_cast<ShuffleVectorSDNode>(VectorShuffle.getNode());
-  if (!ShuffleNode)
-    return SDValue();
-
-  // Ensuring the mask is zero before continuing
-  if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0)
-    return SDValue();
-
-  SDValue InsertVectorElt = VectorShuffle.getOperand(0);
-
-  if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT)
-    return SDValue();
-
-  SDValue InsertLane = InsertVectorElt.getOperand(2);
-  ConstantSDNode *Constant = dyn_cast<ConstantSDNode>(InsertLane.getNode());
-  // Ensures the insert is inserting into lane 0
-  if (!Constant || Constant->getZExtValue() != 0)
+static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
+  EVT VT = BV.getValueType();
+  if (BV.getOpcode() != ISD::BUILD_VECTOR)
     return SDValue();
 
-  SDValue Extend = InsertVectorElt.getOperand(1);
+  // Use the first item in the buildvector to get the size of the extend, and
+  // make sure it looks valid.
+  SDValue Extend = BV->getOperand(0);
   unsigned ExtendOpcode = Extend.getOpcode();
-
   bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
                 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
                 ExtendOpcode == ISD::AssertSext;
@@ -13484,30 +13468,29 @@ static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
 
   // Restrict valid pre-extend data type
   EVT PreExtendType = calculatePreExtendType(Extend);
-  if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 &&
-      PreExtendType != MVT::i32)
-    return SDValue();
-
-  EVT TargetType = VectorShuffle.getValueType();
-  EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType);
-  if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2)
+  if (PreExtendType == MVT::Other ||
+      PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2)
     return SDValue();
 
-  SDLoc DL(VectorShuffle);
-
-  SDValue InsertVectorNode = DAG.getNode(
-      InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT),
-      DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType),
-      DAG.getConstant(0, DL, MVT::i64));
-
-  std::vector<int> ShuffleMask(TargetType.getVectorNumElements());
-
-  SDValue VectorShuffleNode =
-      DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode,
-                           DAG.getUNDEF(PreExtendVT), ShuffleMask);
+  // Make sure all other operands are equally extended
+  for (SDValue Op : drop_begin(BV->ops())) {
+    unsigned Opc = Op.getOpcode();
+    bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
+                     Opc == ISD::AssertSext;
+    if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
+      return SDValue();
+  }
 
-  return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
-                     TargetType, VectorShuffleNode);
+  EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
+  EVT PreExtendLegalType =
+      PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
+  SDLoc DL(BV);
+  SmallVector<SDValue, 8> NewOps;
+  for (SDValue Op : BV->ops())
+    NewOps.push_back(
+        DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType));
+  SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
+  return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
 }
 
 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
@@ -13518,8 +13501,8 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
   if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
     return SDValue();
 
-  SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG);
-  SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG);
+  SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG);
+  SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG);
 
   // Neither operands have been changed, don't make any further changes
   if (!Op0 && !Op1)

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index bc31d41a55f43..cceb79f97bb93 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -156,10 +156,8 @@ entry:
 define <8 x i16> @nonsplat_shuffleinsert(i8 %src, <8 x i8> %b) {
 ; CHECK-LABEL: nonsplat_shuffleinsert:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sxtb w8, w0
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    dup v1.8h, w8
-; CHECK-NEXT:    mul v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    dup v1.8b, w0
+; CHECK-NEXT:    smull v0.8h, v1.8b, v0.8b
 ; CHECK-NEXT:    ret
 entry:
     %in = sext i8 %src to i16
@@ -170,6 +168,80 @@ entry:
     ret <8 x i16> %out
 }
 
+define <4 x i32> @nonsplat_shuffleinsert2(<4 x i16> %b, i16 %b0, i16 %b1, i16 %b2, i16 %b3) {
+; CHECK-LABEL: nonsplat_shuffleinsert2:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fmov s1, w0
+; CHECK-NEXT:    mov v1.h[1], w1
+; CHECK-NEXT:    mov v1.h[2], w2
+; CHECK-NEXT:    mov v1.h[3], w3
+; CHECK-NEXT:    smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT:    ret
+entry:
+    %s0 = sext i16 %b0 to i32
+    %s1 = sext i16 %b1 to i32
+    %s2 = sext i16 %b2 to i32
+    %s3 = sext i16 %b3 to i32
+    %ext.b = sext <4 x i16> %b to <4 x i32>
+    %v0 = insertelement <4 x i32> undef, i32 %s0, i32 0
+    %v1 = insertelement <4 x i32> %v0, i32 %s1, i32 1
+    %v2 = insertelement <4 x i32> %v1, i32 %s2, i32 2
+    %v3 = insertelement <4 x i32> %v2, i32 %s3, i32 3
+    %out = mul nsw <4 x i32> %v3, %ext.b
+    ret <4 x i32> %out
+}
+
+define void @typei1_orig(i64 %a, i8* %p, <8 x i16>* %q) {
+; CHECK-LABEL: typei1_orig:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmp x0, #0
+; CHECK-NEXT:    ldr q0, [x2]
+; CHECK-NEXT:    cset w8, gt
+; CHECK-NEXT:    neg v0.8h, v0.8h
+; CHECK-NEXT:    dup v1.8h, w8
+; CHECK-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NEXT:    cmtst v0.8h, v0.8h, v0.8h
+; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-NEXT:    str q0, [x1]
+; CHECK-NEXT:    ret
+    %tmp = xor <16 x i1> zeroinitializer, <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>
+    %tmp6 = load <8 x i16>, <8 x i16>* %q, align 2
+    %tmp7 = sub <8 x i16> zeroinitializer, %tmp6
+    %tmp8 = shufflevector <8 x i16> %tmp7, <8 x i16> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+    %tmp9 = icmp slt i64 0, %a
+    %tmp10 = zext i1 %tmp9 to i16
+    %tmp11 = insertelement <16 x i16> undef, i16 %tmp10, i64 0
+    %tmp12 = shufflevector <16 x i16> %tmp11, <16 x i16> undef, <16 x i32> zeroinitializer
+    %tmp13 = mul nuw <16 x i16> %tmp8, %tmp12
+    %tmp14 = icmp ne <16 x i16> %tmp13, zeroinitializer
+    %tmp15 = and <16 x i1> %tmp14, %tmp
+    %tmp16 = sext <16 x i1> %tmp15 to <16 x i8>
+    %tmp17 = bitcast i8* %p to <16 x i8>*
+    store <16 x i8> %tmp16, <16 x i8>* %tmp17, align 1
+    ret void
+}
+
+define <8 x i16> @typei1_v8i1_v8i16(i1 %src, <8 x i1> %b) {
+; CHECK-LABEL: typei1_v8i1_v8i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v1.8b, #1
+; CHECK-NEXT:    and w8, w0, #0x1
+; CHECK-NEXT:    and v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    dup v1.8h, w8
+; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-NEXT:    mul v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    ret
+entry:
+    %in = zext i1 %src to i16
+    %ext.b = zext <8 x i1> %b to <8 x i16>
+    %broadcast.splatinsert = insertelement <8 x i16> undef, i16 %in, i16 0
+    %broadcast.splat = shufflevector <8 x i16> %broadcast.splatinsert, <8 x i16> undef, <8 x i32> zeroinitializer
+    %out = mul nsw <8 x i16> %broadcast.splat, %ext.b
+    ret <8 x i16> %out
+}
+
 define <8 x i16> @missing_insert(<8 x i8> %b) {
 ; CHECK-LABEL: missing_insert:
 ; CHECK:       // %bb.0: // %entry

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
index 4f999edf3d571..12b451f509f73 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll
@@ -201,25 +201,22 @@ define void @larger_smull(i16* nocapture noundef readonly %x, i16 noundef %y, i3
 ; CHECK-NEXT:    b .LBB3_6
 ; CHECK-NEXT:  .LBB3_3: // %vector.ph
 ; CHECK-NEXT:    and x10, x9, #0xfffffff0
+; CHECK-NEXT:    dup v0.4h, w8
 ; CHECK-NEXT:    add x11, x2, #32
 ; CHECK-NEXT:    add x12, x0, #16
 ; CHECK-NEXT:    mov x13, x10
-; CHECK-NEXT:    dup v0.4s, w8
+; CHECK-NEXT:    dup v1.8h, w8
 ; CHECK-NEXT:  .LBB3_4: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    ldp q1, q2, [x12, #-16]
+; CHECK-NEXT:    ldp q2, q3, [x12, #-16]
 ; CHECK-NEXT:    subs x13, x13, #16
 ; CHECK-NEXT:    add x12, x12, #32
-; CHECK-NEXT:    sshll2 v3.4s, v1.8h, #0
-; CHECK-NEXT:    sshll v1.4s, v1.4h, #0
-; CHECK-NEXT:    sshll2 v4.4s, v2.8h, #0
-; CHECK-NEXT:    sshll v2.4s, v2.4h, #0
-; CHECK-NEXT:    mul v3.4s, v0.4s, v3.4s
-; CHECK-NEXT:    mul v1.4s, v0.4s, v1.4s
-; CHECK-NEXT:    mul v4.4s, v0.4s, v4.4s
-; CHECK-NEXT:    mul v2.4s, v0.4s, v2.4s
-; CHECK-NEXT:    stp q1, q3, [x11, #-32]
-; CHECK-NEXT:    stp q2, q4, [x11], #64
+; CHECK-NEXT:    smull2 v4.4s, v1.8h, v2.8h
+; CHECK-NEXT:    smull v2.4s, v0.4h, v2.4h
+; CHECK-NEXT:    smull2 v5.4s, v1.8h, v3.8h
+; CHECK-NEXT:    smull v3.4s, v0.4h, v3.4h
+; CHECK-NEXT:    stp q2, q4, [x11, #-32]
+; CHECK-NEXT:    stp q3, q5, [x11], #64
 ; CHECK-NEXT:    b.ne .LBB3_4
 ; CHECK-NEXT:  // %bb.5: // %middle.block
 ; CHECK-NEXT:    cmp x10, x9
@@ -317,25 +314,22 @@ define void @larger_umull(i16* nocapture noundef readonly %x, i16 noundef %y, i3
 ; CHECK-NEXT:    b .LBB4_6
 ; CHECK-NEXT:  .LBB4_3: // %vector.ph
 ; CHECK-NEXT:    and x10, x9, #0xfffffff0
+; CHECK-NEXT:    dup v0.4h, w8
 ; CHECK-NEXT:    add x11, x2, #32
 ; CHECK-NEXT:    add x12, x0, #16
 ; CHECK-NEXT:    mov x13, x10
-; CHECK-NEXT:    dup v0.4s, w8
+; CHECK-NEXT:    dup v1.8h, w8
 ; CHECK-NEXT:  .LBB4_4: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    ldp q1, q2, [x12, #-16]
+; CHECK-NEXT:    ldp q2, q3, [x12, #-16]
 ; CHECK-NEXT:    subs x13, x13, #16
 ; CHECK-NEXT:    add x12, x12, #32
-; CHECK-NEXT:    ushll2 v3.4s, v1.8h, #0
-; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    ushll2 v4.4s, v2.8h, #0
-; CHECK-NEXT:    ushll v2.4s, v2.4h, #0
-; CHECK-NEXT:    mul v3.4s, v0.4s, v3.4s
-; CHECK-NEXT:    mul v1.4s, v0.4s, v1.4s
-; CHECK-NEXT:    mul v4.4s, v0.4s, v4.4s
-; CHECK-NEXT:    mul v2.4s, v0.4s, v2.4s
-; CHECK-NEXT:    stp q1, q3, [x11, #-32]
-; CHECK-NEXT:    stp q2, q4, [x11], #64
+; CHECK-NEXT:    umull2 v4.4s, v1.8h, v2.8h
+; CHECK-NEXT:    umull v2.4s, v0.4h, v2.4h
+; CHECK-NEXT:    umull2 v5.4s, v1.8h, v3.8h
+; CHECK-NEXT:    umull v3.4s, v0.4h, v3.4h
+; CHECK-NEXT:    stp q2, q4, [x11, #-32]
+; CHECK-NEXT:    stp q3, q5, [x11], #64
 ; CHECK-NEXT:    b.ne .LBB4_4
 ; CHECK-NEXT:  // %bb.5: // %middle.block
 ; CHECK-NEXT:    cmp x10, x9
@@ -435,12 +429,13 @@ define i16 @red_mla_dup_ext_u8_s8_s16(i8* noalias nocapture noundef readonly %A,
 ; CHECK-NEXT:    mov w0, wzr
 ; CHECK-NEXT:    ret
 ; CHECK-NEXT:  .LBB5_4: // %vector.ph
+; CHECK-NEXT:    dup v2.8b, w9
 ; CHECK-NEXT:    and x11, x10, #0xfffffff0
-; CHECK-NEXT:    add x8, x0, #8
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
-; CHECK-NEXT:    mov x12, x11
+; CHECK-NEXT:    add x8, x0, #8
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-NEXT:    dup v2.8h, w9
+; CHECK-NEXT:    mov x12, x11
+; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
 ; CHECK-NEXT:  .LBB5_5: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldp d3, d4, [x8, #-8]


        


More information about the llvm-commits mailing list