[llvm] 274ac9d - [AArch64][SVE] Lowering sve.dot to DOT node

Jun Ma via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 2 05:17:41 PDT 2021


Author: Jun Ma
Date: 2021-04-02T20:05:17+08:00
New Revision: 274ac9d40e79f25ac8c928732875708b5bac8f09

URL: https://github.com/llvm/llvm-project/commit/274ac9d40e79f25ac8c928732875708b5bac8f09
DIFF: https://github.com/llvm/llvm-project/commit/274ac9d40e79f25ac8c928732875708b5bac8f09.diff

LOG: [AArch64][SVE] Lowering sve.dot to DOT node

Differential Revision: https://reviews.llvm.org/D99699

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 943e06e7286c5..064f8cb8597ed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -145,6 +145,9 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) {
     if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) {
       SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize);
       return true;
+    } else if (auto *Op0 = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) {
+      SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize);
+      return true;
     }
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9364bfbed3772..b40fb7e1a12b9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2153,6 +2153,24 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
 // Lowering Code
 //===----------------------------------------------------------------------===//
 
+/// isZerosVector - Check whether SDNode N is a zero-filled vector.
+static bool isZerosVector(const SDNode *N) {
+  // Look through a bit convert.
+  while (N->getOpcode() == ISD::BITCAST)
+    N = N->getOperand(0).getNode();
+
+  if (ISD::isConstantSplatVectorAllZeros(N))
+    return true;
+
+  if (N->getOpcode() != AArch64ISD::DUP)
+    return false;
+
+  auto Opnd0 = N->getOperand(0);
+  auto *CINT = dyn_cast<ConstantSDNode>(Opnd0);
+  auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0);
+  return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero());
+}
+
 /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
 /// CC
 static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
@@ -3924,9 +3942,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
                        Op.getOperand(2));
   }
   case Intrinsic::aarch64_neon_sdot:
-  case Intrinsic::aarch64_neon_udot: {
-    unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT
-                                                            : AArch64ISD::SDOT;
+  case Intrinsic::aarch64_neon_udot:
+  case Intrinsic::aarch64_sve_sdot:
+  case Intrinsic::aarch64_sve_udot: {
+    unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
+                       IntNo == Intrinsic::aarch64_sve_udot)
+                          ? AArch64ISD::UDOT
+                          : AArch64ISD::SDOT;
     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
                        Op.getOperand(2), Op.getOperand(3));
   }
@@ -13340,7 +13362,7 @@ static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
   auto isZeroDot = [](SDValue Dot) {
     return (Dot.getOpcode() == AArch64ISD::UDOT ||
             Dot.getOpcode() == AArch64ISD::SDOT) &&
-           ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode());
+           isZerosVector(Dot.getOperand(0).getNode());
   };
   if (!isZeroDot(Dot))
     std::swap(Dot, A);

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d3a607d1dddb1..df4e2cd446234 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -353,8 +353,8 @@ let Predicates = [HasSVE] in {
   defm SDIV_ZPZZ  : sve_int_bin_pred_sd<AArch64sdiv_p>;
   defm UDIV_ZPZZ  : sve_int_bin_pred_sd<AArch64udiv_p>;
 
-  defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>;
-  defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>;
+  defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
+  defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
 
   defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
   defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
index fa67d92c2ae0e..0c8c7c2e05094 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll
@@ -114,6 +114,26 @@ define <vscale x 2 x i64> @sdot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
   ret <vscale x 2 x i64> %out
 }
 
+define <vscale x 2 x i64> @test_sdot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
+; CHECK-LABEL: test_sdot_i64_zero:
+; CHECK:         sdot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
+  %ret = add <vscale x 2 x i64> %vdot1.i, %a
+  ret <vscale x 2 x i64> %ret
+}
+
+define <vscale x 4 x i32> @test_sdot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
+; CHECK-LABEL: test_sdot_i32_zero:
+; CHECK:         sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+entry:
+  %vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
+  %ret = add <vscale x 4 x i32> %vdot1.i, %a
+  ret <vscale x 4 x i32> %ret
+}
+
 ; SDOT (Indexed)
 
 define <vscale x 4 x i32> @sdot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
@@ -236,6 +256,26 @@ define <vscale x 2 x i64> @udot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
   ret <vscale x 2 x i64> %out
 }
 
+define <vscale x 2 x i64> @test_udot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
+; CHECK-LABEL: test_udot_i64_zero:
+; CHECK:         udot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
+  %ret = add <vscale x 2 x i64> %vdot1.i, %a
+  ret <vscale x 2 x i64> %ret
+}
+
+define <vscale x 4 x i32> @test_udot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
+; CHECK-LABEL: test_udot_i32_zero:
+; CHECK:         udot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+entry:
+  %vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
+  %ret = add <vscale x 4 x i32> %vdot1.i, %a
+  ret <vscale x 4 x i32> %ret
+}
+
 ; UDOT (Indexed)
 
 define <vscale x 4 x i32> @udot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {


        


More information about the llvm-commits mailing list