[llvm] 6a8d8f3 - [AArch64][DAGCombiner]: combine <2xi64> add/sub.
Hassnaa Hamdi via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 5 02:19:35 PDT 2023
Author: Hassnaa Hamdi
Date: 2023-04-05T09:18:09Z
New Revision: 6a8d8f3e28aed1b77356be74ee5109d7bdd37dd1
URL: https://github.com/llvm/llvm-project/commit/6a8d8f3e28aed1b77356be74ee5109d7bdd37dd1
DIFF: https://github.com/llvm/llvm-project/commit/6a8d8f3e28aed1b77356be74ee5109d7bdd37dd1.diff
LOG: [AArch64][DAGCombiner]: combine <2xi64> add/sub.
64-bit vector mul is not supported in NEON,
so we use the SVE's mul.
To improve the performance, we can go one step further,
and use SVE's add/sub, so that we can use SVE's mla/mls.
That works on these patterns:
// This works on the patterns of:
// add v1, (mul v2, v3)
// sub v1, (mul v2, v3)
Reviewed By: david-arm
Differential Revision: https://reviews.llvm.org/D147236
Added:
llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 00f2964960c6..697240a1afe5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17804,9 +17804,64 @@ static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
}
+// This works on the patterns of:
+// add v1, (mul v2, v3)
+// sub v1, (mul v2, v3)
+// for vectors of type <1 x i64> and <2 x i64> when SVE is available.
+// It will transform the add/sub to a scalable version, so that we can
+// make use of SVE's MLA/MLS that will be generated for that pattern
+static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) {
+ // Before using SVE's features, check first if it's available.
+ if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE())
+ return SDValue();
+
+ if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB)
+ return SDValue();
+
+ if (!N->getValueType(0).isFixedLengthVector())
+ return SDValue();
+
+ SDValue MulValue, Op, ExtractIndexValue, ExtractOp;
+
+ if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) {
+ ExtractOp = N->getOperand(0);
+ Op = N->getOperand(1);
+ } else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) {
+ ExtractOp = N->getOperand(1);
+ Op = N->getOperand(0);
+ } else
+ return SDValue();
+
+ MulValue = ExtractOp.getOperand(0);
+ ExtractIndexValue = ExtractOp.getOperand(1);
+
+ if (!ExtractOp.hasOneUse() && !MulValue.hasOneUse())
+ return SDValue();
+
+ // If the Opcode is NOT MUL, then that is NOT the expected pattern:
+ if (MulValue.getOpcode() != AArch64ISD::MUL_PRED)
+ return SDValue();
+
+ // If the Mul value type is NOT scalable vector, then that is NOT the expected
+ // pattern:
+ EVT VT = MulValue.getValueType();
+ if (!VT.isScalableVector())
+ return SDValue();
+
+ // If the ConstValue is NOT 0, then that is NOT the expected pattern:
+ if (!cast<ConstantSDNode>(ExtractIndexValue)->isZero())
+ return SDValue();
+
+ SDValue ScaledOp = convertToScalableVector(DAG, VT, Op);
+ SDValue NewValue = DAG.getNode(N->getOpcode(), SDLoc(N), VT, {ScaledOp, MulValue});
+ return convertFromScalableVector(DAG, N->getValueType(0), NewValue);
+}
+
static SDValue performAddSubCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
+ if (SDValue Val = performMulAddSubCombine(N, DAG))
+ return Val;
// Try to change sum of two reductions.
if (SDValue Val = performAddUADDVCombine(N, DAG))
return Val;
diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll
new file mode 100644
index 000000000000..f83b0ce552ff
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-sub-mul.ll
@@ -0,0 +1,62 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s
+
+define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) {
+; CHECK-LABEL: test_mul_add_2x64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %mul = mul <2 x i64> %b, %c
+ %add = add <2 x i64> %a, %mul
+ ret <2 x i64> %add
+}
+
+define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) {
+; CHECK-LABEL: test_mul_add_1x64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl1
+; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+ %mul = mul <1 x i64> %b, %c
+ %add = add <1 x i64> %mul, %a
+ ret <1 x i64> %add
+}
+
+define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) {
+; CHECK-LABEL: test_mul_sub_2x64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
+; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %mul = mul <2 x i64> %b, %c
+ %sub = sub <2 x i64> %a, %mul
+ ret <2 x i64> %sub
+}
+
+define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) {
+; CHECK-LABEL: test_mul_sub_1x64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl1
+; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT: ret
+ %mul = mul <1 x i64> %b, %c
+ %sub = sub <1 x i64> %mul, %a
+ ret <1 x i64> %sub
+}
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll
index 618c3f3ce8ce..376214c6f65b 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll
@@ -606,7 +606,7 @@ define void @srem_v16i32(ptr %a, ptr %b) #0 {
;
; VBITS_GE_256-LABEL: srem_v16i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: mov x8, #8
+; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2]
; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0]
@@ -680,13 +680,13 @@ define void @srem_v64i32(ptr %a, ptr %b) vscale_range(16,0) #0 {
define <1 x i64> @srem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 {
; CHECK-LABEL: srem_v1i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
-; CHECK-NEXT: ptrue p0.d, vl1
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl1
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
-; CHECK-NEXT: sub d0, d0, d1
+; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%res = srem <1 x i64> %op1, %op2
ret <1 x i64> %res
@@ -697,13 +697,13 @@ define <1 x i64> @srem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #
define <2 x i64> @srem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 {
; CHECK-LABEL: srem_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
-; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = srem <2 x i64> %op1, %op2
ret <2 x i64> %res
@@ -730,34 +730,32 @@ define void @srem_v4i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @srem_v8i64(ptr %a, ptr %b) #0 {
; VBITS_GE_128-LABEL: srem_v8i64:
; VBITS_GE_128: // %bb.0:
-; VBITS_GE_128-NEXT: ldp q4, q5, [x1]
-; VBITS_GE_128-NEXT: ptrue p0.d, vl2
-; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32]
; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32]
-; VBITS_GE_128-NEXT: ldp q2, q3, [x0]
-; VBITS_GE_128-NEXT: movprfx z16, z3
-; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z5.d
-; VBITS_GE_128-NEXT: movprfx z17, z2
-; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z4.d
-; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d
+; VBITS_GE_128-NEXT: ptrue p0.d, vl2
+; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32]
; VBITS_GE_128-NEXT: movprfx z16, z1
+; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z3.d
+; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d
+; VBITS_GE_128-NEXT: movprfx z3, z0
+; VBITS_GE_128-NEXT: sdiv z3.d, p0/m, z3.d, z2.d
+; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d
+; VBITS_GE_128-NEXT: ldp q4, q5, [x0]
+; VBITS_GE_128-NEXT: ldp q7, q6, [x1]
+; VBITS_GE_128-NEXT: movprfx z16, z5
; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z6.d
-; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d
-; VBITS_GE_128-NEXT: movprfx z17, z0
-; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z7.d
-; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d
-; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d
-; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d
-; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d
-; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d
+; VBITS_GE_128-NEXT: movprfx z2, z4
+; VBITS_GE_128-NEXT: sdiv z2.d, p0/m, z2.d, z7.d
; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32]
-; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d
-; VBITS_GE_128-NEXT: stp q2, q0, [x0]
+; VBITS_GE_128-NEXT: movprfx z0, z4
+; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d
+; VBITS_GE_128-NEXT: movprfx z1, z5
+; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d
+; VBITS_GE_128-NEXT: stp q0, q1, [x0]
; VBITS_GE_128-NEXT: ret
;
; VBITS_GE_256-LABEL: srem_v8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: mov x8, #4
+; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3]
; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]
@@ -1426,7 +1424,7 @@ define void @urem_v16i32(ptr %a, ptr %b) #0 {
;
; VBITS_GE_256-LABEL: urem_v16i32:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: mov x8, #8
+; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2]
; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0]
@@ -1500,13 +1498,13 @@ define void @urem_v64i32(ptr %a, ptr %b) vscale_range(16,0) #0 {
define <1 x i64> @urem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 {
; CHECK-LABEL: urem_v1i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
-; CHECK-NEXT: ptrue p0.d, vl1
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl1
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
-; CHECK-NEXT: sub d0, d0, d1
+; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
+; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%res = urem <1 x i64> %op1, %op2
ret <1 x i64> %res
@@ -1517,13 +1515,13 @@ define <1 x i64> @urem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #
define <2 x i64> @urem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 {
; CHECK-LABEL: urem_v2i64:
; CHECK: // %bb.0:
-; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
-; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
-; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
-; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
; CHECK-NEXT: ret
%res = urem <2 x i64> %op1, %op2
ret <2 x i64> %res
@@ -1550,34 +1548,32 @@ define void @urem_v4i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
define void @urem_v8i64(ptr %a, ptr %b) #0 {
; VBITS_GE_128-LABEL: urem_v8i64:
; VBITS_GE_128: // %bb.0:
-; VBITS_GE_128-NEXT: ldp q4, q5, [x1]
-; VBITS_GE_128-NEXT: ptrue p0.d, vl2
-; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32]
; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32]
-; VBITS_GE_128-NEXT: ldp q2, q3, [x0]
-; VBITS_GE_128-NEXT: movprfx z16, z3
-; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z5.d
-; VBITS_GE_128-NEXT: movprfx z17, z2
-; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z4.d
-; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d
+; VBITS_GE_128-NEXT: ptrue p0.d, vl2
+; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32]
; VBITS_GE_128-NEXT: movprfx z16, z1
+; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z3.d
+; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d
+; VBITS_GE_128-NEXT: movprfx z3, z0
+; VBITS_GE_128-NEXT: udiv z3.d, p0/m, z3.d, z2.d
+; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d
+; VBITS_GE_128-NEXT: ldp q4, q5, [x0]
+; VBITS_GE_128-NEXT: ldp q7, q6, [x1]
+; VBITS_GE_128-NEXT: movprfx z16, z5
; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z6.d
-; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d
-; VBITS_GE_128-NEXT: movprfx z17, z0
-; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z7.d
-; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d
-; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d
-; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d
-; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d
-; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d
+; VBITS_GE_128-NEXT: movprfx z2, z4
+; VBITS_GE_128-NEXT: udiv z2.d, p0/m, z2.d, z7.d
; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32]
-; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d
-; VBITS_GE_128-NEXT: stp q2, q0, [x0]
+; VBITS_GE_128-NEXT: movprfx z0, z4
+; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d
+; VBITS_GE_128-NEXT: movprfx z1, z5
+; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d
+; VBITS_GE_128-NEXT: stp q0, q1, [x0]
; VBITS_GE_128-NEXT: ret
;
; VBITS_GE_256-LABEL: urem_v8i64:
; VBITS_GE_256: // %bb.0:
-; VBITS_GE_256-NEXT: mov x8, #4
+; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3]
; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]
More information about the llvm-commits
mailing list