[llvm] f51b3de - [AArch64] Introduce UDOT/SDOT DAG nodes

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 23 12:31:23 PST 2021


Author: David Green
Date: 2021-02-23T20:31:01Z
New Revision: f51b3de4e851812b5f7d7c307ddb7b6ec61c05ab

URL: https://github.com/llvm/llvm-project/commit/f51b3de4e851812b5f7d7c307ddb7b6ec61c05ab
DIFF: https://github.com/llvm/llvm-project/commit/f51b3de4e851812b5f7d7c307ddb7b6ec61c05ab.diff

LOG: [AArch64] Introduce UDOT/SDOT DAG nodes

This is used to lower UDOT/SDOT instructions, as opposed to relying on
the intrinsic. Subsequent optimizations will be able to optimize them
more cleanly based on these nodes.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 036932d90f78..f3ca68810800 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1842,6 +1842,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::URHADD)
     MAKE_CASE(AArch64ISD::SHADD)
     MAKE_CASE(AArch64ISD::UHADD)
+    MAKE_CASE(AArch64ISD::SDOT)
+    MAKE_CASE(AArch64ISD::UDOT)
     MAKE_CASE(AArch64ISD::SMINV)
     MAKE_CASE(AArch64ISD::UMINV)
     MAKE_CASE(AArch64ISD::SMAXV)
@@ -3860,14 +3862,19 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
                        Op.getOperand(2));
   }
-
+  case Intrinsic::aarch64_neon_sabd:
   case Intrinsic::aarch64_neon_uabd: {
-    return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(),
-                       Op.getOperand(1), Op.getOperand(2));
+    unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
+                                                            : AArch64ISD::SABD;
+    return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
+                       Op.getOperand(2));
   }
-  case Intrinsic::aarch64_neon_sabd: {
-    return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(),
-                       Op.getOperand(1), Op.getOperand(2));
+  case Intrinsic::aarch64_neon_sdot:
+  case Intrinsic::aarch64_neon_udot: {
+    unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT
+                                                            : AArch64ISD::SDOT;
+    return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
+                       Op.getOperand(2), Op.getOperand(3));
   }
   }
 }
@@ -11753,11 +11760,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
   SDLoc DL(Op0);
   SDValue Ones = DAG.getConstant(1, DL, Op0VT);
   SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
-  auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND)
-                         ? Intrinsic::aarch64_neon_udot
-                         : Intrinsic::aarch64_neon_sdot;
-  SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(),
-                            DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros,
+  auto DotOpcode =
+      (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
+  SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
                             Ones, Op0.getOperand(0));
   return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 94aef30b21b1..4959c8c3d589 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -231,6 +231,10 @@ enum NodeType : unsigned {
   UABD,
   SABD,
 
+  // udot/sdot instructions
+  UDOT,
+  SDOT,
+
   // Vector across-lanes min/max
   // Only the lower result lane is defined.
   SMINV,

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 1dd2fb30b233..d4871c311275 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -247,6 +247,8 @@ def SDT_AArch64UnaryVec: SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
 def SDT_AArch64ExtVec: SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                           SDTCisSameAs<0,2>, SDTCisInt<3>]>;
 def SDT_AArch64vshift : SDTypeProfile<1, 2, [SDTCisSameAs<0,1>, SDTCisInt<2>]>;
+def SDT_AArch64Dot: SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisSameAs<0,1>,
+                                         SDTCisVec<2>, SDTCisSameAs<2,3>]>;
 
 def SDT_AArch64vshiftinsert : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<3>,
                                                  SDTCisSameAs<0,1>,
@@ -561,6 +563,9 @@ def AArch64frecps   : SDNode<"AArch64ISD::FRECPS", SDTFPBinOp>;
 def AArch64frsqrte  : SDNode<"AArch64ISD::FRSQRTE", SDTFPUnaryOp>;
 def AArch64frsqrts  : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>;
 
+def AArch64sdot     : SDNode<"AArch64ISD::SDOT", SDT_AArch64Dot>;
+def AArch64udot     : SDNode<"AArch64ISD::UDOT", SDT_AArch64Dot>;
+
 def AArch64saddv    : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
 def AArch64uaddv    : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>;
 def AArch64sminv    : SDNode<"AArch64ISD::SMINV", SDT_AArch64UnaryVec>;
@@ -831,10 +836,10 @@ def : TokenAlias<"IALL", "iall">;
 
 // ARMv8.2-A Dot Product
 let Predicates = [HasDotProd] in {
-defm SDOT : SIMDThreeSameVectorDot<0, 0, "sdot", int_aarch64_neon_sdot>;
-defm UDOT : SIMDThreeSameVectorDot<1, 0, "udot", int_aarch64_neon_udot>;
-defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", int_aarch64_neon_sdot>;
-defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", int_aarch64_neon_udot>;
+defm SDOT : SIMDThreeSameVectorDot<0, 0, "sdot", AArch64sdot>;
+defm UDOT : SIMDThreeSameVectorDot<1, 0, "udot", AArch64udot>;
+defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
+defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
 }
 
 // ARMv8.6-A BFloat


        


More information about the llvm-commits mailing list