[llvm] [AArch64] Improve mull generation (PR #114997)

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 01:10:42 PST 2024


https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/114997

>From cc9dfdcbc1f753656ca91bbfa077fceafdd1304e Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Wed, 20 Nov 2024 09:10:29 +0000
Subject: [PATCH] [AArch64] Improve mull generation

This attempts to clean up and improve where we generate smull using known-bits.
For v2i64 types (where no mul is present), we try to create mull more
aggressively to avoid scalarization.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 103 ++++--------------
 llvm/test/CodeGen/AArch64/aarch64-smull.ll    |  92 +++++-----------
 2 files changed, 51 insertions(+), 144 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ad1d1237aa25a9..2ec0e0bb7dff71 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5186,40 +5186,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
   return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
 }
 
-static EVT getExtensionTo64Bits(const EVT &OrigVT) {
-  if (OrigVT.getSizeInBits() >= 64)
-    return OrigVT;
-
-  assert(OrigVT.isSimple() && "Expecting a simple value type");
-
-  MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
-  switch (OrigSimpleTy) {
-  default: llvm_unreachable("Unexpected Vector Type");
-  case MVT::v2i8:
-  case MVT::v2i16:
-     return MVT::v2i32;
-  case MVT::v4i8:
-    return  MVT::v4i16;
-  }
-}
-
-static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
-                                                 const EVT &OrigTy,
-                                                 const EVT &ExtTy,
-                                                 unsigned ExtOpcode) {
-  // The vector originally had a size of OrigTy. It was then extended to ExtTy.
-  // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
-  // 64-bits we need to insert a new extension so that it will be 64-bits.
-  assert(ExtTy.is128BitVector() && "Unexpected extension size");
-  if (OrigTy.getSizeInBits() >= 64)
-    return N;
-
-  // Must extend size to at least 64 bits to be used as an operand for VMULL.
-  EVT NewVT = getExtensionTo64Bits(OrigTy);
-
-  return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
-}
-
 // Returns lane if Op extracts from a two-element vector and lane is constant
 // (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
 static std::optional<uint64_t>
@@ -5265,31 +5231,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
 static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
   EVT VT = N.getValueType();
   assert(VT.is128BitVector() && "Unexpected vector MULL size");
-
-  unsigned NumElts = VT.getVectorNumElements();
-  unsigned OrigEltSize = VT.getScalarSizeInBits();
-  unsigned EltSize = OrigEltSize / 2;
-  MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
-
-  APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
-  if (DAG.MaskedValueIsZero(N, HiBits))
-    return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
-
-  if (ISD::isExtOpcode(N.getOpcode()))
-    return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
-                                             N.getOperand(0).getValueType(), VT,
-                                             N.getOpcode());
-
-  assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
-  SDLoc dl(N);
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0; i != NumElts; ++i) {
-    const APInt &CInt = N.getConstantOperandAPInt(i);
-    // Element types smaller than 32 bits are not legal, so use i32 elements.
-    // The values are implicitly truncated so sext vs. zext doesn't matter.
-    Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
-  }
-  return DAG.getBuildVector(TruncVT, dl, Ops);
+  EVT HalfVT = EVT::getVectorVT(
+      *DAG.getContext(),
+      VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
+      VT.getVectorElementCount());
+  return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
 }
 
 static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5465,33 +5411,26 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
   if (IsN0ZExt && IsN1ZExt)
     return AArch64ISD::UMULL;
 
-  // Select SMULL if we can replace zext with sext.
-  if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
-      !isExtendedBUILD_VECTOR(N0, DAG, false) &&
-      !isExtendedBUILD_VECTOR(N1, DAG, false)) {
-    SDValue ZextOperand;
-    if (IsN0ZExt)
-      ZextOperand = N0.getOperand(0);
-    else
-      ZextOperand = N1.getOperand(0);
-    if (DAG.SignBitIsZero(ZextOperand)) {
-      SDValue NewSext =
-          DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
-      if (IsN0ZExt)
-        N0 = NewSext;
-      else
-        N1 = NewSext;
-      return AArch64ISD::SMULL;
-    }
-  }
-
   // Select UMULL if we can replace the other operand with an extend.
+  EVT VT = N0.getValueType();
+  unsigned EltSize = VT.getScalarSizeInBits();
+  APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2);
   if (IsN0ZExt || IsN1ZExt) {
-    EVT VT = N0.getValueType();
-    APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
-                                       VT.getScalarSizeInBits() / 2);
     if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
       return AArch64ISD::UMULL;
+  } else if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
+             DAG.MaskedValueIsZero(N1, Mask)) {
+    // For v2i64 we look more aggresively at both operands being zero, to avoid
+    // scalarization.
+    return AArch64ISD::UMULL;
+  }
+
+  if (IsN0SExt || IsN1SExt) {
+    if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) > EltSize / 2)
+      return AArch64ISD::SMULL;
+  } else if (VT == MVT::v2i64 && DAG.ComputeNumSignBits(N0) > EltSize / 2 &&
+             DAG.ComputeNumSignBits(N1) > EltSize / 2) {
+    return AArch64ISD::SMULL;
   }
 
   if (!IsN1SExt && !IsN1ZExt)
diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
index 69a69dbd3b18ba..0fe2bbe2c449f2 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
 define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
 ; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    ldr d0, [x1]
-; CHECK-NEON-NEXT:    ldrh w9, [x0]
-; CHECK-NEON-NEXT:    ldrh w10, [x0, #2]
-; CHECK-NEON-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v0.d[1]
-; CHECK-NEON-NEXT:    smull x9, w9, w11
-; CHECK-NEON-NEXT:    smull x8, w10, w8
-; CHECK-NEON-NEXT:    fmov d0, x9
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    ldrh w8, [x0]
+; CHECK-NEON-NEXT:    ldrh w9, [x0, #2]
+; CHECK-NEON-NEXT:    ldr d1, [x1]
+; CHECK-NEON-NEXT:    fmov d0, x8
+; CHECK-NEON-NEXT:    mov v0.d[1], x9
+; CHECK-NEON-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
 ; CHECK-SVE:       // %bb.0:
 ; CHECK-SVE-NEXT:    ldrh w8, [x0]
 ; CHECK-SVE-NEXT:    ldrh w9, [x0, #2]
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    ldr d0, [x1]
-; CHECK-SVE-NEXT:    fmov d1, x8
-; CHECK-SVE-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-SVE-NEXT:    mov v1.d[1], x9
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    ldr d1, [x1]
+; CHECK-SVE-NEXT:    fmov d0, x8
+; CHECK-SVE-NEXT:    mov v0.d[1], x9
+; CHECK-SVE-NEXT:    xtn v0.2s, v0.2d
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2404,25 +2399,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
 define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: lsr:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    ushr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    ushr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT:    fmov x10, d1
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v1.d[1]
-; CHECK-NEON-NEXT:    mov x9, v0.d[1]
-; CHECK-NEON-NEXT:    umull x10, w11, w10
-; CHECK-NEON-NEXT:    umull x8, w9, w8
-; CHECK-NEON-NEXT:    fmov d0, x10
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT:    umull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: lsr:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    ushr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    ushr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT:    umull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: lsr:
@@ -2481,25 +2467,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
 define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: asr:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    sshr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT:    fmov x10, d1
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v1.d[1]
-; CHECK-NEON-NEXT:    mov x9, v0.d[1]
-; CHECK-NEON-NEXT:    smull x10, w11, w10
-; CHECK-NEON-NEXT:    smull x8, w9, w8
-; CHECK-NEON-NEXT:    fmov d0, x10
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: asr:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    sshr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: asr:
@@ -2524,25 +2501,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
 define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: asr_const:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    fmov x9, d0
-; CHECK-NEON-NEXT:    mov x8, v0.d[1]
-; CHECK-NEON-NEXT:    lsl x10, x9, #5
-; CHECK-NEON-NEXT:    lsl x11, x8, #5
-; CHECK-NEON-NEXT:    sub x9, x10, x9
-; CHECK-NEON-NEXT:    fmov d0, x9
-; CHECK-NEON-NEXT:    sub x8, x11, x8
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    movi v1.2s, #31
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: asr_const:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    mov w8, #31 // =0x1f
-; CHECK-SVE-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    dup v1.2d, x8
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    movi v1.2s, #31
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: asr_const:



More information about the llvm-commits mailing list