[llvm] ab3607c - [AArch64][SVE] Add missing load/store patterns for unpacked bfloat vectors.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 22 01:46:35 PDT 2021


Author: Sander de Smalen
Date: 2021-09-22T09:45:33+01:00
New Revision: ab3607c0ed92a7e39952ce22e72e778d2679876a

URL: https://github.com/llvm/llvm-project/commit/ab3607c0ed92a7e39952ce22e72e778d2679876a
DIFF: https://github.com/llvm/llvm-project/commit/ab3607c0ed92a7e39952ce22e72e778d2679876a.diff

LOG: [AArch64][SVE] Add missing load/store patterns for unpacked bfloat vectors.

Reviewed By: c-rhodes

Differential Revision: https://reviews.llvm.org/D110063

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/test/CodeGen/AArch64/sve-ld1-addressing-mode-reg-reg.ll
    llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
    llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-reg.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 5059185c76188..7455b52d78a0e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2000,25 +2000,27 @@ let Predicates = [HasSVEorStreamingSVE] in {
   }
 
   // 2-element contiguous loads
-  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i8,   LD1B_D,  LD1B_D_IMM,  am_sve_regreg_lsl0>;
-  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i8,  LD1SB_D, LD1SB_D_IMM, am_sve_regreg_lsl0>;
-  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i16,  LD1H_D,  LD1H_D_IMM,  am_sve_regreg_lsl1>;
-  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i16, LD1SH_D, LD1SH_D_IMM, am_sve_regreg_lsl1>;
-  defm : pred_load<nxv2i64, nxv2i1, zext_masked_load_i32,  LD1W_D,  LD1W_D_IMM,  am_sve_regreg_lsl2>;
-  defm : pred_load<nxv2i64, nxv2i1, asext_masked_load_i32, LD1SW_D, LD1SW_D_IMM, am_sve_regreg_lsl2>;
-  defm : pred_load<nxv2i64, nxv2i1, nonext_masked_load,    LD1D,    LD1D_IMM,    am_sve_regreg_lsl3>;
-  defm : pred_load<nxv2f16, nxv2i1, nonext_masked_load,    LD1H_D,  LD1H_D_IMM,  am_sve_regreg_lsl1>;
-  defm : pred_load<nxv2f32, nxv2i1, nonext_masked_load,    LD1W_D,  LD1W_D_IMM,  am_sve_regreg_lsl2>;
-  defm : pred_load<nxv2f64, nxv2i1, nonext_masked_load,    LD1D,    LD1D_IMM,    am_sve_regreg_lsl3>;
+  defm : pred_load<nxv2i64,  nxv2i1, zext_masked_load_i8,   LD1B_D,  LD1B_D_IMM,  am_sve_regreg_lsl0>;
+  defm : pred_load<nxv2i64,  nxv2i1, asext_masked_load_i8,  LD1SB_D, LD1SB_D_IMM, am_sve_regreg_lsl0>;
+  defm : pred_load<nxv2i64,  nxv2i1, zext_masked_load_i16,  LD1H_D,  LD1H_D_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv2i64,  nxv2i1, asext_masked_load_i16, LD1SH_D, LD1SH_D_IMM, am_sve_regreg_lsl1>;
+  defm : pred_load<nxv2i64,  nxv2i1, zext_masked_load_i32,  LD1W_D,  LD1W_D_IMM,  am_sve_regreg_lsl2>;
+  defm : pred_load<nxv2i64,  nxv2i1, asext_masked_load_i32, LD1SW_D, LD1SW_D_IMM, am_sve_regreg_lsl2>;
+  defm : pred_load<nxv2i64,  nxv2i1, nonext_masked_load,    LD1D,    LD1D_IMM,    am_sve_regreg_lsl3>;
+  defm : pred_load<nxv2f16,  nxv2i1, nonext_masked_load,    LD1H_D,  LD1H_D_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv2bf16, nxv2i1, nonext_masked_load,    LD1H_D,  LD1H_D_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv2f32,  nxv2i1, nonext_masked_load,    LD1W_D,  LD1W_D_IMM,  am_sve_regreg_lsl2>;
+  defm : pred_load<nxv2f64,  nxv2i1, nonext_masked_load,    LD1D,    LD1D_IMM,    am_sve_regreg_lsl3>;
 
   // 4-element contiguous loads
-  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i8,   LD1B_S,  LD1B_S_IMM,  am_sve_regreg_lsl0>;
-  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i8,  LD1SB_S, LD1SB_S_IMM, am_sve_regreg_lsl0>;
-  defm : pred_load<nxv4i32, nxv4i1, zext_masked_load_i16,  LD1H_S,  LD1H_S_IMM,  am_sve_regreg_lsl1>;
-  defm : pred_load<nxv4i32, nxv4i1, asext_masked_load_i16, LD1SH_S, LD1SH_S_IMM, am_sve_regreg_lsl1>;
-  defm : pred_load<nxv4i32, nxv4i1, nonext_masked_load,    LD1W,    LD1W_IMM,    am_sve_regreg_lsl2>;
-  defm : pred_load<nxv4f16, nxv4i1, nonext_masked_load,    LD1H_S,  LD1H_S_IMM,  am_sve_regreg_lsl1>;
-  defm : pred_load<nxv4f32, nxv4i1, nonext_masked_load,    LD1W,    LD1W_IMM,    am_sve_regreg_lsl2>;
+  defm : pred_load<nxv4i32,  nxv4i1, zext_masked_load_i8,   LD1B_S,  LD1B_S_IMM,  am_sve_regreg_lsl0>;
+  defm : pred_load<nxv4i32,  nxv4i1, asext_masked_load_i8,  LD1SB_S, LD1SB_S_IMM, am_sve_regreg_lsl0>;
+  defm : pred_load<nxv4i32,  nxv4i1, zext_masked_load_i16,  LD1H_S,  LD1H_S_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv4i32,  nxv4i1, asext_masked_load_i16, LD1SH_S, LD1SH_S_IMM, am_sve_regreg_lsl1>;
+  defm : pred_load<nxv4i32,  nxv4i1, nonext_masked_load,    LD1W,    LD1W_IMM,    am_sve_regreg_lsl2>;
+  defm : pred_load<nxv4f16,  nxv4i1, nonext_masked_load,    LD1H_S,  LD1H_S_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv4bf16, nxv4i1, nonext_masked_load,    LD1H_S,  LD1H_S_IMM,  am_sve_regreg_lsl1>;
+  defm : pred_load<nxv4f32,  nxv4i1, nonext_masked_load,    LD1W,    LD1W_IMM,    am_sve_regreg_lsl2>;
 
   // 8-element contiguous loads
   defm : pred_load<nxv8i16,  nxv8i1, zext_masked_load_i8,  LD1B_H,  LD1B_H_IMM,  am_sve_regreg_lsl0>;
@@ -2045,20 +2047,22 @@ let Predicates = [HasSVEorStreamingSVE] in {
   }
 
   // 2-element contiguous stores
-  defm : pred_store<nxv2i64, nxv2i1, trunc_masked_store_i8,  ST1B_D, ST1B_D_IMM, am_sve_regreg_lsl0>;
-  defm : pred_store<nxv2i64, nxv2i1, trunc_masked_store_i16, ST1H_D, ST1H_D_IMM, am_sve_regreg_lsl1>;
-  defm : pred_store<nxv2i64, nxv2i1, trunc_masked_store_i32, ST1W_D, ST1W_D_IMM, am_sve_regreg_lsl2>;
-  defm : pred_store<nxv2i64, nxv2i1, nontrunc_masked_store,  ST1D,   ST1D_IMM,   am_sve_regreg_lsl3>;
-  defm : pred_store<nxv2f16, nxv2i1, nontrunc_masked_store,  ST1H_D, ST1H_D_IMM, am_sve_regreg_lsl1>;
-  defm : pred_store<nxv2f32, nxv2i1, nontrunc_masked_store,  ST1W_D, ST1W_D_IMM, am_sve_regreg_lsl2>;
-  defm : pred_store<nxv2f64, nxv2i1, nontrunc_masked_store,  ST1D,   ST1D_IMM,   am_sve_regreg_lsl3>;
+  defm : pred_store<nxv2i64,  nxv2i1, trunc_masked_store_i8,  ST1B_D, ST1B_D_IMM, am_sve_regreg_lsl0>;
+  defm : pred_store<nxv2i64,  nxv2i1, trunc_masked_store_i16, ST1H_D, ST1H_D_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv2i64,  nxv2i1, trunc_masked_store_i32, ST1W_D, ST1W_D_IMM, am_sve_regreg_lsl2>;
+  defm : pred_store<nxv2i64,  nxv2i1, nontrunc_masked_store,  ST1D,   ST1D_IMM,   am_sve_regreg_lsl3>;
+  defm : pred_store<nxv2f16,  nxv2i1, nontrunc_masked_store,  ST1H_D, ST1H_D_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv2bf16, nxv2i1, nontrunc_masked_store,  ST1H_D, ST1H_D_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv2f32,  nxv2i1, nontrunc_masked_store,  ST1W_D, ST1W_D_IMM, am_sve_regreg_lsl2>;
+  defm : pred_store<nxv2f64,  nxv2i1, nontrunc_masked_store,  ST1D,   ST1D_IMM,   am_sve_regreg_lsl3>;
 
   // 4-element contiguous stores
-  defm : pred_store<nxv4i32, nxv4i1, trunc_masked_store_i8,  ST1B_S, ST1B_S_IMM, am_sve_regreg_lsl0>;
-  defm : pred_store<nxv4i32, nxv4i1, trunc_masked_store_i16, ST1H_S, ST1H_S_IMM, am_sve_regreg_lsl1>;
-  defm : pred_store<nxv4i32, nxv4i1, nontrunc_masked_store,  ST1W,   ST1W_IMM,   am_sve_regreg_lsl2>;
-  defm : pred_store<nxv4f16, nxv4i1, nontrunc_masked_store,  ST1H_S, ST1H_S_IMM, am_sve_regreg_lsl1>;
-  defm : pred_store<nxv4f32, nxv4i1, nontrunc_masked_store,  ST1W,   ST1W_IMM,   am_sve_regreg_lsl2>;
+  defm : pred_store<nxv4i32,  nxv4i1, trunc_masked_store_i8,  ST1B_S, ST1B_S_IMM, am_sve_regreg_lsl0>;
+  defm : pred_store<nxv4i32,  nxv4i1, trunc_masked_store_i16, ST1H_S, ST1H_S_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv4i32,  nxv4i1, nontrunc_masked_store,  ST1W,   ST1W_IMM,   am_sve_regreg_lsl2>;
+  defm : pred_store<nxv4f16,  nxv4i1, nontrunc_masked_store,  ST1H_S, ST1H_S_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv4bf16, nxv4i1, nontrunc_masked_store,  ST1H_S, ST1H_S_IMM, am_sve_regreg_lsl1>;
+  defm : pred_store<nxv4f32,  nxv4i1, nontrunc_masked_store,  ST1W,   ST1W_IMM,   am_sve_regreg_lsl2>;
 
   // 8-element contiguous stores
   defm : pred_store<nxv8i16,  nxv8i1, trunc_masked_store_i8, ST1B_H, ST1B_H_IMM, am_sve_regreg_lsl0>;
@@ -2099,23 +2103,25 @@ let Predicates = [HasSVEorStreamingSVE] in {
               (RegImmInst ZPR:$val, (PTrue 31), GPR64:$base, (i64 0))>;
   }
 
-  defm : unpred_store<         store, nxv16i8,   ST1B,   ST1B_IMM, PTRUE_B, am_sve_regreg_lsl0>;
-  defm : unpred_store< truncstorevi8, nxv8i16, ST1B_H, ST1B_H_IMM, PTRUE_H, am_sve_regreg_lsl0>;
-  defm : unpred_store< truncstorevi8, nxv4i32, ST1B_S, ST1B_S_IMM, PTRUE_S, am_sve_regreg_lsl0>;
-  defm : unpred_store< truncstorevi8, nxv2i64, ST1B_D, ST1B_D_IMM, PTRUE_D, am_sve_regreg_lsl0>;
-  defm : unpred_store<         store, nxv8i16,   ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
-  defm : unpred_store<truncstorevi16, nxv4i32, ST1H_S, ST1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
-  defm : unpred_store<truncstorevi16, nxv2i64, ST1H_D, ST1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
-  defm : unpred_store<         store, nxv4i32,   ST1W,   ST1W_IMM, PTRUE_S, am_sve_regreg_lsl2>;
-  defm : unpred_store<truncstorevi32, nxv2i64, ST1W_D, ST1W_D_IMM, PTRUE_D, am_sve_regreg_lsl2>;
-  defm : unpred_store<         store, nxv2i64,   ST1D,   ST1D_IMM, PTRUE_D, am_sve_regreg_lsl3>;
-  defm : unpred_store<         store, nxv8f16,   ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
-  defm : unpred_store<         store, nxv8bf16,  ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
-  defm : unpred_store<         store, nxv4f16, ST1H_S, ST1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
-  defm : unpred_store<         store, nxv2f16, ST1H_D, ST1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
-  defm : unpred_store<         store, nxv4f32,   ST1W,   ST1W_IMM, PTRUE_S, am_sve_regreg_lsl2>;
-  defm : unpred_store<         store, nxv2f32, ST1W_D, ST1W_D_IMM, PTRUE_D, am_sve_regreg_lsl2>;
-  defm : unpred_store<         store, nxv2f64,   ST1D,   ST1D_IMM, PTRUE_D, am_sve_regreg_lsl3>;
+  defm : unpred_store<         store, nxv16i8,    ST1B,   ST1B_IMM, PTRUE_B, am_sve_regreg_lsl0>;
+  defm : unpred_store< truncstorevi8, nxv8i16,  ST1B_H, ST1B_H_IMM, PTRUE_H, am_sve_regreg_lsl0>;
+  defm : unpred_store< truncstorevi8, nxv4i32,  ST1B_S, ST1B_S_IMM, PTRUE_S, am_sve_regreg_lsl0>;
+  defm : unpred_store< truncstorevi8, nxv2i64,  ST1B_D, ST1B_D_IMM, PTRUE_D, am_sve_regreg_lsl0>;
+  defm : unpred_store<         store, nxv8i16,    ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
+  defm : unpred_store<truncstorevi16, nxv4i32,  ST1H_S, ST1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
+  defm : unpred_store<truncstorevi16, nxv2i64,  ST1H_D, ST1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv4i32,    ST1W,   ST1W_IMM, PTRUE_S, am_sve_regreg_lsl2>;
+  defm : unpred_store<truncstorevi32, nxv2i64,  ST1W_D, ST1W_D_IMM, PTRUE_D, am_sve_regreg_lsl2>;
+  defm : unpred_store<         store, nxv2i64,    ST1D,   ST1D_IMM, PTRUE_D, am_sve_regreg_lsl3>;
+  defm : unpred_store<         store, nxv8f16,    ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv8bf16,   ST1H,   ST1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv4f16,  ST1H_S, ST1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv4bf16, ST1H_S, ST1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv2f16,  ST1H_D, ST1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv2bf16, ST1H_D, ST1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
+  defm : unpred_store<         store, nxv4f32,    ST1W,   ST1W_IMM, PTRUE_S, am_sve_regreg_lsl2>;
+  defm : unpred_store<         store, nxv2f32,  ST1W_D, ST1W_D_IMM, PTRUE_D, am_sve_regreg_lsl2>;
+  defm : unpred_store<         store, nxv2f64,    ST1D,   ST1D_IMM, PTRUE_D, am_sve_regreg_lsl3>;
 
   multiclass unpred_load<PatFrag Load, ValueType Ty, Instruction RegRegInst,
                          Instruction RegImmInst, Instruction PTrue,
@@ -2162,7 +2168,9 @@ let Predicates = [HasSVEorStreamingSVE] in {
   defm : unpred_load<        load, nxv8f16,    LD1H,    LD1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
   defm : unpred_load<        load, nxv8bf16,   LD1H,    LD1H_IMM, PTRUE_H, am_sve_regreg_lsl1>;
   defm : unpred_load<        load, nxv4f16,  LD1H_S,  LD1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
+  defm : unpred_load<        load, nxv4bf16, LD1H_S,  LD1H_S_IMM, PTRUE_S, am_sve_regreg_lsl1>;
   defm : unpred_load<        load, nxv2f16,  LD1H_D,  LD1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
+  defm : unpred_load<        load, nxv2bf16, LD1H_D,  LD1H_D_IMM, PTRUE_D, am_sve_regreg_lsl1>;
   defm : unpred_load<        load, nxv4f32,    LD1W,    LD1W_IMM, PTRUE_S, am_sve_regreg_lsl2>;
   defm : unpred_load<        load, nxv2f32,  LD1W_D,  LD1W_D_IMM, PTRUE_D, am_sve_regreg_lsl2>;
   defm : unpred_load<        load, nxv2f64,    LD1D,    LD1D_IMM, PTRUE_D, am_sve_regreg_lsl3>;

diff  --git a/llvm/test/CodeGen/AArch64/sve-ld1-addressing-mode-reg-reg.ll b/llvm/test/CodeGen/AArch64/sve-ld1-addressing-mode-reg-reg.ll
index 4079126000868..fe1fb107ecf70 100644
--- a/llvm/test/CodeGen/AArch64/sve-ld1-addressing-mode-reg-reg.ll
+++ b/llvm/test/CodeGen/AArch64/sve-ld1-addressing-mode-reg-reg.ll
@@ -231,6 +231,18 @@ define <vscale x 4 x half> @ld1_nxv4f16(half* %addr, i64 %off) {
   ret <vscale x 4 x half> %val
 }
 
+define <vscale x 4 x bfloat> @ld1_nxv4bf16(bfloat* %addr, i64 %off) {
+; CHECK-LABEL: ld1_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x0, x1, lsl #1]
+; CHECK-NEXT:    ret
+  %ptr = getelementptr inbounds bfloat, bfloat* %addr, i64 %off
+  %ptrcast = bitcast bfloat* %ptr to <vscale x 4 x bfloat>*
+  %val = load volatile <vscale x 4 x bfloat>, <vscale x 4 x bfloat>* %ptrcast
+  ret <vscale x 4 x bfloat> %val
+}
+
 define <vscale x 2 x half> @ld1_nxv2f16(half* %addr, i64 %off) {
 ; CHECK-LABEL: ld1_nxv2f16:
 ; CHECK:       // %bb.0:
@@ -243,6 +255,18 @@ define <vscale x 2 x half> @ld1_nxv2f16(half* %addr, i64 %off) {
   ret <vscale x 2 x half> %val
 }
 
+define <vscale x 2 x bfloat> @ld1_nxv2bf16(bfloat* %addr, i64 %off) {
+; CHECK-LABEL: ld1_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ld1h { z0.d }, p0/z, [x0, x1, lsl #1]
+; CHECK-NEXT:    ret
+  %ptr = getelementptr inbounds bfloat, bfloat* %addr, i64 %off
+  %ptrcast = bitcast bfloat* %ptr to <vscale x 2 x bfloat>*
+  %val = load volatile <vscale x 2 x bfloat>, <vscale x 2 x bfloat>* %ptrcast
+  ret <vscale x 2 x bfloat> %val
+}
+
 ; LD1W
 
 define <vscale x 4 x i32> @ld1_nxv4i32(i32* %addr, i64 %off) {

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
index 47a8e1a5db3ec..085c153967782 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
@@ -60,6 +60,14 @@ define <vscale x 2 x half> @masked_load_nxv2f16(<vscale x 2 x half> *%a, <vscale
   ret <vscale x 2 x half> %load
 }
 
+define <vscale x 2 x bfloat> @masked_load_nxv2bf16(<vscale x 2 x bfloat> *%a, <vscale x 2 x i1> %mask) nounwind #0 {
+; CHECK-LABEL: masked_load_nxv2bf16:
+; CHECK-NEXT: ld1h { z0.d }, p0/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(<vscale x 2 x bfloat> *%a, i32 2, <vscale x 2 x i1> %mask, <vscale x 2 x bfloat> undef)
+  ret <vscale x 2 x bfloat> %load
+}
+
 define <vscale x 4 x float> @masked_load_nxv4f32(<vscale x 4 x float> *%a, <vscale x 4 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv4f32:
 ; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0]
@@ -76,6 +84,14 @@ define <vscale x 4 x half> @masked_load_nxv4f16(<vscale x 4 x half> *%a, <vscale
   ret <vscale x 4 x half> %load
 }
 
+define <vscale x 4 x bfloat> @masked_load_nxv4bf16(<vscale x 4 x bfloat> *%a, <vscale x 4 x i1> %mask) nounwind #0 {
+; CHECK-LABEL: masked_load_nxv4bf16:
+; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0]
+; CHECK-NEXT: ret
+  %load = call <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(<vscale x 4 x bfloat> *%a, i32 2, <vscale x 4 x i1> %mask, <vscale x 4 x bfloat> undef)
+  ret <vscale x 4 x bfloat> %load
+}
+
 define <vscale x 8 x half> @masked_load_nxv8f16(<vscale x 8 x half> *%a, <vscale x 8 x i1> %mask) nounwind {
 ; CHECK-LABEL: masked_load_nxv8f16:
 ; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
@@ -185,6 +201,22 @@ define void @masked_store_nxv8f16(<vscale x 8 x half> *%a, <vscale x 8 x half> %
   ret void
 }
 
+define void @masked_store_nxv2bf16(<vscale x 2 x bfloat> *%a, <vscale x 2 x bfloat> %val, <vscale x 2 x i1> %mask) nounwind #0 {
+; CHECK-LABEL: masked_store_nxv2bf16:
+; CHECK-NEXT: st1h { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+  call void @llvm.masked.store.nxv2bf16(<vscale x 2 x bfloat> %val, <vscale x 2 x bfloat> *%a, i32 2, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @masked_store_nxv4bf16(<vscale x 4 x bfloat> *%a, <vscale x 4 x bfloat> %val, <vscale x 4 x i1> %mask) nounwind #0 {
+; CHECK-LABEL: masked_store_nxv4bf16:
+; CHECK-NEXT: st1h { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+  call void @llvm.masked.store.nxv4bf16(<vscale x 4 x bfloat> %val, <vscale x 4 x bfloat> *%a, i32 2, <vscale x 4 x i1> %mask)
+  ret void
+}
+
 define void @masked_store_nxv8bf16(<vscale x 8 x bfloat> *%a, <vscale x 8 x bfloat> %val, <vscale x 8 x i1> %mask) nounwind #0 {
 ; CHECK-LABEL: masked_store_nxv8bf16:
 ; CHECK-NEXT: st1h { z0.h }, p0, [x0]
@@ -292,6 +324,8 @@ declare <vscale x 2 x half> @llvm.masked.load.nxv2f16(<vscale x 2 x half>*, i32,
 declare <vscale x 4 x float> @llvm.masked.load.nxv4f32(<vscale x 4 x float>*, i32, <vscale x 4 x i1>, <vscale x 4 x float>)
 declare <vscale x 4 x half> @llvm.masked.load.nxv4f16(<vscale x 4 x half>*, i32, <vscale x 4 x i1>, <vscale x 4 x half>)
 declare <vscale x 8 x half> @llvm.masked.load.nxv8f16(<vscale x 8 x half>*, i32, <vscale x 8 x i1>, <vscale x 8 x half>)
+declare <vscale x 2 x bfloat> @llvm.masked.load.nxv2bf16(<vscale x 2 x bfloat>*, i32, <vscale x 2 x i1>, <vscale x 2 x bfloat>)
+declare <vscale x 4 x bfloat> @llvm.masked.load.nxv4bf16(<vscale x 4 x bfloat>*, i32, <vscale x 4 x i1>, <vscale x 4 x bfloat>)
 declare <vscale x 8 x bfloat> @llvm.masked.load.nxv8bf16(<vscale x 8 x bfloat>*, i32, <vscale x 8 x i1>, <vscale x 8 x bfloat>)
 
 declare void @llvm.masked.store.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>*, i32, <vscale x 2 x i1>)
@@ -305,6 +339,8 @@ declare void @llvm.masked.store.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>
 declare void @llvm.masked.store.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>*, i32, <vscale x 4 x i1>)
 declare void @llvm.masked.store.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>*, i32, <vscale x 4 x i1>)
 declare void @llvm.masked.store.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>*, i32, <vscale x 8 x i1>)
+declare void @llvm.masked.store.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat>*, i32, <vscale x 2 x i1>)
+declare void @llvm.masked.store.nxv4bf16(<vscale x 4 x bfloat>, <vscale x 4 x bfloat>*, i32, <vscale x 4 x i1>)
 declare void @llvm.masked.store.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8 x bfloat>*, i32, <vscale x 8 x i1>)
 
 declare <vscale x 2 x i8*> @llvm.masked.load.nxv2p0i8.p0nxv2p0i8(<vscale x 2 x i8*>*, i32 immarg, <vscale x 2 x i1>, <vscale x 2 x i8*>)

diff  --git a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-reg.ll b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-reg.ll
index 5dbc3366bd113..846c479c50806 100644
--- a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-reg.ll
+++ b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-reg.ll
@@ -166,6 +166,18 @@ define void @st1_nxv4f16(half* %addr, i64 %off, <vscale x 4 x half> %val) {
   ret void
 }
 
+define void @st1_nxv4bf16(bfloat* %addr, i64 %off, <vscale x 4 x bfloat> %val) {
+; CHECK-LABEL: st1_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    st1h { z0.s }, p0, [x0, x1, lsl #1]
+; CHECK-NEXT:    ret
+  %ptr = getelementptr inbounds bfloat, bfloat* %addr, i64 %off
+  %ptrcast = bitcast bfloat* %ptr to <vscale x 4 x bfloat>*
+  store <vscale x 4 x bfloat> %val, <vscale x 4 x bfloat>* %ptrcast
+  ret void
+}
+
 define void @st1_nxv2f16(half* %addr, i64 %off, <vscale x 2 x half> %val) {
 ; CHECK-LABEL: st1_nxv2f16:
 ; CHECK:       // %bb.0:
@@ -178,6 +190,18 @@ define void @st1_nxv2f16(half* %addr, i64 %off, <vscale x 2 x half> %val) {
   ret void
 }
 
+define void @st1_nxv2bf16(bfloat* %addr, i64 %off, <vscale x 2 x bfloat> %val) {
+; CHECK-LABEL: st1_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    st1h { z0.d }, p0, [x0, x1, lsl #1]
+; CHECK-NEXT:    ret
+  %ptr = getelementptr inbounds bfloat, bfloat* %addr, i64 %off
+  %ptrcast = bitcast bfloat* %ptr to <vscale x 2 x bfloat>*
+  store <vscale x 2 x bfloat> %val, <vscale x 2 x bfloat>* %ptrcast
+  ret void
+}
+
 ; ST1W
 
 define void @st1_nxv4i32(i32* %addr, i64 %off, <vscale x 4 x i32> %val) {


        


More information about the llvm-commits mailing list