[llvm] [AArch64] Generalize bfdotq_lane patterns to work for f32/i32 duplanes (PR #171146)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 8 07:06:59 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This also removes an overly specific pattern that is redundant with this change.
Fixes #<!-- -->170883
---
Full diff: https://github.com/llvm/llvm-project/pull/171146.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+47-30)
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (-17)
- (modified) llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll (+31)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 4d2e740779961..821dfbd8e9191 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -10,6 +10,20 @@
// Describe AArch64 instructions format here
//
+// Helper class to convert vector element types to integers.
+class ChangeElementTypeToInteger<ValueType InVT> {
+ ValueType VT = !cond(
+ !eq(InVT, v2f32): v2i32,
+ !eq(InVT, v4f32): v4i32,
+ // TODO: Other types.
+ true : untyped);
+}
+
+class VTPair<ValueType A, ValueType B> {
+ ValueType VT0 = A;
+ ValueType VT1 = B;
+}
+
// Format specifies the encoding used by the instruction. This is part of the
// ad-hoc solution used to emit machine instruction encodings by our machine
// code emitter.
@@ -8952,36 +8966,6 @@ multiclass SIMDThreeSameVectorBFDot<bit U, string asm> {
v4f32, v8bf16>;
}
-class BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
- string dst_kind, string lhs_kind,
- string rhs_kind,
- RegisterOperand RegType,
- ValueType AccumType,
- ValueType InputType>
- : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111,
- RegType, RegType, V128, VectorIndexS,
- asm, "", dst_kind, lhs_kind, rhs_kind,
- [(set (AccumType RegType:$dst),
- (AccumType (int_aarch64_neon_bfdot
- (AccumType RegType:$Rd),
- (InputType RegType:$Rn),
- (InputType (bitconvert (AccumType
- (AArch64duplane32 (v4f32 V128:$Rm),
- VectorIndexS:$idx)))))))]> {
-
- bits<2> idx;
- let Inst{21} = idx{0}; // L
- let Inst{11} = idx{1}; // H
-}
-
-multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
-
- def v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
- ".2h", V64, v2f32, v4bf16>;
- def v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
- ".2h", V128, v4f32, v8bf16>;
-}
-
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDBF16MLAL<bit Q, string asm, SDPatternOperator OpNode>
: BaseSIMDThreeSameVectorTied<Q, 0b1, 0b110, 0b11111, V128, asm, ".4s",
@@ -9054,6 +9038,39 @@ class BF16ToSinglePrecision<string asm>
}
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0
+multiclass BaseSIMDThreeSameVectorBF16DotI<bit Q, bit U, string asm,
+ string dst_kind, string lhs_kind,
+ string rhs_kind,
+ RegisterOperand RegType,
+ ValueType AccumType,
+ ValueType InputType> {
+ let mayLoad = 0, mayStore = 0, hasSideEffects = 0 in {
+ def NAME : BaseSIMDIndexedTied<Q, U, 0b0, 0b01, 0b1111, RegType, RegType, V128, VectorIndexS,
+ asm, "", dst_kind, lhs_kind, rhs_kind, []>
+ {
+ bits<2> idx;
+ let Inst{21} = idx{0}; // L
+ let Inst{11} = idx{1}; // H
+ }
+ }
+
+ foreach DupTypes = [VTPair<AccumType, v4f32>,
+ VTPair<ChangeElementTypeToInteger<AccumType>.VT, v4i32>] in {
+ def : Pat<(AccumType (int_aarch64_neon_bfdot
+ (AccumType RegType:$Rd), (InputType RegType:$Rn),
+ (InputType (bitconvert
+ (DupTypes.VT0 (AArch64duplane32 (DupTypes.VT1 V128:$Rm), VectorIndexS:$Idx)))))),
+ (!cast<Instruction>(NAME) $Rd, $Rn, $Rm, VectorIndexS:$Idx)>;
+ }
+}
+
+multiclass SIMDThreeSameVectorBF16DotI<bit U, string asm> {
+ defm v4bf16 : BaseSIMDThreeSameVectorBF16DotI<0, U, asm, ".2s", ".4h",
+ ".2h", V64, v2f32, v4bf16>;
+ defm v8bf16 : BaseSIMDThreeSameVectorBF16DotI<1, U, asm, ".4s", ".8h",
+ ".2h", V128, v4f32, v8bf16>;
+}
+
//----------------------------------------------------------------------------
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
string asm, string dst_kind,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 64017d7cafca3..d2e34219d40aa 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1729,23 +1729,6 @@ def BFCVTN2 : SIMD_BFCVTN2;
def : Pat<(concat_vectors (v4bf16 V64:$Rd), (any_fpround (v4f32 V128:$Rn))),
(BFCVTN2 (v8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub)), V128:$Rn)>;
-
-// Vector-scalar BFDOT:
-// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
-// register (the instruction uses a single 32-bit lane from it), so the pattern
-// is a bit tricky.
-def : Pat<(v2f32 (int_aarch64_neon_bfdot
- (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
- (v4bf16 (bitconvert
- (v2i32 (AArch64duplane32
- (v4i32 (bitconvert
- (v8bf16 (insert_subvector undef,
- (v4bf16 V64:$Rm),
- (i64 0))))),
- VectorIndexS:$idx)))))),
- (BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn),
- (SUBREG_TO_REG (i32 0), V64:$Rm, dsub),
- VectorIndexS:$idx)>;
}
let Predicates = [HasNEONandIsStreamingSafe, HasBF16] in {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
index 52b542790e82d..ca3cd6bbae549 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-bf16-dotprod-intrinsics.ll
@@ -151,6 +151,37 @@ entry:
ret <4 x float> %vbfmlaltq_v3.i
}
+define <4 x float> @test_vbfdotq_laneq_f32_v4i32_shufflevector(<8 x bfloat> %a, <8 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v4i32_shufflevector:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: bfdot v2.4s, v0.8h, v1.2h[0]
+; CHECK-NEXT: mov v0.16b, v2.16b
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <8 x bfloat> %b to <4 x i32>
+ %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <4 x i32> zeroinitializer
+ %2 = bitcast <4 x i32> %1 to <8 x bfloat>
+ %vbfdotq = call <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float> zeroinitializer, <8 x bfloat> %a, <8 x bfloat> %2)
+ ret <4 x float> %vbfdotq
+}
+
+define <2 x float> @test_vbfdotq_laneq_f32_v2i32_shufflevector(<4 x bfloat> %a, <4 x bfloat> %b) {
+; CHECK-LABEL: test_vbfdotq_laneq_f32_v2i32_shufflevector:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: movi d2, #0000000000000000
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: bfdot v2.2s, v0.4h, v1.2h[0]
+; CHECK-NEXT: fmov d0, d2
+; CHECK-NEXT: ret
+entry:
+ %0 = bitcast <4 x bfloat> %b to <2 x i32>
+ %1 = shufflevector <2 x i32> %0, <2 x i32> poison, <2 x i32> zeroinitializer
+ %2 = bitcast <2 x i32> %1 to <4 x bfloat>
+ %vbfdotq = call <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float> zeroinitializer, <4 x bfloat> %a, <4 x bfloat> %2)
+ ret <2 x float> %vbfdotq
+}
+
declare <2 x float> @llvm.aarch64.neon.bfdot.v2f32.v4bf16(<2 x float>, <4 x bfloat>, <4 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfdot.v4f32.v8bf16(<4 x float>, <8 x bfloat>, <8 x bfloat>)
declare <4 x float> @llvm.aarch64.neon.bfmmla(<4 x float>, <8 x bfloat>, <8 x bfloat>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/171146
More information about the llvm-commits
mailing list