[llvm] [AArch64] Improve mull generation (PR #114997)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 14 10:48:22 PST 2024
https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/114997
>From 26308822b5bc5c5029b2b72f9556a16540132c1c Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Thu, 14 Nov 2024 18:46:43 +0000
Subject: [PATCH] [AArch64] Improve mull generation
This attempts to clean up and improve where we generate smull using known-bits.
For v2i64 types (where no mul is present), we try to create mull more
aggressively to avoid scalarization.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 107 ++++--------------
llvm/test/CodeGen/AArch64/aarch64-smull.ll | 92 +++++----------
2 files changed, 53 insertions(+), 146 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e7923ff02de704..7bb27b6a0da900 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5173,40 +5173,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
}
-static EVT getExtensionTo64Bits(const EVT &OrigVT) {
- if (OrigVT.getSizeInBits() >= 64)
- return OrigVT;
-
- assert(OrigVT.isSimple() && "Expecting a simple value type");
-
- MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
- switch (OrigSimpleTy) {
- default: llvm_unreachable("Unexpected Vector Type");
- case MVT::v2i8:
- case MVT::v2i16:
- return MVT::v2i32;
- case MVT::v4i8:
- return MVT::v4i16;
- }
-}
-
-static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
- const EVT &OrigTy,
- const EVT &ExtTy,
- unsigned ExtOpcode) {
- // The vector originally had a size of OrigTy. It was then extended to ExtTy.
- // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
- // 64-bits we need to insert a new extension so that it will be 64-bits.
- assert(ExtTy.is128BitVector() && "Unexpected extension size");
- if (OrigTy.getSizeInBits() >= 64)
- return N;
-
- // Must extend size to at least 64 bits to be used as an operand for VMULL.
- EVT NewVT = getExtensionTo64Bits(OrigTy);
-
- return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
-}
-
// Returns lane if Op extracts from a two-element vector and lane is constant
// (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
static std::optional<uint64_t>
@@ -5252,31 +5218,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
EVT VT = N.getValueType();
assert(VT.is128BitVector() && "Unexpected vector MULL size");
-
- unsigned NumElts = VT.getVectorNumElements();
- unsigned OrigEltSize = VT.getScalarSizeInBits();
- unsigned EltSize = OrigEltSize / 2;
- MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
-
- APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
- if (DAG.MaskedValueIsZero(N, HiBits))
- return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
-
- if (ISD::isExtOpcode(N.getOpcode()))
- return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
- N.getOperand(0).getValueType(), VT,
- N.getOpcode());
-
- assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
- SDLoc dl(N);
- SmallVector<SDValue, 8> Ops;
- for (unsigned i = 0; i != NumElts; ++i) {
- const APInt &CInt = N.getConstantOperandAPInt(i);
- // Element types smaller than 32 bits are not legal, so use i32 elements.
- // The values are implicitly truncated so sext vs. zext doesn't matter.
- Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
- }
- return DAG.getBuildVector(TruncVT, dl, Ops);
+ EVT HalfVT = EVT::getVectorVT(
+ *DAG.getContext(),
+ VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
+ VT.getVectorElementCount());
+ return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
}
static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5452,34 +5398,27 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
if (IsN0ZExt && IsN1ZExt)
return AArch64ISD::UMULL;
- // Select SMULL if we can replace zext with sext.
- if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
- !isExtendedBUILD_VECTOR(N0, DAG, false) &&
- !isExtendedBUILD_VECTOR(N1, DAG, false)) {
- SDValue ZextOperand;
- if (IsN0ZExt)
- ZextOperand = N0.getOperand(0);
- else
- ZextOperand = N1.getOperand(0);
- if (DAG.SignBitIsZero(ZextOperand)) {
- SDValue NewSext =
- DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
- if (IsN0ZExt)
- N0 = NewSext;
- else
- N1 = NewSext;
- return AArch64ISD::SMULL;
- }
- }
-
// Select UMULL if we can replace the other operand with an extend.
- if (IsN0ZExt || IsN1ZExt) {
- EVT VT = N0.getValueType();
- APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
- VT.getScalarSizeInBits() / 2);
+ EVT VT = N0.getValueType();
+ APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
+ VT.getScalarSizeInBits() / 2);
+ if (IsN0ZExt || IsN1ZExt)
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
return AArch64ISD::UMULL;
- }
+ // For v2i64 we look more aggresively at both operands being zero, to avoid
+ // scalarization.
+ if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
+ DAG.MaskedValueIsZero(N1, Mask))
+ return AArch64ISD::UMULL;
+
+ if (IsN0SExt || IsN1SExt)
+ if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
+ VT.getScalarSizeInBits() / 2)
+ return AArch64ISD::SMULL;
+ if (VT == MVT::v2i64 &&
+ DAG.ComputeNumSignBits(N0) > VT.getScalarSizeInBits() / 2 &&
+ DAG.ComputeNumSignBits(N1) > VT.getScalarSizeInBits() / 2)
+ return AArch64ISD::SMULL;
if (!IsN1SExt && !IsN1ZExt)
return 0;
diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
index 3c4901ade972ec..0a38fd1488e2eb 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
; CHECK-NEON: // %bb.0:
-; CHECK-NEON-NEXT: ldr d0, [x1]
-; CHECK-NEON-NEXT: ldrh w9, [x0]
-; CHECK-NEON-NEXT: ldrh w10, [x0, #2]
-; CHECK-NEON-NEXT: sshll v0.2d, v0.2s, #0
-; CHECK-NEON-NEXT: fmov x11, d0
-; CHECK-NEON-NEXT: mov x8, v0.d[1]
-; CHECK-NEON-NEXT: smull x9, w9, w11
-; CHECK-NEON-NEXT: smull x8, w10, w8
-; CHECK-NEON-NEXT: fmov d0, x9
-; CHECK-NEON-NEXT: mov v0.d[1], x8
+; CHECK-NEON-NEXT: ldrh w8, [x0]
+; CHECK-NEON-NEXT: ldrh w9, [x0, #2]
+; CHECK-NEON-NEXT: ldr d1, [x1]
+; CHECK-NEON-NEXT: fmov d0, x8
+; CHECK-NEON-NEXT: mov v0.d[1], x9
+; CHECK-NEON-NEXT: xtn v0.2s, v0.2d
+; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-NEON-NEXT: ret
;
; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
; CHECK-SVE: // %bb.0:
; CHECK-SVE-NEXT: ldrh w8, [x0]
; CHECK-SVE-NEXT: ldrh w9, [x0, #2]
-; CHECK-SVE-NEXT: ptrue p0.d, vl2
-; CHECK-SVE-NEXT: ldr d0, [x1]
-; CHECK-SVE-NEXT: fmov d1, x8
-; CHECK-SVE-NEXT: sshll v0.2d, v0.2s, #0
-; CHECK-SVE-NEXT: mov v1.d[1], x9
-; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT: ldr d1, [x1]
+; CHECK-SVE-NEXT: fmov d0, x8
+; CHECK-SVE-NEXT: mov v0.d[1], x9
+; CHECK-SVE-NEXT: xtn v0.2s, v0.2d
+; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-SVE-NEXT: ret
;
; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2405,25 +2400,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
; CHECK-NEON-LABEL: lsr:
; CHECK-NEON: // %bb.0:
-; CHECK-NEON-NEXT: ushr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT: ushr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT: fmov x10, d1
-; CHECK-NEON-NEXT: fmov x11, d0
-; CHECK-NEON-NEXT: mov x8, v1.d[1]
-; CHECK-NEON-NEXT: mov x9, v0.d[1]
-; CHECK-NEON-NEXT: umull x10, w11, w10
-; CHECK-NEON-NEXT: umull x8, w9, w8
-; CHECK-NEON-NEXT: fmov d0, x10
-; CHECK-NEON-NEXT: mov v0.d[1], x8
+; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s
; CHECK-NEON-NEXT: ret
;
; CHECK-SVE-LABEL: lsr:
; CHECK-SVE: // %bb.0:
-; CHECK-SVE-NEXT: ushr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT: ushr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT: ptrue p0.d, vl2
-; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT: umull v0.2d, v0.2s, v1.2s
; CHECK-SVE-NEXT: ret
;
; CHECK-GI-LABEL: lsr:
@@ -2482,25 +2468,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
; CHECK-NEON-LABEL: asr:
; CHECK-NEON: // %bb.0:
-; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT: sshr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT: fmov x10, d1
-; CHECK-NEON-NEXT: fmov x11, d0
-; CHECK-NEON-NEXT: mov x8, v1.d[1]
-; CHECK-NEON-NEXT: mov x9, v0.d[1]
-; CHECK-NEON-NEXT: smull x10, w11, w10
-; CHECK-NEON-NEXT: smull x8, w9, w8
-; CHECK-NEON-NEXT: fmov d0, x10
-; CHECK-NEON-NEXT: mov v0.d[1], x8
+; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-NEON-NEXT: ret
;
; CHECK-SVE-LABEL: asr:
; CHECK-SVE: // %bb.0:
-; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT: sshr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT: ptrue p0.d, vl2
-; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-SVE-NEXT: ret
;
; CHECK-GI-LABEL: asr:
@@ -2525,25 +2502,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
; CHECK-NEON-LABEL: asr_const:
; CHECK-NEON: // %bb.0:
-; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT: fmov x9, d0
-; CHECK-NEON-NEXT: mov x8, v0.d[1]
-; CHECK-NEON-NEXT: lsl x10, x9, #5
-; CHECK-NEON-NEXT: lsl x11, x8, #5
-; CHECK-NEON-NEXT: sub x9, x10, x9
-; CHECK-NEON-NEXT: fmov d0, x9
-; CHECK-NEON-NEXT: sub x8, x11, x8
-; CHECK-NEON-NEXT: mov v0.d[1], x8
+; CHECK-NEON-NEXT: movi v1.2s, #31
+; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-NEON-NEXT: ret
;
; CHECK-SVE-LABEL: asr_const:
; CHECK-SVE: // %bb.0:
-; CHECK-SVE-NEXT: mov w8, #31 // =0x1f
-; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT: ptrue p0.d, vl2
-; CHECK-SVE-NEXT: dup v1.2d, x8
-; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT: movi v1.2s, #31
+; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-SVE-NEXT: ret
;
; CHECK-GI-LABEL: asr_const:
More information about the llvm-commits
mailing list