[llvm] 274ac9d - [AArch64][SVE] Lowering sve.dot to DOT node
Jun Ma via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 2 05:17:41 PDT 2021
Author: Jun Ma
Date: 2021-04-02T20:05:17+08:00
New Revision: 274ac9d40e79f25ac8c928732875708b5bac8f09
URL: https://github.com/llvm/llvm-project/commit/274ac9d40e79f25ac8c928732875708b5bac8f09
DIFF: https://github.com/llvm/llvm-project/commit/274ac9d40e79f25ac8c928732875708b5bac8f09.diff
LOG: [AArch64][SVE] Lowering sve.dot to DOT node
Differential Revision: https://reviews.llvm.org/D99699
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 943e06e7286c5..064f8cb8597ed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -145,6 +145,9 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) {
if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) {
SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize);
return true;
+ } else if (auto *Op0 = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) {
+ SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize);
+ return true;
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9364bfbed3772..b40fb7e1a12b9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2153,6 +2153,24 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
// Lowering Code
//===----------------------------------------------------------------------===//
+/// isZerosVector - Check whether SDNode N is a zero-filled vector.
+static bool isZerosVector(const SDNode *N) {
+ // Look through a bit convert.
+ while (N->getOpcode() == ISD::BITCAST)
+ N = N->getOperand(0).getNode();
+
+ if (ISD::isConstantSplatVectorAllZeros(N))
+ return true;
+
+ if (N->getOpcode() != AArch64ISD::DUP)
+ return false;
+
+ auto Opnd0 = N->getOperand(0);
+ auto *CINT = dyn_cast<ConstantSDNode>(Opnd0);
+ auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0);
+ return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero());
+}
+
/// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
/// CC
static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
@@ -3924,9 +3942,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getOperand(2));
}
case Intrinsic::aarch64_neon_sdot:
- case Intrinsic::aarch64_neon_udot: {
- unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT
- : AArch64ISD::SDOT;
+ case Intrinsic::aarch64_neon_udot:
+ case Intrinsic::aarch64_sve_sdot:
+ case Intrinsic::aarch64_sve_udot: {
+ unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
+ IntNo == Intrinsic::aarch64_sve_udot)
+ ? AArch64ISD::UDOT
+ : AArch64ISD::SDOT;
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2), Op.getOperand(3));
}
@@ -13340,7 +13362,7 @@ static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
auto isZeroDot = [](SDValue Dot) {
return (Dot.getOpcode() == AArch64ISD::UDOT ||
Dot.getOpcode() == AArch64ISD::SDOT) &&
- ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode());
+ isZerosVector(Dot.getOperand(0).getNode());
};
if (!isZeroDot(Dot))
std::swap(Dot, A);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d3a607d1dddb1..df4e2cd446234 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -353,8 +353,8 @@ let Predicates = [HasSVE] in {
defm SDIV_ZPZZ : sve_int_bin_pred_sd<AArch64sdiv_p>;
defm UDIV_ZPZZ : sve_int_bin_pred_sd<AArch64udiv_p>;
- defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>;
- defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>;
+ defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
+ defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
index fa67d92c2ae0e..0c8c7c2e05094 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
@@ -114,6 +114,26 @@ define <vscale x 2 x i64> @sdot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
ret <vscale x 2 x i64> %out
}
+define <vscale x 2 x i64> @test_sdot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
+; CHECK-LABEL: test_sdot_i64_zero:
+; CHECK: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
+ %ret = add <vscale x 2 x i64> %vdot1.i, %a
+ ret <vscale x 2 x i64> %ret
+}
+
+define <vscale x 4 x i32> @test_sdot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
+; CHECK-LABEL: test_sdot_i32_zero:
+; CHECK: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+entry:
+ %vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
+ %ret = add <vscale x 4 x i32> %vdot1.i, %a
+ ret <vscale x 4 x i32> %ret
+}
+
; SDOT (Indexed)
define <vscale x 4 x i32> @sdot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
@@ -236,6 +256,26 @@ define <vscale x 2 x i64> @udot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
ret <vscale x 2 x i64> %out
}
+define <vscale x 2 x i64> @test_udot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
+; CHECK-LABEL: test_udot_i64_zero:
+; CHECK: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
+ %ret = add <vscale x 2 x i64> %vdot1.i, %a
+ ret <vscale x 2 x i64> %ret
+}
+
+define <vscale x 4 x i32> @test_udot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
+; CHECK-LABEL: test_udot_i32_zero:
+; CHECK: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+entry:
+ %vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
+ %ret = add <vscale x 4 x i32> %vdot1.i, %a
+ ret <vscale x 4 x i32> %ret
+}
+
; UDOT (Indexed)
define <vscale x 4 x i32> @udot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
More information about the llvm-commits
mailing list