[llvm] edf9f88 - [AArch64] Handle 64bit vector s/umull from extracts
David Green via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 14 02:25:16 PDT 2023
Author: David Green
Date: 2023-07-14T10:25:12+01:00
New Revision: edf9f885668e5d928abf6d9f58d481a37bea07cf
URL: https://github.com/llvm/llvm-project/commit/edf9f885668e5d928abf6d9f58d481a37bea07cf
DIFF: https://github.com/llvm/llvm-project/commit/edf9f885668e5d928abf6d9f58d481a37bea07cf.diff
LOG: [AArch64] Handle 64bit vector s/umull from extracts
This is similar to D153632, but for mul nodes instead of add/sub. They get
recognised in LowerMUL in order to detect the mul(ext, ext), in a way that will
work for i64 nodes as well as i16/i32. This extends it to look for
mul(subvector_extract(ext(x), 0), subvector_extract(ext(y), 0)), generating a
subvector_extract(mull(x,y)) if it matches.
Differential Revision: https://reviews.llvm.org/D154063
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4ce0a4e6b9ed90..2a2953c359984a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,12 +1128,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SMIN, VT, Custom);
}
- // AArch64 doesn't have MUL.2d:
- setOperationAction(ISD::MUL, MVT::v2i64, Expand);
// Custom handling for some quad-vector types to detect MULL.
setOperationAction(ISD::MUL, MVT::v8i16, Custom);
setOperationAction(ISD::MUL, MVT::v4i32, Custom);
setOperationAction(ISD::MUL, MVT::v2i64, Custom);
+ setOperationAction(ISD::MUL, MVT::v4i16, Custom);
+ setOperationAction(ISD::MUL, MVT::v2i32, Custom);
+ setOperationAction(ISD::MUL, MVT::v1i64, Custom);
// Saturates
for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
@@ -4592,24 +4593,44 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
// If SVE is available then i64 vector multiplications can also be made legal.
- bool OverrideNEON =
- VT == MVT::v1i64 || Subtarget->forceStreamingCompatibleSVE();
+ bool OverrideNEON = Subtarget->forceStreamingCompatibleSVE();
if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
- // Multiplications are only custom-lowered for 128-bit vectors so that
- // VMULL can be detected. Otherwise v2i64 multiplications are not legal.
- assert(VT.is128BitVector() && VT.isInteger() &&
+ // Multiplications are only custom-lowered for 128-bit and 64-bit vectors so
+ // that VMULL can be detected. Otherwise v2i64 multiplications are not legal.
+ assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() &&
"unexpected type for custom-lowering ISD::MUL");
SDNode *N0 = Op.getOperand(0).getNode();
SDNode *N1 = Op.getOperand(1).getNode();
bool isMLA = false;
+ EVT OVT = VT;
+ if (VT.is64BitVector()) {
+ if (N0->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ isNullConstant(N0->getOperand(1)) &&
+ N1->getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ isNullConstant(N1->getOperand(1))) {
+ N0 = N0->getOperand(0).getNode();
+ N1 = N1->getOperand(0).getNode();
+ VT = N0->getValueType(0);
+ } else {
+ if (VT == MVT::v1i64) {
+ if (Subtarget->hasSVE())
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
+ // Fall through to expand this. It is not legal.
+ return SDValue();
+ } else
+ // Other vector multiplications are legal.
+ return Op;
+ }
+ }
+
SDLoc DL(Op);
unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA);
if (!NewOpc) {
- if (VT == MVT::v2i64) {
+ if (VT.getVectorElementType() == MVT::i64) {
// If SVE is available then i64 vector multiplications can also be made
// legal.
if (Subtarget->hasSVE())
@@ -4629,7 +4650,9 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
assert(Op0.getValueType().is64BitVector() &&
Op1.getValueType().is64BitVector() &&
"unexpected types for extended operands to VMULL");
- return DAG.getNode(NewOpc, DL, VT, Op0, Op1);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OVT,
+ DAG.getNode(NewOpc, DL, VT, Op0, Op1),
+ DAG.getConstant(0, DL, MVT::i64));
}
// Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
// isel lowering to take advantage of no-stall back to back s/umul + s/umla.
@@ -4637,11 +4660,14 @@ SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG);
SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG);
EVT Op1VT = Op1.getValueType();
- return DAG.getNode(N0->getOpcode(), DL, VT,
- DAG.getNode(NewOpc, DL, VT,
- DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
- DAG.getNode(NewOpc, DL, VT,
- DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1));
+ return DAG.getNode(
+ ISD::EXTRACT_SUBVECTOR, DL, OVT,
+ DAG.getNode(N0->getOpcode(), DL, VT,
+ DAG.getNode(NewOpc, DL, VT,
+ DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
+ DAG.getNode(NewOpc, DL, VT,
+ DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)),
+ DAG.getConstant(0, DL, MVT::i64));
}
static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
diff --git a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
index f09de0c5b9e1aa..d79c0720555563 100644
--- a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
+++ b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll
@@ -120,9 +120,8 @@ entry:
define <4 x i16> @mulls_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) {
; CHECK-LABEL: mulls_v8i8_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sshll v0.8h, v0.8b, #0
-; CHECK-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-NEXT: mul v0.4h, v0.4h, v1.4h
+; CHECK-NEXT: smull v0.8h, v0.8b, v1.8b
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = sext <8 x i8> %s0 to <8 x i16>
@@ -149,9 +148,8 @@ entry:
define <4 x i16> @mullu_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) {
; CHECK-LABEL: mullu_v8i8_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEXT: mul v0.4h, v0.4h, v1.4h
+; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = zext <8 x i8> %s0 to <8 x i16>
@@ -294,9 +292,8 @@ entry:
define <2 x i32> @mulls_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) {
; CHECK-LABEL: mulls_v4i16_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sshll v0.4s, v0.4h, #0
-; CHECK-NEXT: sshll v1.4s, v1.4h, #0
-; CHECK-NEXT: mul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = sext <4 x i16> %s0 to <4 x i32>
@@ -323,9 +320,8 @@ entry:
define <2 x i32> @mullu_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) {
; CHECK-LABEL: mullu_v4i16_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ushll v0.4s, v0.4h, #0
-; CHECK-NEXT: ushll v1.4s, v1.4h, #0
-; CHECK-NEXT: mul v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = zext <4 x i16> %s0 to <4 x i32>
@@ -468,12 +464,8 @@ entry:
define <1 x i64> @mulls_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) {
; CHECK-LABEL: mulls_v2i32_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sshll v0.2d, v0.2s, #0
-; CHECK-NEXT: sshll v1.2d, v1.2s, #0
-; CHECK-NEXT: fmov x9, d0
-; CHECK-NEXT: fmov x8, d1
-; CHECK-NEXT: smull x8, w9, w8
-; CHECK-NEXT: fmov d0, x8
+; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = sext <2 x i32> %s0 to <2 x i64>
@@ -504,12 +496,8 @@ entry:
define <1 x i64> @mullu_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) {
; CHECK-LABEL: mullu_v2i32_0:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ushll v0.2d, v0.2s, #0
-; CHECK-NEXT: ushll v1.2d, v1.2s, #0
-; CHECK-NEXT: fmov x9, d0
-; CHECK-NEXT: fmov x8, d1
-; CHECK-NEXT: umull x8, w9, w8
-; CHECK-NEXT: fmov d0, x8
+; CHECK-NEXT: umull v0.2d, v0.2s, v1.2s
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0
; CHECK-NEXT: ret
entry:
%s0s = zext <2 x i32> %s0 to <2 x i64>
More information about the llvm-commits
mailing list