[llvm] 1a80828 - [AArch64] Extend vecreduce -> udot handling to mla reductions

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 10 14:26:50 PST 2021


Author: David Green
Date: 2021-03-10T22:25:12Z
New Revision: 1a808286eff01fde07794b2c94138a96e7099561

URL: https://github.com/llvm/llvm-project/commit/1a808286eff01fde07794b2c94138a96e7099561
DIFF: https://github.com/llvm/llvm-project/commit/1a808286eff01fde07794b2c94138a96e7099561.diff

LOG: [AArch64] Extend vecreduce -> udot handling to mla reductions

We previously have lowering for:
  vecreduce.add(zext(X)) to vecreduce.add(UDOT(zero, X, one))
This extends that to also handle:
  vecreduce.add(mul(zext(X), zext(Y)) to vecreduce.add(UDOT(zero, X, Y))
It extends the existing code to optionally handle a mul with equal
extends.

Differential Revision: https://reviews.llvm.org/D97280

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/neon-dotreduce.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c827baf4fb3b..b6ef5eb2b629 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -11747,31 +11747,46 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
 
 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
 //   vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
+//   vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
 static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                                           const AArch64Subtarget *ST) {
   SDValue Op0 = N->getOperand(0);
-  if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32)
-    return SDValue();
-
-  if (Op0.getValueType().getVectorElementType() != MVT::i32)
+  if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 ||
+      Op0.getValueType().getVectorElementType() != MVT::i32)
     return SDValue();
 
   unsigned ExtOpcode = Op0.getOpcode();
+  SDValue A = Op0;
+  SDValue B;
+  if (ExtOpcode == ISD::MUL) {
+    A = Op0.getOperand(0);
+    B = Op0.getOperand(1);
+    if (A.getOpcode() != B.getOpcode() ||
+        A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
+      return SDValue();
+    ExtOpcode = A.getOpcode();
+  }
   if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
     return SDValue();
 
-  EVT Op0VT = Op0.getOperand(0).getValueType();
+  EVT Op0VT = A.getOperand(0).getValueType();
   if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
     return SDValue();
 
   SDLoc DL(Op0);
-  SDValue Ones = DAG.getConstant(1, DL, Op0VT);
+  // For non-mla reductions B can be set to 1. For MLA we take the operand of
+  // the extend B.
+  if (!B)
+    B = DAG.getConstant(1, DL, Op0VT);
+  else
+    B = B.getOperand(0);
+
   SDValue Zeros =
       DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
   auto DotOpcode =
       (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
   SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
-                            Ones, Op0.getOperand(0));
+                            A.getOperand(0), B);
   return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index 91a526bc0488..1827bb0990be 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -9,11 +9,10 @@ define i32 @test_udot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b)
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr d0, [x0]
 ; CHECK-NEXT:    ldr d1, [x1]
-; CHECK-NEXT:    dup v2.2s, wzr
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
 ; CHECK-NEXT:    udot v2.2s, v1.8b, v0.8b
 ; CHECK-NEXT:    addp v0.2s, v2.2s, v2.2s
-; CHECK-NEXT:    fmov x0, d0
-; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
   %0 = bitcast i8* %a to <8 x i8>*
@@ -33,7 +32,7 @@ define i32 @test_udot_v8i8_nomla(i8* nocapture readonly %a1) {
 ; CHECK-NEXT:    ldr d0, [x0]
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    movi v2.8b, #1
-; CHECK-NEXT:    udot v1.2s, v2.8b, v0.8b
+; CHECK-NEXT:    udot v1.2s, v0.8b, v2.8b
 ; CHECK-NEXT:    addp v0.2s, v1.2s, v1.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -50,11 +49,10 @@ define i32 @test_sdot_v8i8(i8* nocapture readonly %a, i8* nocapture readonly %b)
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr d0, [x0]
 ; CHECK-NEXT:    ldr d1, [x1]
-; CHECK-NEXT:    dup v2.2s, wzr
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
 ; CHECK-NEXT:    sdot v2.2s, v1.8b, v0.8b
 ; CHECK-NEXT:    addp v0.2s, v2.2s, v2.2s
-; CHECK-NEXT:    fmov x0, d0
-; CHECK-NEXT:    // kill: def $w0 killed $w0 killed $x0
+; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
   %0 = bitcast i8* %a to <8 x i8>*
@@ -74,7 +72,7 @@ define i32 @test_sdot_v8i8_nomla(i8* nocapture readonly %a1) {
 ; CHECK-NEXT:    ldr d0, [x0]
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    movi v2.8b, #1
-; CHECK-NEXT:    sdot v1.2s, v2.8b, v0.8b
+; CHECK-NEXT:    sdot v1.2s, v0.8b, v2.8b
 ; CHECK-NEXT:    addp v0.2s, v1.2s, v1.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -92,7 +90,7 @@ define i32 @test_udot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
-; CHECK-NEXT:    dup v2.4s, wzr
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
 ; CHECK-NEXT:    udot v2.4s, v1.16b, v0.16b
 ; CHECK-NEXT:    addv s0, v2.4s
 ; CHECK-NEXT:    fmov w8, s0
@@ -117,7 +115,7 @@ define i32 @test_udot_v16i8_nomla(i8* nocapture readonly %a1) {
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    movi v1.16b, #1
 ; CHECK-NEXT:    movi v2.2d, #0000000000000000
-; CHECK-NEXT:    udot v2.4s, v1.16b, v0.16b
+; CHECK-NEXT:    udot v2.4s, v0.16b, v1.16b
 ; CHECK-NEXT:    addv s0, v2.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -134,7 +132,7 @@ define i32 @test_sdot_v16i8(i8* nocapture readonly %a, i8* nocapture readonly %b
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    ldr q1, [x1]
-; CHECK-NEXT:    dup v2.4s, wzr
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
 ; CHECK-NEXT:    sdot v2.4s, v1.16b, v0.16b
 ; CHECK-NEXT:    addv s0, v2.4s
 ; CHECK-NEXT:    fmov w8, s0
@@ -159,7 +157,7 @@ define i32 @test_sdot_v16i8_nomla(i8* nocapture readonly %a1) {
 ; CHECK-NEXT:    ldr q0, [x0]
 ; CHECK-NEXT:    movi v1.16b, #1
 ; CHECK-NEXT:    movi v2.2d, #0000000000000000
-; CHECK-NEXT:    sdot v2.4s, v1.16b, v0.16b
+; CHECK-NEXT:    sdot v2.4s, v0.16b, v1.16b
 ; CHECK-NEXT:    addv s0, v2.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -175,20 +173,10 @@ entry:
 define i32 @test_udot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-LABEL: test_udot_v8i8_double:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NEXT:    ushll v3.8h, v3.8b, #0
-; CHECK-NEXT:    ext v4.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    umull v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    ext v1.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT:    umull v2.4s, v2.4h, v3.4h
-; CHECK-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT:    umlal v0.4s, v4.4h, v5.4h
-; CHECK-NEXT:    umlal v2.4s, v1.4h, v3.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v2.4s
-; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEXT:    udot v4.2s, v2.8b, v3.8b
+; CHECK-NEXT:    udot v4.2s, v0.8b, v1.8b
+; CHECK-NEXT:    addp v0.2s, v4.2s, v4.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
@@ -209,8 +197,8 @@ define i32 @test_udot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    movi v3.8b, #1
-; CHECK-NEXT:    udot v1.2s, v3.8b, v2.8b
-; CHECK-NEXT:    udot v1.2s, v3.8b, v0.8b
+; CHECK-NEXT:    udot v1.2s, v2.8b, v3.8b
+; CHECK-NEXT:    udot v1.2s, v0.8b, v3.8b
 ; CHECK-NEXT:    addp v0.2s, v1.2s, v1.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -226,30 +214,10 @@ entry:
 define i32 @test_udot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
 ; CHECK-LABEL: test_udot_v16i8_double:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ushll2 v4.8h, v0.16b, #0
-; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    ushll2 v5.8h, v1.16b, #0
-; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
-; CHECK-NEXT:    ext v7.16b, v5.16b, v5.16b, #8
-; CHECK-NEXT:    umull2 v16.4s, v0.8h, v1.8h
-; CHECK-NEXT:    umlal v16.4s, v6.4h, v7.4h
-; CHECK-NEXT:    ushll2 v6.8h, v2.16b, #0
-; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
-; CHECK-NEXT:    ushll2 v7.8h, v3.16b, #0
-; CHECK-NEXT:    ushll v3.8h, v3.8b, #0
-; CHECK-NEXT:    umull v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    ext v1.16b, v6.16b, v6.16b, #8
-; CHECK-NEXT:    umlal v0.4s, v4.4h, v5.4h
-; CHECK-NEXT:    ext v4.16b, v7.16b, v7.16b, #8
-; CHECK-NEXT:    umull v5.4s, v2.4h, v3.4h
-; CHECK-NEXT:    umull2 v2.4s, v2.8h, v3.8h
-; CHECK-NEXT:    umlal v2.4s, v1.4h, v4.4h
-; CHECK-NEXT:    umlal v5.4s, v6.4h, v7.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v16.4s
-; CHECK-NEXT:    add v1.4s, v5.4s, v2.4s
-; CHECK-NEXT:    add v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEXT:    udot v4.4s, v2.16b, v3.16b
+; CHECK-NEXT:    udot v4.4s, v0.16b, v1.16b
+; CHECK-NEXT:    addv s0, v4.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
@@ -270,8 +238,8 @@ define i32 @test_udot_v16i8_double_nomla(<16 x i8> %a, <16 x i8> %b, <16 x i8> %
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v1.16b, #1
 ; CHECK-NEXT:    movi v3.2d, #0000000000000000
-; CHECK-NEXT:    udot v3.4s, v1.16b, v2.16b
-; CHECK-NEXT:    udot v3.4s, v1.16b, v0.16b
+; CHECK-NEXT:    udot v3.4s, v2.16b, v1.16b
+; CHECK-NEXT:    udot v3.4s, v0.16b, v1.16b
 ; CHECK-NEXT:    addv s0, v3.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -287,20 +255,10 @@ entry:
 define i32 @test_sdot_v8i8_double(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-LABEL: test_sdot_v8i8_double:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
-; CHECK-NEXT:    sshll v3.8h, v3.8b, #0
-; CHECK-NEXT:    ext v4.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    smull v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    ext v1.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT:    smull v2.4s, v2.4h, v3.4h
-; CHECK-NEXT:    ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT:    smlal v0.4s, v4.4h, v5.4h
-; CHECK-NEXT:    smlal v2.4s, v1.4h, v3.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v2.4s
-; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEXT:    sdot v4.2s, v2.8b, v3.8b
+; CHECK-NEXT:    sdot v4.2s, v0.8b, v1.8b
+; CHECK-NEXT:    addp v0.2s, v4.2s, v4.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
@@ -321,8 +279,8 @@ define i32 @test_sdot_v8i8_double_nomla(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    movi v3.8b, #1
-; CHECK-NEXT:    sdot v1.2s, v3.8b, v2.8b
-; CHECK-NEXT:    sdot v1.2s, v3.8b, v0.8b
+; CHECK-NEXT:    sdot v1.2s, v2.8b, v3.8b
+; CHECK-NEXT:    sdot v1.2s, v0.8b, v3.8b
 ; CHECK-NEXT:    addp v0.2s, v1.2s, v1.2s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
@@ -338,30 +296,10 @@ entry:
 define i32 @test_sdot_v16i8_double(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
 ; CHECK-LABEL: test_sdot_v16i8_double:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    sshll2 v4.8h, v0.16b, #0
-; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll2 v5.8h, v1.16b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
-; CHECK-NEXT:    ext v7.16b, v5.16b, v5.16b, #8
-; CHECK-NEXT:    smull2 v16.4s, v0.8h, v1.8h
-; CHECK-NEXT:    smlal v16.4s, v6.4h, v7.4h
-; CHECK-NEXT:    sshll2 v6.8h, v2.16b, #0
-; CHECK-NEXT:    sshll v2.8h, v2.8b, #0
-; CHECK-NEXT:    sshll2 v7.8h, v3.16b, #0
-; CHECK-NEXT:    sshll v3.8h, v3.8b, #0
-; CHECK-NEXT:    smull v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    ext v1.16b, v6.16b, v6.16b, #8
-; CHECK-NEXT:    smlal v0.4s, v4.4h, v5.4h
-; CHECK-NEXT:    ext v4.16b, v7.16b, v7.16b, #8
-; CHECK-NEXT:    smull v5.4s, v2.4h, v3.4h
-; CHECK-NEXT:    smull2 v2.4s, v2.8h, v3.8h
-; CHECK-NEXT:    smlal v2.4s, v1.4h, v4.4h
-; CHECK-NEXT:    smlal v5.4s, v6.4h, v7.4h
-; CHECK-NEXT:    add v0.4s, v0.4s, v16.4s
-; CHECK-NEXT:    add v1.4s, v5.4s, v2.4s
-; CHECK-NEXT:    add v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEXT:    sdot v4.4s, v2.16b, v3.16b
+; CHECK-NEXT:    sdot v4.4s, v0.16b, v1.16b
+; CHECK-NEXT:    addv s0, v4.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
 entry:
@@ -382,8 +320,8 @@ define i32 @test_sdot_v16i8_double_nomla(<16 x i8> %a, <16 x i8> %b, <16 x i8> %
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v1.16b, #1
 ; CHECK-NEXT:    movi v3.2d, #0000000000000000
-; CHECK-NEXT:    sdot v3.4s, v1.16b, v2.16b
-; CHECK-NEXT:    sdot v3.4s, v1.16b, v0.16b
+; CHECK-NEXT:    sdot v3.4s, v2.16b, v1.16b
+; CHECK-NEXT:    sdot v3.4s, v0.16b, v1.16b
 ; CHECK-NEXT:    addv s0, v3.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret


        


More information about the llvm-commits mailing list