[llvm] 85cf958 - [AArch64] Improve codegen for some fixed-width partial reductions (#126529)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 01:56:09 PST 2025


Author: David Sherwood
Date: 2025-02-25T09:56:06Z
New Revision: 85cf95876c4b21ee6ecd0253a2c9de0e90c4a521

URL: https://github.com/llvm/llvm-project/commit/85cf95876c4b21ee6ecd0253a2c9de0e90c4a521
DIFF: https://github.com/llvm/llvm-project/commit/85cf95876c4b21ee6ecd0253a2c9de0e90c4a521.diff

LOG: [AArch64] Improve codegen for some fixed-width partial reductions (#126529)

This patch teaches optimizeExtendOrTruncateConversion to bail out
if the user of a zero-extend is a partial reduction intrinsic
that we know will get lowered efficiently to a udot instruction.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d519bfc06af05..b00aa11f8499d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2055,8 +2055,9 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
-  if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
-    return true;
+  assert(I->getIntrinsicID() ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Unexpected intrinsic!");
   if (EnablePartialReduceNodes)
     return true;
 
@@ -16890,9 +16891,16 @@ bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
     // mul(zext(i8), sext) can be transformed into smull(zext, sext) which
     // performs one extend implicitly. If DstWidth is at most 4 * SrcWidth, at
     // most one extra extend step is needed and using tbl is not profitable.
+    // Similarly, bail out if partial_reduce(acc, zext(i8)) can be lowered to a
+    // udot instruction.
     if (SrcWidth * 4 <= DstWidth && I->hasOneUser()) {
       auto *SingleUser = cast<Instruction>(*I->user_begin());
-      if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))))
+      if (match(SingleUser, m_c_Mul(m_Specific(I), m_SExt(m_Value()))) ||
+          (match(SingleUser,
+                 m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+                     m_Value(), m_Specific(I))) &&
+           !shouldExpandPartialReductionIntrinsic(
+               cast<IntrinsicInst>(SingleUser))))
         return false;
     }
 

diff  --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 40daf8ffb63ea..249675470e38c 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,7 +2,7 @@
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
 ; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
 ; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
 
 define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
 ; CHECK-DOT-LABEL: udot:
@@ -27,6 +27,66 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @udot_in_loop(ptr %p1, ptr %p2){
+; CHECK-DOT-LABEL: udot_in_loop:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT:    mov x8, xzr
+; CHECK-DOT-NEXT:  .LBB1_1: // %vector.body
+; CHECK-DOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT:    ldr q2, [x0, x8]
+; CHECK-DOT-NEXT:    ldr q3, [x1, x8]
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    add x8, x8, #16
+; CHECK-DOT-NEXT:    udot v1.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT:    cmp x8, #16
+; CHECK-DOT-NEXT:    b.ne .LBB1_1
+; CHECK-DOT-NEXT:  // %bb.2: // %end
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_in_loop:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NODOT-NEXT:    mov x8, xzr
+; CHECK-NODOT-NEXT:  .LBB1_1: // %vector.body
+; CHECK-NODOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NODOT-NEXT:    ldr q0, [x0, x8]
+; CHECK-NODOT-NEXT:    ldr q2, [x1, x8]
+; CHECK-NODOT-NEXT:    add x8, x8, #16
+; CHECK-NODOT-NEXT:    cmp x8, #16
+; CHECK-NODOT-NEXT:    umull v3.8h, v0.8b, v2.8b
+; CHECK-NODOT-NEXT:    umull2 v2.8h, v0.16b, v2.16b
+; CHECK-NODOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v1.4s, v2.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v4.4s, v0.4s, v3.4h
+; CHECK-NODOT-NEXT:    uaddw2 v1.4s, v1.4s, v3.8h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v4.4s, v2.8h
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    b.ne .LBB1_1
+; CHECK-NODOT-NEXT:  // %bb.2: // %end
+; CHECK-NODOT-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = zext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = zext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
 ; CHECK-DOT-LABEL: udot_narrow:
 ; CHECK-DOT:       // %bb.0:
@@ -129,6 +189,68 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @usdot_in_loop(ptr %p1, ptr %p2){
+; CHECK-NOI8MM-LABEL: usdot_in_loop:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NOI8MM-NEXT:    mov x8, xzr
+; CHECK-NOI8MM-NEXT:  .LBB6_1: // %vector.body
+; CHECK-NOI8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NOI8MM-NEXT:    ldr q0, [x0, x8]
+; CHECK-NOI8MM-NEXT:    ldr q2, [x1, x8]
+; CHECK-NOI8MM-NEXT:    add x8, x8, #16
+; CHECK-NOI8MM-NEXT:    cmp x8, #16
+; CHECK-NOI8MM-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v4.8h, v0.16b, #0
+; CHECK-NOI8MM-NEXT:    ushll v5.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NOI8MM-NEXT:    smlal v1.4s, v3.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    smull v6.4s, v4.4h, v2.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v1.4s, v4.8h, v2.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v6.4s, v3.8h, v5.8h
+; CHECK-NOI8MM-NEXT:    add v1.4s, v6.4s, v1.4s
+; CHECK-NOI8MM-NEXT:    b.ne .LBB6_1
+; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: usdot_in_loop:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    mov x8, xzr
+; CHECK-I8MM-NEXT:  .LBB6_1: // %vector.body
+; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-I8MM-NEXT:    add x8, x8, #16
+; CHECK-I8MM-NEXT:    usdot v1.4s, v3.16b, v2.16b
+; CHECK-I8MM-NEXT:    cmp x8, #16
+; CHECK-I8MM-NEXT:    b.ne .LBB6_1
+; CHECK-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-I8MM-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = sext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = zext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-LABEL: usdot_narrow:
 ; CHECK-NOI8MM:       // %bb.0:
@@ -176,13 +298,75 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
 ; CHECK-I8MM:       // %bb.0:
 ; CHECK-I8MM-NEXT:    usdot v0.4s, v2.16b, v1.16b
 ; CHECK-I8MM-NEXT:    ret
-  %u.wide = sext <16 x i8> %u to <16 x i32>
-  %s.wide = zext <16 x i8> %s to <16 x i32>
-  %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+  %s.wide = sext <16 x i8> %u to <16 x i32>
+  %u.wide = zext <16 x i8> %s to <16 x i32>
+  %mult = mul nuw nsw <16 x i32> %u.wide, %s.wide
   %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @sudot_in_loop(ptr %p1, ptr %p2){
+; CHECK-NOI8MM-LABEL: sudot_in_loop:
+; CHECK-NOI8MM:       // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NOI8MM-NEXT:    mov x8, xzr
+; CHECK-NOI8MM-NEXT:  .LBB9_1: // %vector.body
+; CHECK-NOI8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NOI8MM-NEXT:    ldr q0, [x0, x8]
+; CHECK-NOI8MM-NEXT:    ldr q2, [x1, x8]
+; CHECK-NOI8MM-NEXT:    add x8, x8, #16
+; CHECK-NOI8MM-NEXT:    cmp x8, #16
+; CHECK-NOI8MM-NEXT:    ushll v3.8h, v0.8b, #0
+; CHECK-NOI8MM-NEXT:    ushll2 v4.8h, v0.16b, #0
+; CHECK-NOI8MM-NEXT:    sshll v5.8h, v2.8b, #0
+; CHECK-NOI8MM-NEXT:    sshll2 v2.8h, v2.16b, #0
+; CHECK-NOI8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-NOI8MM-NEXT:    smlal v1.4s, v3.4h, v5.4h
+; CHECK-NOI8MM-NEXT:    smull v6.4s, v4.4h, v2.4h
+; CHECK-NOI8MM-NEXT:    smlal2 v1.4s, v4.8h, v2.8h
+; CHECK-NOI8MM-NEXT:    smlal2 v6.4s, v3.8h, v5.8h
+; CHECK-NOI8MM-NEXT:    add v1.4s, v6.4s, v1.4s
+; CHECK-NOI8MM-NEXT:    b.ne .LBB9_1
+; CHECK-NOI8MM-NEXT:  // %bb.2: // %end
+; CHECK-NOI8MM-NEXT:    ret
+;
+; CHECK-I8MM-LABEL: sudot_in_loop:
+; CHECK-I8MM:       // %bb.0: // %entry
+; CHECK-I8MM-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-I8MM-NEXT:    mov x8, xzr
+; CHECK-I8MM-NEXT:  .LBB9_1: // %vector.body
+; CHECK-I8MM-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-I8MM-NEXT:    ldr q2, [x0, x8]
+; CHECK-I8MM-NEXT:    ldr q3, [x1, x8]
+; CHECK-I8MM-NEXT:    mov v0.16b, v1.16b
+; CHECK-I8MM-NEXT:    add x8, x8, #16
+; CHECK-I8MM-NEXT:    usdot v1.4s, v2.16b, v3.16b
+; CHECK-I8MM-NEXT:    cmp x8, #16
+; CHECK-I8MM-NEXT:    b.ne .LBB9_1
+; CHECK-I8MM-NEXT:  // %bb.2: // %end
+; CHECK-I8MM-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep1 = getelementptr i8, ptr %p1, i64 %index
+  %load1 = load <16 x i8>, ptr %gep1, align 16
+  %load1.wide = zext <16 x i8> %load1 to <16 x i32>
+  %gep2 = getelementptr i8, ptr %p2, i64 %index
+  %load2 = load <16 x i8>, ptr %gep2, align 16
+  %load2.wide = sext <16 x i8> %load2 to <16 x i32>
+  %mul = mul nuw nsw <16 x i32> %load1.wide, %load2.wide
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mul)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
 ; CHECK-NOI8MM-LABEL: sudot_narrow:
 ; CHECK-NOI8MM:       // %bb.0:
@@ -390,6 +574,62 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
   ret <4 x i32> %partial.reduce
 }
 
+define <4 x i32> @udot_no_bin_op_in_loop(ptr %p){
+; CHECK-DOT-LABEL: udot_no_bin_op_in_loop:
+; CHECK-DOT:       // %bb.0: // %entry
+; CHECK-DOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-DOT-NEXT:    movi v2.16b, #1
+; CHECK-DOT-NEXT:    mov x8, xzr
+; CHECK-DOT-NEXT:  .LBB16_1: // %vector.body
+; CHECK-DOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-DOT-NEXT:    ldr q3, [x0, x8]
+; CHECK-DOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-DOT-NEXT:    add x8, x8, #16
+; CHECK-DOT-NEXT:    cmp x8, #16
+; CHECK-DOT-NEXT:    udot v1.4s, v3.16b, v2.16b
+; CHECK-DOT-NEXT:    b.ne .LBB16_1
+; CHECK-DOT-NEXT:  // %bb.2: // %end
+; CHECK-DOT-NEXT:    ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_in_loop:
+; CHECK-NODOT:       // %bb.0: // %entry
+; CHECK-NODOT-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NODOT-NEXT:    mov x8, xzr
+; CHECK-NODOT-NEXT:  .LBB16_1: // %vector.body
+; CHECK-NODOT-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NODOT-NEXT:    ldr q0, [x0, x8]
+; CHECK-NODOT-NEXT:    add x8, x8, #16
+; CHECK-NODOT-NEXT:    cmp x8, #16
+; CHECK-NODOT-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NODOT-NEXT:    ushll2 v3.8h, v0.16b, #0
+; CHECK-NODOT-NEXT:    mov v0.16b, v1.16b
+; CHECK-NODOT-NEXT:    ushll v1.4s, v3.4h, #0
+; CHECK-NODOT-NEXT:    uaddw v4.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT:    uaddw2 v1.4s, v1.4s, v2.8h
+; CHECK-NODOT-NEXT:    uaddw2 v2.4s, v4.4s, v3.8h
+; CHECK-NODOT-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NODOT-NEXT:    b.ne .LBB16_1
+; CHECK-NODOT-NEXT:  // %bb.2: // %end
+; CHECK-NODOT-NEXT:    ret
+
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
+  %acc = phi <4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce, %vector.body ]
+  %gep = getelementptr i8, ptr %p, i64 %index
+  %load = load <16 x i8>, ptr %gep, align 16
+  %load.wide = zext <16 x i8> %load to <16 x i32>
+  %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %load.wide)
+  %index.next = add nuw i64 %index, 16
+  %cmp = icmp eq i64 %index.next, 16
+  br i1 %cmp, label %end, label %vector.body
+
+end:
+  ret <4 x i32> %acc
+}
+
 define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
 ; CHECK-DOT-LABEL: sdot_no_bin_op:
 ; CHECK-DOT:       // %bb.0:


        


More information about the llvm-commits mailing list