[llvm] edf9f88 - [AArch64] Handle 64bit vector s/umull from extracts

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 14 02:25:16 PDT 2023


Author: David Green
Date: 2023-07-14T10:25:12+01:00
New Revision: edf9f885668e5d928abf6d9f58d481a37bea07cf

URL: https://github.com/llvm/llvm-project/commit/edf9f885668e5d928abf6d9f58d481a37bea07cf
DIFF: https://github.com/llvm/llvm-project/commit/edf9f885668e5d928abf6d9f58d481a37bea07cf.diff

LOG: [AArch64] Handle 64bit vector s/umull from extracts

This is similar to D153632, but for mul nodes instead of add/sub. They get
recognised in LowerMUL in order to detect the mul(ext, ext), in a way that will
work for i64 nodes as well as i16/i32. This extends it to look for
mul(subvector_extract(ext(x), 0), subvector_extract(ext(y), 0)), generating a
subvector_extract(mull(x,y)) if it matches.

Differential Revision: https://reviews.llvm.org/D154063

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/neon-extadd-extract.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4ce0a4e6b9ed90..2a2953c359984a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,12 +1128,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SMIN, VT, Custom);
     }
 
-    // AArch64 doesn't have MUL.2d:
-    setOperationAction(ISD::MUL, MVT::v2i64, Expand);
     // Custom handling for some quad-vector types to detect MULL.
     setOperationAction(ISD::MUL, MVT::v8i16, Custom);
     setOperationAction(ISD::MUL, MVT::v4i32, Custom);
     setOperationAction(ISD::MUL, MVT::v2i64, Custom);
+    setOperationAction(ISD::MUL, MVT::v4i16, Custom);
+    setOperationAction(ISD::MUL, MVT::v2i32, Custom);
+    setOperationAction(ISD::MUL, MVT::v1i64, Custom);
 
     // Saturates
     for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
@@ -4592,24 +4593,44 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
   EVT VT = Op.getValueType();
 
   // If SVE is available then i64 vector multiplications can also be made legal.
-  bool OverrideNEON =
-      VT == MVT::v1i64 || Subtarget->forceStreamingCompatibleSVE();
+  bool OverrideNEON = Subtarget->forceStreamingCompatibleSVE();
 
   if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
     return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
 
-  // Multiplications are only custom-lowered for 128-bit vectors so that
-  // VMULL can be detected.  Otherwise v2i64 multiplications are not legal.
-  assert(VT.is128BitVector() && VT.isInteger() &&
+  // Multiplications are only custom-lowered for 128-bit and 64-bit vectors so
+  // that VMULL can be detected.  Otherwise v2i64 multiplications are not legal.
+  assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() &&
          "unexpected type for custom-lowering ISD::MUL");
   SDNode *N0 = Op.getOperand(0).getNode();
   SDNode *N1 = Op.getOperand(1).getNode();
   bool isMLA = false;
+  EVT OVT = VT;
+  if (VT.is64BitVector()) {
+    if (N0->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        isNullConstant(N0->getOperand(1)) &&
+        N1->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        isNullConstant(N1->getOperand(1))) {
+      N0 = N0->getOperand(0).getNode();
+      N1 = N1->getOperand(0).getNode();
+      VT = N0->getValueType(0);
+    } else {
+      if (VT == MVT::v1i64) {
+        if (Subtarget->hasSVE())
+          return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
+        // Fall through to expand this.  It is not legal.
+        return SDValue();
+      } else
+        // Other vector multiplications are legal.
+        return Op;
+    }
+  }
+
   SDLoc DL(Op);
   unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA);
 
   if (!NewOpc) {
-    if (VT == MVT::v2i64) {
+    if (VT.getVectorElementType() == MVT::i64) {
       // If SVE is available then i64 vector multiplications can also be made
       // legal.
       if (Subtarget->hasSVE())
@@ -4629,7 +4650,9 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
     assert(Op0.getValueType().is64BitVector() &&
            Op1.getValueType().is64BitVector() &&
            "unexpected types for extended operands to VMULL");
-    return DAG.getNode(NewOpc, DL, VT, Op0, Op1);
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OVT,
+                       DAG.getNode(NewOpc, DL, VT, Op0, Op1),
+                       DAG.getConstant(0, DL, MVT::i64));
   }
   // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
   // isel lowering to take advantage of no-stall back to back s/umul + s/umla.
@@ -4637,11 +4660,14 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
   SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG);
   SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG);
   EVT Op1VT = Op1.getValueType();
-  return DAG.getNode(N0->getOpcode(), DL, VT,
-                     DAG.getNode(NewOpc, DL, VT,
-                               DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
-                     DAG.getNode(NewOpc, DL, VT,
-                               DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1));
+  return DAG.getNode(
+      ISD::EXTRACT_SUBVECTOR, DL, OVT,
+      DAG.getNode(N0->getOpcode(), DL, VT,
+                  DAG.getNode(NewOpc, DL, VT,
+                              DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
+                  DAG.getNode(NewOpc, DL, VT,
+                              DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)),
+      DAG.getConstant(0, DL, MVT::i64));
 }
 
 static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,

diff  --git a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
index f09de0c5b9e1aa..d79c0720555563 100644
--- a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
+++ b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
@@ -120,9 +120,8 @@ entry:
 define <4 x i16> @mulls_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) {
 ; CHECK-LABEL: mulls_v8i8_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    mul v0.4h, v0.4h, v1.4h
+; CHECK-NEXT:    smull v0.8h, v0.8b, v1.8b
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = sext <8 x i8> %s0 to <8 x i16>
@@ -149,9 +148,8 @@ entry:
 define <4 x i16> @mullu_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) {
 ; CHECK-LABEL: mullu_v8i8_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    mul v0.4h, v0.4h, v1.4h
+; CHECK-NEXT:    umull v0.8h, v0.8b, v1.8b
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = zext <8 x i8> %s0 to <8 x i16>
@@ -294,9 +292,8 @@ entry:
 define <2 x i32> @mulls_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) {
 ; CHECK-LABEL: mulls_v4i16_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.4s, v0.4h, #0
-; CHECK-NEXT:    sshll v1.4s, v1.4h, #0
-; CHECK-NEXT:    mul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    smull v0.4s, v0.4h, v1.4h
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = sext <4 x i16> %s0 to <4 x i32>
@@ -323,9 +320,8 @@ entry:
 define <2 x i32> @mullu_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) {
 ; CHECK-LABEL: mullu_v4i16_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.4s, v0.4h, #0
-; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    mul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    umull v0.4s, v0.4h, v1.4h
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = zext <4 x i16> %s0 to <4 x i32>
@@ -468,12 +464,8 @@ entry:
 define <1 x i64> @mulls_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) {
 ; CHECK-LABEL: mulls_v2i32_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-NEXT:    sshll v1.2d, v1.2s, #0
-; CHECK-NEXT:    fmov x9, d0
-; CHECK-NEXT:    fmov x8, d1
-; CHECK-NEXT:    smull x8, w9, w8
-; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    smull v0.2d, v0.2s, v1.2s
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = sext <2 x i32> %s0 to <2 x i64>
@@ -504,12 +496,8 @@ entry:
 define <1 x i64> @mullu_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) {
 ; CHECK-LABEL: mullu_v2i32_0:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.2d, v0.2s, #0
-; CHECK-NEXT:    ushll v1.2d, v1.2s, #0
-; CHECK-NEXT:    fmov x9, d0
-; CHECK-NEXT:    fmov x8, d1
-; CHECK-NEXT:    umull x8, w9, w8
-; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    umull v0.2d, v0.2s, v1.2s
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $q0
 ; CHECK-NEXT:    ret
 entry:
   %s0s = zext <2 x i32> %s0 to <2 x i64>


        


More information about the llvm-commits mailing list