[llvm] bd61664 - [AArch64][SME] Add ldr/str (fill/spill) intrinsics

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 14 05:58:28 PDT 2022


Author: David Sherwood
Date: 2022-06-14T13:58:22+01:00
New Revision: bd616641675591ecd136b44df8af2ea61298c30f

URL: https://github.com/llvm/llvm-project/commit/bd616641675591ecd136b44df8af2ea61298c30f
DIFF: https://github.com/llvm/llvm-project/commit/bd616641675591ecd136b44df8af2ea61298c30f.diff

LOG: [AArch64][SME] Add ldr/str (fill/spill) intrinsics

This patch adds implementations for the fill/spill SME ACLE intrinsics:

    @llvm.aarch64.sme.ldr
    @llvm.aarch64.sme.str

Differential Revision: https://reviews.llvm.org/D127317

Added: 
    

Modified: 
    llvm/include/llvm/IR/IntrinsicsAArch64.td
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/SMEInstrFormats.td
    llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll
    llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 6aa976e4e8df..b608c55bcc13 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2625,4 +2625,10 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sme_st1w_vert  : SME_Load_Store_S_Intrinsic;
   def int_aarch64_sme_st1d_vert  : SME_Load_Store_D_Intrinsic;
   def int_aarch64_sme_st1q_vert  : SME_Load_Store_Q_Intrinsic;
+
+  // Spill + fill
+  def int_aarch64_sme_ldr : DefaultAttrsIntrinsic<
+    [], [llvm_i32_ty, llvm_ptr_ty]>;
+  def int_aarch64_sme_str : DefaultAttrsIntrinsic<
+    [], [llvm_i32_ty, llvm_ptr_ty]>;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index ad97a2b265f5..c7e982771e3f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -5114,6 +5114,10 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) {
 
   const unsigned IntNo =
       cast<ConstantSDNode>(Root->getOperand(1))->getZExtValue();
+  if (IntNo == Intrinsic::aarch64_sme_ldr ||
+      IntNo == Intrinsic::aarch64_sme_str)
+    return MVT::nxv16i8;
+
   if (IntNo != Intrinsic::aarch64_sve_prf)
     return EVT();
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7e2fbd33de32..5d61e621ae91 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2341,6 +2341,22 @@ AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
   return BB;
 }
 
+MachineBasicBlock *
+AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const {
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  MachineInstrBuilder MIB =
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::LDR_ZA));
+
+  MIB.addReg(AArch64::ZA, RegState::Define);
+  MIB.add(MI.getOperand(0)); // Vector select register
+  MIB.add(MI.getOperand(1)); // Vector select offset
+  MIB.add(MI.getOperand(2)); // Base
+  MIB.add(MI.getOperand(1)); // Offset, same as vector select offset
+
+  MI.eraseFromParent(); // The pseudo is gone now.
+  return BB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
   switch (MI.getOpcode()) {
@@ -2391,6 +2407,8 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitTileLoad(AArch64::LD1_MXIPXX_V_D, AArch64::ZAD0, MI, BB);
   case AArch64::LD1_MXIPXX_V_PSEUDO_Q:
     return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB);
+  case AArch64::LDR_ZA_PSEUDO:
+    return EmitFill(MI, BB);
   }
 }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 2b8fbbab5517..e3f1121c93a6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -559,6 +559,7 @@ class AArch64TargetLowering : public TargetLowering {
   MachineBasicBlock *EmitTileLoad(unsigned Opc, unsigned BaseReg,
                                   MachineInstr &MI,
                                   MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,

diff  --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 9a1dc17d4486..9d5670f9b93c 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -22,6 +22,8 @@ def tileslice32  : ComplexPattern<i32 , 2, "SelectSMETileSlice<2>", []>;
 def tileslice64  : ComplexPattern<i32 , 2, "SelectSMETileSlice<1>", []>;
 def tileslice128 : ComplexPattern<i32 , 2, "SelectSMETileSlice<0>", []>; // nop
 
+def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;
+
 //===----------------------------------------------------------------------===//
 // SME Outer Products
 //===----------------------------------------------------------------------===//
@@ -509,7 +511,7 @@ multiclass sme_mem_st_ss<string mnemonic> {
 // SME Save and Restore Array
 //===----------------------------------------------------------------------===//
 
-class sme_spill_fill_inst<bit isStore, dag outs, dag ins, string opcodestr>
+class sme_spill_fill_base<bit isStore, dag outs, dag ins, string opcodestr>
     : I<outs, ins, opcodestr, "\t$ZAt[$Rv, $imm4], [$Rn, $offset, mul vl]", "",
         []>,
       Sched<[]> {
@@ -524,33 +526,61 @@ class sme_spill_fill_inst<bit isStore, dag outs, dag ins, string opcodestr>
   let Inst{9-5}   = Rn;
   let Inst{4}     = 0b0;
   let Inst{3-0}   = imm4;
-
-  let mayLoad = !not(isStore);
-  let mayStore = isStore;
 }
 
-multiclass sme_spill_fill<bit isStore, dag outs, dag ins, string opcodestr> {
-  def NAME : sme_spill_fill_inst<isStore, outs, ins, opcodestr>;
-
+let mayStore = 1 in
+class sme_spill_inst<string opcodestr>
+    : sme_spill_fill_base<0b1, (outs),
+                          (ins MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv,
+                               sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
+                               imm0_15:$offset),
+                          opcodestr>;
+let mayLoad = 1 in
+class sme_fill_inst<string opcodestr>
+    : sme_spill_fill_base<0b0, (outs MatrixOp:$ZAt),
+                          (ins MatrixIndexGPR32Op12_15:$Rv,
+                               sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
+                               imm0_15:$offset),
+                          opcodestr>;
+multiclass sme_spill<string opcodestr> {
+  def NAME : sme_spill_inst<opcodestr>;
   def : InstAlias<opcodestr # "\t$ZAt[$Rv, $imm4], [$Rn]",
                   (!cast<Instruction>(NAME) MatrixOp:$ZAt,
                    MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
-}
-
-multiclass sme_spill<string opcodestr> {
-  defm NAME : sme_spill_fill<0b1, (outs),
-                             (ins MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv,
-                                  sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
-                                  imm0_15:$offset),
-                             opcodestr>;
+  // base
+  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base),
+            (!cast<Instruction>(NAME) ZA, $idx, 0, $base, 0)>;
+  // scalar + immediate (mul vl)
+  let AddedComplexity = 2 in {
+    def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx,
+                                   (am_sme_indexed_b4 GPR64sp:$base, imm0_15:$imm4)),
+              (!cast<Instruction>(NAME) ZA, $idx, 0, $base, $imm4)>;
+  }
 }
 
 multiclass sme_fill<string opcodestr> {
-  defm NAME : sme_spill_fill<0b0, (outs MatrixOp:$ZAt),
-                             (ins MatrixIndexGPR32Op12_15:$Rv,
-                                  sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
-                                  imm0_15:$offset),
-                             opcodestr>;
+  def NAME : sme_fill_inst<opcodestr>;
+  def : InstAlias<opcodestr # "\t$ZAt[$Rv, $imm4], [$Rn]",
+                  (!cast<Instruction>(NAME) MatrixOp:$ZAt,
+                   MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
+  def NAME # _PSEUDO
+      : Pseudo<(outs),
+               (ins MatrixIndexGPR32Op12_15:$idx, imm0_15:$imm4,
+                    GPR64sp:$base), []>,
+        Sched<[]> {
+    // Translated to actual instruction in AArch64ISelLowering.cpp
+    let usesCustomInserter = 1;
+    let mayLoad = 1;
+  }
+  // base
+  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base),
+            (!cast<Instruction>(NAME # _PSEUDO) $idx, 0, $base)>;
+  // scalar + immediate (mul vl)
+  let AddedComplexity = 2 in {
+    def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx,
+                                   (am_sme_indexed_b4 GPR64sp:$base, imm0_15:$imm4)),
+              (!cast<Instruction>(NAME # _PSEUDO) $idx, $imm4, $base)>;
+  }
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll
index 3418d7e8a819..9bfe6280e652 100644
--- a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-loads.ll
@@ -246,6 +246,55 @@ define void @ld1q_with_addr_offset(<vscale x 16 x i1> %pg, i128* %ptr, i64 %inde
   ret void;
 }
 
+define void @ldr(i8* %ptr) {
+; CHECK-LABEL: ldr:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    ldr za[w12, 0], [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.ldr(i32 0, i8* %ptr)
+  ret void;
+}
+
+define void @ldr_with_off_15(i8* %ptr) {
+; CHECK-LABEL: ldr_with_off_15:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    add x8, x0, #15
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  %base = getelementptr i8, i8* %ptr, i64 15
+  call void @llvm.aarch64.sme.ldr(i32 0, i8* %base)
+  ret void;
+}
+
+define void @ldr_with_off_15mulvl(i8* %ptr) {
+; CHECK-LABEL: ldr_with_off_15mulvl:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    ldr za[w12, 15], [x0, #15, mul vl]
+; CHECK-NEXT:    ret
+  %vscale = call i64 @llvm.vscale.i64()
+  %mulvl = mul i64 %vscale, 240
+  %base = getelementptr i8, i8* %ptr, i64 %mulvl
+  call void @llvm.aarch64.sme.ldr(i32 0, i8* %base)
+  ret void;
+}
+
+define void @ldr_with_off_16mulvl(i8* %ptr) {
+; CHECK-LABEL: ldr_with_off_16mulvl:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    addvl x8, x0, #16
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  %vscale = call i64 @llvm.vscale.i64()
+  %mulvl = mul i64 %vscale, 256
+  %base = getelementptr i8, i8* %ptr, i64 %mulvl
+  call void @llvm.aarch64.sme.ldr(i32 0, i8* %base)
+  ret void;
+}
+
 declare void @llvm.aarch64.sme.ld1b.horiz(<vscale x 16 x i1>, i8*, i64, i32)
 declare void @llvm.aarch64.sme.ld1h.horiz(<vscale x 16 x i1>, i16*, i64, i32)
 declare void @llvm.aarch64.sme.ld1w.horiz(<vscale x 16 x i1>, i32*, i64, i32)
@@ -256,3 +305,6 @@ declare void @llvm.aarch64.sme.ld1h.vert(<vscale x 16 x i1>, i16*, i64, i32)
 declare void @llvm.aarch64.sme.ld1w.vert(<vscale x 16 x i1>, i32*, i64, i32)
 declare void @llvm.aarch64.sme.ld1d.vert(<vscale x 16 x i1>, i64*, i64, i32)
 declare void @llvm.aarch64.sme.ld1q.vert(<vscale x 16 x i1>, i128*, i64, i32)
+
+declare void @llvm.aarch64.sme.ldr(i32, i8*)
+declare i64 @llvm.vscale.i64()

diff  --git a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll
index 5b8acd21520c..fa2b7cae5162 100644
--- a/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/SME/sme-intrinsics-stores.ll
@@ -246,6 +246,55 @@ define void @st1q_with_addr_offset(<vscale x 16 x i1> %pg, i128* %ptr, i64 %inde
   ret void;
 }
 
+define void @str(i8* %ptr) {
+; CHECK-LABEL: str:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    str za[w12, 0], [x0]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.str(i32 0, i8* %ptr)
+  ret void;
+}
+
+define void @str_with_off_15(i8* %ptr) {
+; CHECK-LABEL: str_with_off_15:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    add x8, x0, #15
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  %base = getelementptr i8, i8* %ptr, i64 15
+  call void @llvm.aarch64.sme.str(i32 0, i8* %base)
+  ret void;
+}
+
+define void @str_with_off_15mulvl(i8* %ptr) {
+; CHECK-LABEL: str_with_off_15mulvl:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    str za[w12, 0], [x0, #15, mul vl]
+; CHECK-NEXT:    ret
+  %vscale = call i64 @llvm.vscale.i64()
+  %mulvl = mul i64 %vscale, 240
+  %base = getelementptr i8, i8* %ptr, i64 %mulvl
+  call void @llvm.aarch64.sme.str(i32 0, i8* %base)
+  ret void;
+}
+
+define void @str_with_off_16mulvl(i8* %ptr) {
+; CHECK-LABEL: str_with_off_16mulvl:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, wzr
+; CHECK-NEXT:    addvl x8, x0, #16
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  %vscale = call i64 @llvm.vscale.i64()
+  %mulvl = mul i64 %vscale, 256
+  %base = getelementptr i8, i8* %ptr, i64 %mulvl
+  call void @llvm.aarch64.sme.str(i32 0, i8* %base)
+  ret void;
+}
+
 declare void @llvm.aarch64.sme.st1b.horiz(<vscale x 16 x i1>, i8*, i64, i32)
 declare void @llvm.aarch64.sme.st1h.horiz(<vscale x 16 x i1>, i16*, i64, i32)
 declare void @llvm.aarch64.sme.st1w.horiz(<vscale x 16 x i1>, i32*, i64, i32)
@@ -256,3 +305,6 @@ declare void @llvm.aarch64.sme.st1h.vert(<vscale x 16 x i1>, i16*, i64, i32)
 declare void @llvm.aarch64.sme.st1w.vert(<vscale x 16 x i1>, i32*, i64, i32)
 declare void @llvm.aarch64.sme.st1d.vert(<vscale x 16 x i1>, i64*, i64, i32)
 declare void @llvm.aarch64.sme.st1q.vert(<vscale x 16 x i1>, i128*, i64, i32)
+
+declare void @llvm.aarch64.sme.str(i32, i8*)
+declare i64 @llvm.vscale.i64()


        


More information about the llvm-commits mailing list