[llvm] [AArch64][SVE] Add dot product lowering for PARTIAL_REDUCE_MLA node (PR #130933)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 18 03:50:49 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

<details>
<summary>Changes</summary>

Add lowering in tablegen for PARTIAL_REDUCE_U/SMLA ISD nodes. Only happens when the combine has been performed on the ISD node. Also adds in check to only do the DAG combine when the node can then eventually be lowered, so changes neon tests too.

@<!-- -->JamesChesterman is the original author

---

Patch is 32.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130933.diff


10 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+35) 
- (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+9) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+5-2) 
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-3) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+15) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+3) 
- (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+11) 
- (modified) llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll (+66-93) 
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+35-151) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2089d47e9cbc8..a9d7d596e6869 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
            getCondCodeAction(CC, VT) == Custom;
   }
 
+  /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+  /// InputVT should be treated. Either it's legal, needs to be promoted to a
+  /// larger size, needs to be expanded to some other code sequence, or the
+  /// target has a custom expander for it.
+  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+    unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+    unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+    assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+           "Table isn't big enough!");
+    return PartialReduceMLAActions[AccI][InputI];
+  }
+
+  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+  /// legal or custom for this target.
+  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+    return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+           getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+  }
+
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2712,6 +2731,16 @@ class TargetLoweringBase {
       setCondCodeAction(CCs, VT, Action);
   }
 
+  /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+  /// type InputVT should be treated by the target. Either it's legal, needs to
+  /// be promoted to a larger size, needs to be expanded to some other code
+  /// sequence, or the target has a custom expander for it.
+  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+                                 LegalizeAction Action) {
+    assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+    PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+  }
+
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
   /// to trying a larger integer/fp until it can find one that works. If that
   /// default is insufficient, this method can be used by the target to override
@@ -3658,6 +3687,12 @@ class TargetLoweringBase {
   /// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
+  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+  /// nodes, keep a LegalizeAction which indicates how instruction selection
+  /// should deal with this operation.
+  LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+                                        [MVT::VALUETYPE_SIZE];
+
   ValueTypeActionImpl ValueTypeActions;
 
 private:
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174..64c27dbace397 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -313,6 +313,10 @@ def SDTSubVecInsert : SDTypeProfile<1, 3, [ // subvector insert
   SDTCisSubVecOfVec<2, 1>, SDTCisSameAs<0,1>, SDTCisInt<3>
 ]>;
 
+def SDTPartialReduceMLA : SDTypeProfile<1, 3, [ // partial reduce mla
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>
+]>;
+
 def SDTPrefetch : SDTypeProfile<0, 4, [     // prefetch
   SDTCisPtrTy<0>, SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisInt<1>
 ]>;
@@ -513,6 +517,11 @@ def vecreduce_fmax  : SDNode<"ISD::VECREDUCE_FMAX", SDTFPVecReduce>;
 def vecreduce_fminimum : SDNode<"ISD::VECREDUCE_FMINIMUM", SDTFPVecReduce>;
 def vecreduce_fmaximum : SDNode<"ISD::VECREDUCE_FMAXIMUM", SDTFPVecReduce>;
 
+def partial_reduce_umla : SDNode<"ISD::PARTIAL_REDUCE_UMLA",
+                                 SDTPartialReduceMLA>;
+def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
+                                 SDTPartialReduceMLA>;
+
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
 def fmul       : SDNode<"ISD::FMUL"       , SDTFPBinOp, [SDNPCommutative]>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0e17897cf60b0..5aaa6cc31efd8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12528,8 +12528,10 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
     return SDValue();
 
-  // FIXME: Add a check to only perform the DAG combine if there is lowering
-  // provided by the target
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), LHSExtOpVT))
+    return SDValue();
 
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 27bde7b96c857..c61e5b263a967 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
-  case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -530,6 +528,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
     break;
   }
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+                                           Node->getOperand(1).getValueType());
+    break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
   case ISD::VPID: {                                                            \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SET_FPENV, VT, Expand);
     setOperationAction(ISD::RESET_FPENV, VT, Expand);
 
-    // PartialReduceMLA operations default to expand.
-    setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
-                       Expand);
+    for (MVT InputVT : MVT::all_valuetypes())
+      setPartialReduceMLAAction(VT, InputVT, Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2dca8c0da4756..d1bfd9b78fd00 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1585,6 +1585,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::MSTORE, VT, Custom);
     }
 
+    for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+      if (!EnablePartialReduceNodes)
+        break;
+      for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+        ElementCount VTElemCount = VT.getVectorElementCount();
+        if (VTElemCount.getKnownMinValue() == 1)
+          continue;
+        if (VTElemCount * 4 == InnerVT.getVectorElementCount())
+          setPartialReduceMLAAction(VT, InnerVT, Custom);
+        if (InnerVT.getVectorElementType().getSizeInBits() * 4 ==
+            VT.getVectorElementType().getSizeInBits())
+          setPartialReduceMLAAction(VT, InnerVT, Legal);
+      }
+    }
+
     // Firstly, exclude all scalable vector extending loads/truncating stores,
     // include both integer and floating scalable vector.
     for (MVT VT : MVT::scalable_vector_valuetypes()) {
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index c836f3138a45f..6459ec9e4fae9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -143,6 +143,9 @@ def HasFuseAES       : Predicate<"Subtarget->hasFuseAES()">,
                                  "fuse-aes">;
 def HasSVE           : Predicate<"Subtarget->isSVEAvailable()">,
                                  AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
+def HasSVEorStreamingSVE 
+                     : Predicate<"Subtarget->isSVEorStreamingSVEAvailable()">, 
+                                 AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
 def HasSVEB16B16     : Predicate<"Subtarget->isSVEorStreamingSVEAvailable() && Subtarget->hasSVEB16B16()">,
                                  AssemblerPredicateWithAll<(all_of FeatureSVEB16B16), "sve-b16b16">;
 def HasSVE2          : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3ee71c14c6bd4..c72bc31c46878 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -655,6 +655,17 @@ let Predicates = [HasSVE_or_SME] in {
   defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
   defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
 
+  let Predicates = [HasSVEorStreamingSVE] in {
+    def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
+              (UDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv16i8:$MulLHS, nxv16i8:$MulRHS)),
+              (SDOT_ZZZ_S $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+              (UDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
+    def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv8i16:$MulLHS, nxv8i16:$MulRHS)),
+              (SDOT_ZZZ_D $Acc, $MulLHS, $MulRHS)>;
+  } // End HasSVEorStreamingSVE
+
   defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
   defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;
 
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 1c9849bdaed3c..0645c7d46d861 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,15 +12,13 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    umlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    umull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    umlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <16 x i8> %u to <16 x i32>
   %s.wide = zext <16 x i8> %s to <16 x i32>
@@ -52,20 +50,18 @@ define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
 ; CHECK-NODOT-NEXT:    mov x8, xzr
 ; CHECK-NODOT-NEXT:  .LBB1_1: // %vector.body
 ; CHECK-NODOT-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NODOT-NEXT:    ldr q0, [x1, x8]
-; CHECK-NODOT-NEXT:    ldr q2, [x0, x8]
+; CHECK-NODOT-NEXT:    ldr q0, [x0, x8]
+; CHECK-NODOT-NEXT:    ldr q2, [x1, x8]
 ; CHECK-NODOT-NEXT:    add x8, x8, #16
 ; CHECK-NODOT-NEXT:    cmp x8, #16
-; CHECK-NODOT-NEXT:    ushll2 v3.8h, v0.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v4.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    ushll v5.8h, v0.8b, #0
-; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    umull v3.8h, v0.8b, v2.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v0.16b, v2.16b
 ; CHECK-NODOT-NEXT:    mov v0.16b, v1.16b
-; CHECK-NODOT-NEXT:    umull v6.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    umlal v1.4s, v2.4h, v5.4h
-; CHECK-NODOT-NEXT:    umlal2 v6.4s, v2.8h, v5.8h
-; CHECK-NODOT-NEXT:    umlal2 v1.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v1.4s, v6.4s, v1.4s
+; CHECK-NODOT-NEXT:    ushll v1.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v4.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v1.4s, v1.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v4.4s, v2.8h
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
 ; CHECK-NODOT-NEXT:    b.ne .LBB1_1
 ; CHECK-NODOT-NEXT:  // %bb.2: // %end
 ; CHECK-NODOT-NEXT:    ret
@@ -99,19 +95,17 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: udot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    umull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    umull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    umull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    uaddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    umlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = zext <8 x i8> %u to <8 x i32>
   %s.wide = zext <8 x i8> %s to <8 x i32>
@@ -128,15 +122,13 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v4.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    smlal v0.4s, v4.4h, v3.4h
-; CHECK-NODOT-NEXT:    smull v5.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NODOT-NEXT:    add v0.4s, v5.4s, v0.4s
+; CHECK-NODOT-NEXT:    smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT:    smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOT-NEXT:    saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT:    add v0.4s, v2.4s, v0.4s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <16 x i8> %u to <16 x i32>
   %s.wide = sext <16 x i8> %s to <16 x i32>
@@ -153,19 +145,17 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ;
 ; CHECK-NODOT-LABEL: sdot_narrow:
 ; CHECK-NODOT:       // %bb.0:
-; CHECK-NODOT-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NODOT-NEXT:    sshll v2.8h, v2.8b, #0
+; CHECK-NODOT-NEXT:    smull v1.8h, v2.8b, v1.8b
 ; CHECK-NODOT-NEXT:    // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT:    smull v3.4s, v2.4h, v1.4h
-; CHECK-NODOT-NEXT:    smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NODOT-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT:    ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT:    smlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT:    sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT:    sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT:    saddw v0.4s, v0.4s, v1.4h
 ; CHECK-NODOT-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT:    ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NODOT-NEXT:    smlal v3.4s, v6.4h, v5.4h
-; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT:    ext v2.16b, v2.16b, v2.16b, #8
 ; CHECK-NODOT-NEXT:    add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT:    saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT:    add v0.2s, v1.2s, v0.2s
 ; CHECK-NODOT-NEXT:    ret
   %u.wide = sext <8 x i8> %u to <8 x i32>
   %s.wide = sext <8 x i8> %s to <8 x i32>
@@ -417,27 +407,19 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
 ;
 ; CHECK-NODOT-LABEL: udot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    ushll v4.8h, v3.8b, #0
-; CHECK-NODOT-NEXT:    ushll v5.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    ushll2 v3.8h, v3.16b, #0
-; CHECK-NODOT-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    ushll v6.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    ushll v7.4s, v5.4h, #0
+; CHECK-NODOT-NEXT:    umull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    ushll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    ushll v5.4s, v2.4h, #0
 ; CHECK-NODOT-NEXT:    ushll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v5.4s, v5.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v16.4s, v3.8h, #0
-; CHECK-NODOT-NEXT:    ushll2 v17.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    ushll v3.4s, v3.4h, #0
-; CHECK-NODOT-NEXT:    ushll v2.4s, v2.4h, #0
-; CHECK-NODOT-NEXT:    umlal2 v1.2d, v7.4s, v6.4s
-; CHECK-NODOT-NEXT:    umlal v0.2d, v7.2s, v6.2s
-; CHECK-NODOT-NEXT:    umull2 v18.2d, v5.4s, v4.4s
-; CHECK-NODOT-NEXT:    umull v4.2d, v5.2s, v4.2s
-; CHECK-NODOT-NEXT:    umlal2 v1.2d, v17.4s, v16.4s
-; CHECK-NODOT-NEXT:    umlal v0.2d, v17.2s, v16.2s
-; CHECK-NODOT-NEXT:    umlal2 v18.2d, v2.4s, v3.4s
-; CHECK-NODOT-NEXT:    umlal v4.2d, v2.2s, v3.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
+; CHECK-NODOT-NEXT:    ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NODOT-NEXT:    uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NODOT-NEXT:    uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NODOT-NEXT:    uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT:    uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT:    add v1.2d, v3.2d, v1.2d
 ; CHECK-NODOT-NEXT:    add v0.2d, v4.2d, v0.2d
 ; CHECK-NODOT-NEXT:    ret
 entry:
@@ -460,27 +442,19 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
 ;
 ; CHECK-NODOT-LABEL: sdot_8to64:
 ; CHECK-NODOT:       // %bb.0: // %entry
-; CHECK-NODOT-NEXT:    sshll v4.8h, v3.8b, #0
-; CHECK-NODOT-NEXT:    sshll v5.8h, v2.8b, #0
-; CHECK-NODOT-NEXT:    sshll2 v3.8h, v3.16b, #0
-; CHECK-NODOT-NEXT:    sshll2 v2.8h, v2.16b, #0
-; CHECK-NODOT-NEXT:    sshll v6.4s, v4.4h, #0
-; CHECK-NODOT-NEXT:    sshll v7.4s, v5.4h, #0
+; CHECK-NODOT-NEXT:    smull v4.8h, v2.8b, v3.8b
+; CHECK-NODOT-NEXT:    smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NODOT-NEXT:    sshll v3.4s, v4.4h, #0
+; CHECK-NODOT-NEXT:    sshll v5.4s, v2.4h, #0
 ; CHECK-NODOT-NEXT:    sshll2 v4.4s, v4.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v5.4s, v5.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v16.4s, v3.8h, #0
-; CHECK-NODOT-NEXT:    sshll2 v17.4s, v2.8h, #0
-; CHECK-NODOT-NEXT:    sshll v3.4s, v3.4h, #0
-; CHECK-NODOT-NEXT:    sshll v2.4s, v2.4h, #0
-; CHECK-NODOT-NEXT:    smlal2 v1.2d, v7.4s, v6.4s
-; CHECK-NODOT-NEXT:    smlal v0.2d, v7.2s, v6.2s
-; CHECK-NODOT-NEXT:    smull2 v18.2d, v5.4s, v4.4s
-; CHECK-NODOT-NEXT:    smull v4.2d, v5.2s, v4.2s
-; CHECK-NODOT-NEXT:    smlal2 v1.2d, v17.4s, v16.4s
-; CHECK-NODOT-NEXT:    smlal v0.2d, v17.2s, v16.2s
-; CHECK-NODOT-NEXT:    smlal2 v18.2d, v2.4s, v3.4s
-; CHECK-NODOT-NEXT:    smlal v4.2d, v2.2s, v3.2s
-; CHECK-NODOT-NEXT:    add v1.2d, v18.2d, v1.2d
+; CHECK-NODOT-NEXT:    sshll2 v2.4s, v2.8h, #0
+;...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/130933


More information about the llvm-commits mailing list