[llvm] [AArch64] Generalize bfdotq_lane patterns to work for f32/i32 duplanes (PR #171146)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 8 07:06:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This also removes an overly specific pattern that is redundant with this change.

Fixes #<!-- -->170883

---
Full diff: https://github.com/llvm/llvm-project/pull/171146.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+47-30) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (-17) 
- (modified) llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll (+31) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 4d2e740779961..821dfbd8e9191 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -10,6 +10,20 @@
 //  Describe AArch64 instructions format here
 //
 
+// Helper class to convert vector element types to integers.
+class ChangeElementTypeToInteger<ValueType InVT> {
+  ValueType VT = !cond(
+    !eq(InVT, v2f32): v2i32,
+    !eq(InVT, v4f32): v4i32,
+    // TODO: Other types.
+    true : untyped);
+}
+
+class VTPair<ValueType A, ValueType B> {
+  ValueType VT0 = A;
+  ValueType VT1 = B;
+}
+
 // Format specifies the encoding used by the instruction.  This is part of the
 // ad-hoc solution used to emit machine instruction encodings by our machine
 // code emitter.
@@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
                                            v4f32, v8bf16>;
 }
 
-class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
-                                      string dst_kind, string lhs_kind,
-                                      string rhs_kind,
-                                      RegisterOperand RegType,
-                                      ValueType AccumType,
-                                      ValueType InputType>
-  : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
-                        RegType, RegType, V128, VectorIndexS,
-                        asm, "", dst_kind, lhs_kind, rhs_kind,
-        [(set (AccumType RegType:$dst),
-              (AccumType (int_aarch64_neon_bfdot
-                                 (AccumType RegType:$Rd),
-                                 (InputType RegType:$Rn),
-                                 (InputType (bitconvert (AccumType
-                                    (AArch64duplane32 (v4f32 V128:$Rm),
-                                        VectorIndexS:$idx)))))))]> {
-
-  bits<2> idx;
-  let Inst{21}    = idx{0};  // L
-  let Inst{11}    = idx{1};  // H
-}
-
-multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
-
-  def v4bf16  : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
-                                               ".2h", V64, v2f32, v4bf16>;
-  def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
-                                              ".2h", V128, v4f32, v8bf16>;
-}
-
 let mayRaiseFPException = 1, Uses = [FPCR] in
 class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
   : BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
@@ -9054,6 +9038,39 @@ class BF16ToSinglePrecision<string asm>
 }
 } // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0
 
+multiclass BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
+                                           string dst_kind, string lhs_kind,
+                                           string rhs_kind,
+                                           RegisterOperand RegType,
+                                           ValueType AccumType,
+                                           ValueType InputType> {
+  let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in {
+    def NAME : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111, RegType, RegType, V128, VectorIndexS,
+                                   asm, "", dst_kind, lhs_kind, rhs_kind, []>
+    {
+      bits<2> idx;
+      let Inst{21} = idx{0};  // L
+      let Inst{11} = idx{1};  // H
+    }
+  }
+
+  foreach DupTypes = [VTPair<AccumType, v4f32>,
+                      VTPair<ChangeElementTypeToInteger<AccumType>.VT, v4i32>] in {
+    def : Pat<(AccumType (int_aarch64_neon_bfdot
+                (AccumType RegType:$Rd), (InputType RegType:$Rn),
+                (InputType (bitconvert
+                  (DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1 V128:$Rm), VectorIndexS:$Idx)))))),
+              (!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>;
+  }
+}
+
+multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
+  defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
+                                               ".2h", V64, v2f32, v4bf16>;
+  defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
+                                              ".2h", V128, v4f32, v8bf16>;
+}
+
 //----------------------------------------------------------------------------
 class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
                                     string asm, string dst_kind,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 64017d7cafca3..d2e34219d40aa 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1729,23 +1729,6 @@ def BFCVTN2      : SIMD_BFCVTN2;
 
 def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))),
           (BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>;
-
-// Vector-scalar BFDOT:
-// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
-// register (the instruction uses a single 32-bit lane from it), so the pattern
-// is a bit tricky.
-def : Pat<(v2f32 (int_aarch64_neon_bfdot
-                    (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
-                    (v4bf16 (bitconvert
-                      (v2i32 (AArch64duplane32
-                        (v4i32 (bitconvert
-                          (v8bf16 (insert_subvector undef,
-                            (v4bf16 V64:$Rm),
-                            (i64 0))))),
-                        VectorIndexS:$idx)))))),
-          (BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
-                             (SUBREG_TO_REG (i32 0), V64:$Rm, dsub),
-                             VectorIndexS:$idx)>;
 }
 
 let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index 52b542790e82d..ca3cd6bbae549 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -151,6 +151,37 @@ entry:
   ret <4 x float> %vbfmlaltq_v3.i
 }
 
+define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b)  {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    bfdot v2.4s, v0.8h, v1.2h[0]
+; CHECK-NEXT:    mov v0.16b, v2.16b
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <8 x bfloat> %b to <4 x i32>
+  %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer
+  %2 = bitcast <4 x i32> %1 to <8 x bfloat>
+  %vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2)
+  ret <4 x float> %vbfdotq
+}
+
+define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b)  {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi d2, #0000000000000000
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    bfdot v2.2s, v0.4h, v1.2h[0]
+; CHECK-NEXT:    fmov d0, d2
+; CHECK-NEXT:    ret
+entry:
+  %0 = bitcast <4 x bfloat> %b to <2 x i32>
+  %1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer
+  %2 = bitcast <2 x i32> %1 to <4 x bfloat>
+  %vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2)
+  ret <2 x float> %vbfdotq
+}
+
 declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>)
 declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>)
 declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/171146


More information about the llvm-commits mailing list