[llvm] 45d2877 - [AArch64][SME] Fix lowering of llvm.aarch64.get.pstatesm()

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 15 08:16:00 PDT 2022


Author: Sander de Smalen
Date: 2022-09-15T15:14:13Z
New Revision: 45d28779c5dc6c8afa6feb24d68606f01b9800f4

URL: https://github.com/llvm/llvm-project/commit/45d28779c5dc6c8afa6feb24d68606f01b9800f4
DIFF: https://github.com/llvm/llvm-project/commit/45d28779c5dc6c8afa6feb24d68606f01b9800f4.diff

LOG: [AArch64][SME] Fix lowering of llvm.aarch64.get.pstatesm()

A thread may not have access to SME or TPIDR2_EL0, so in order to
safely query PSTATE.SM in a streaming-compatible function, the
code should call `__arm_sme_state()`, as described in the ABI:

  https://github.com/ARM-software/abi-aa/pull/123/commits/c2bb09c4d4ee60a5787baf1ccc7e92e67e4240b7

This means that the value of pstate.sm is:
* 0 if the function is non-streaming.
* 1 if the function has `arm_streaming` or `arm_locally_streaming`.
* evaluated at runtime by a call to __arm_sme_state() otherwise.

This patch also adds a calling convention for calls to SME support routines.

At some point we can remove the need for the llvm.aarch64.get.pstatesm() intrinsic
and use function calls (with the corresponding cc) directly instead.

Reviewed By: aemerson

Differential Revision: https://reviews.llvm.org/D131571

Added: 
    llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll

Modified: 
    llvm/include/llvm/AsmParser/LLToken.h
    llvm/include/llvm/IR/CallingConv.h
    llvm/lib/AsmParser/LLLexer.cpp
    llvm/lib/AsmParser/LLParser.cpp
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/Target/AArch64/AArch64CallingConvention.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
    llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 04235f0fdc4ef..a070cd89e3873 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -141,6 +141,8 @@ enum Kind {
   kw_arm_aapcs_vfpcc,
   kw_aarch64_vector_pcs,
   kw_aarch64_sve_vector_pcs,
+  kw_aarch64_sme_preservemost_from_x0,
+  kw_aarch64_sme_preservemost_from_x2,
   kw_msp430_intrcc,
   kw_avr_intrcc,
   kw_avr_signalcc,

diff  --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h
index 030a7e1503da4..9fefeef05cb21 100644
--- a/llvm/include/llvm/IR/CallingConv.h
+++ b/llvm/include/llvm/IR/CallingConv.h
@@ -235,6 +235,12 @@ namespace CallingConv {
     /// Used for M68k interrupt routines.
     M68k_INTR = 101,
 
+    /// Preserve X0-X13, X19-X29, SP, Z0-Z31, P0-P15.
+    AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 = 102,
+
+    /// Preserve X2-X15, X19-X29, SP, Z0-Z31, P0-P15.
+    AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 = 103,
+
     /// The highest possible ID. Must be some 2^k - 1.
     MaxID = 1023
   };

diff  --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index c9a982693fa75..c020fe779827e 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -597,6 +597,8 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(arm_aapcs_vfpcc);
   KEYWORD(aarch64_vector_pcs);
   KEYWORD(aarch64_sve_vector_pcs);
+  KEYWORD(aarch64_sme_preservemost_from_x0);
+  KEYWORD(aarch64_sme_preservemost_from_x2);
   KEYWORD(msp430_intrcc);
   KEYWORD(avr_intrcc);
   KEYWORD(avr_signalcc);

diff  --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 10a775f89cbcc..7475868f1e019 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -1875,6 +1875,8 @@ void LLParser::parseOptionalDLLStorageClass(unsigned &Res) {
 ///   ::= 'arm_aapcs_vfpcc'
 ///   ::= 'aarch64_vector_pcs'
 ///   ::= 'aarch64_sve_vector_pcs'
+///   ::= 'aarch64_sme_preservemost_from_x0'
+///   ::= 'aarch64_sme_preservemost_from_x2'
 ///   ::= 'msp430_intrcc'
 ///   ::= 'avr_intrcc'
 ///   ::= 'avr_signalcc'
@@ -1925,6 +1927,12 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
   case lltok::kw_aarch64_sve_vector_pcs:
     CC = CallingConv::AArch64_SVE_VectorCall;
     break;
+  case lltok::kw_aarch64_sme_preservemost_from_x0:
+    CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0;
+    break;
+  case lltok::kw_aarch64_sme_preservemost_from_x2:
+    CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2;
+    break;
   case lltok::kw_msp430_intrcc:  CC = CallingConv::MSP430_INTR; break;
   case lltok::kw_avr_intrcc:     CC = CallingConv::AVR_INTR; break;
   case lltok::kw_avr_signalcc:   CC = CallingConv::AVR_SIGNAL; break;

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 0ee559a7d7cf2..d9443f43ae14f 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -312,6 +312,12 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
   case CallingConv::AArch64_SVE_VectorCall:
     Out << "aarch64_sve_vector_pcs";
     break;
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+    Out << "aarch64_sme_preservemost_from_x0";
+    break;
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
+    Out << "aarch64_sme_preservemost_from_x2";
+    break;
   case CallingConv::MSP430_INTR:   Out << "msp430_intrcc"; break;
   case CallingConv::AVR_INTR:      Out << "avr_intrcc "; break;
   case CallingConv::AVR_SIGNAL:    Out << "avr_signalcc "; break;

diff  --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 6cf7bf6d1cfcf..0000d26f30443 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -435,6 +435,22 @@ def CSR_AArch64_SVE_AAPCS : CalleeSavedRegs<(add (sequence "Z%u", 8, 23),
                                                  X19, X20, X21, X22, X23, X24,
                                                  X25, X26, X27, X28, LR, FP)>;
 
+// SME ABI support routines such as __arm_tpidr2_save/restore preserve most registers.
+def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0
+                          : CalleeSavedRegs<(add (sequence "Z%u", 0, 31),
+                                                 (sequence "P%u", 0, 15),
+                                                 (sequence "X%u", 0, 13),
+                                                 (sequence "X%u",19, 28),
+                                                 LR, FP)>;
+
+// SME ABI support routines __arm_sme_state preserves most registers.
+def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2
+                          : CalleeSavedRegs<(add (sequence "Z%u", 0, 31),
+                                                 (sequence "P%u", 0, 15),
+                                                 (sequence "X%u", 2, 15),
+                                                 (sequence "X%u",19, 28),
+                                                 LR, FP)>;
+
 def CSR_AArch64_AAPCS_SwiftTail
     : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 198d332015252..944c8bae782c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4490,6 +4490,32 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
   return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
 }
 
+SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
+                                           SMEAttrs Attrs, SDLoc DL,
+                                           EVT VT) const {
+  if (Attrs.hasStreamingInterfaceOrBody())
+    return DAG.getConstant(1, DL, VT);
+
+  if (Attrs.hasNonStreamingInterfaceAndBody())
+    return DAG.getConstant(0, DL, VT);
+
+  assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
+
+  SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
+                                         getPointerTy(DAG.getDataLayout()));
+  Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
+  Type *RetTy = StructType::get(Int64Ty, Int64Ty);
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  ArgListTy Args;
+  CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+      CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
+      RetTy, Callee, std::move(Args));
+  std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
+  SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
+  return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
+                     Mask);
+}
+
 SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
                                                       SelectionDAG &DAG) const {
   unsigned IntNo = Op.getConstantOperandVal(1);
@@ -4521,13 +4547,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
     return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL);
   }
   case Intrinsic::aarch64_sme_get_pstatesm: {
-    SDValue Chain = Op.getOperand(0);
-    SDValue MRS = DAG.getNode(
-        AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
-        Chain, DAG.getConstant(AArch64SysReg::SVCR, DL, MVT::i64));
-    SDValue Mask = DAG.getConstant(/* PSTATE.SM */ 1, DL, MVT::i64);
-    SDValue And = DAG.getNode(ISD::AND, DL, MVT::i64, MRS, Mask);
-    return DAG.getMergeValues({And, Chain}, DL);
+    SDValue Chain = Op->getOperand(0);
+    SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
+    SDValue PStateSM = getPStateSM(DAG, Chain, Attrs, DL, Op.getValueType());
+    return DAG.getMergeValues({PStateSM, Chain}, DL);
   }
   }
 }
@@ -5834,6 +5857,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
      return CC_AArch64_Win64_CFGuard_Check;
    case CallingConv::AArch64_VectorCall:
    case CallingConv::AArch64_SVE_VectorCall:
+   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
      return CC_AArch64_AAPCS;
   }
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 82e05790dd08f..a5552caa9bb09 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -15,6 +15,7 @@
 #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H
 
 #include "AArch64.h"
+#include "Utils/AArch64SMEAttributes.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/SelectionDAG.h"
@@ -1158,6 +1159,11 @@ class AArch64TargetLowering : public TargetLowering {
   // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
+  // Returns the runtime value for PSTATE.SM. When the function is streaming-
+  // compatible, this generates a call to __arm_sme_state.
+  SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
+                      SDLoc DL, EVT VT) const;
+
   bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
                                               LLT Ty2) const override;
 };

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index f92fccac21dbb..91b6d183fa2e2 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -91,6 +91,18 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const {
     return CSR_AArch64_AAVPCS_SaveList;
   if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall)
     return CSR_AArch64_SVE_AAPCS_SaveList;
+  if (MF->getFunction().getCallingConv() ==
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+        "only supported to improve calls to SME ACLE save/restore/disable-za "
+        "functions, and is not intended to be used beyond that scope.");
+  if (MF->getFunction().getCallingConv() ==
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+        "only supported to improve calls to SME ACLE __arm_sme_state "
+        "and is not intended to be used beyond that scope.");
   if (MF->getSubtarget<AArch64Subtarget>().getTargetLowering()
           ->supportSwiftError() &&
       MF->getFunction().getAttributes().hasAttrSomewhere(
@@ -123,6 +135,18 @@ AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const {
   if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SVE_VectorCall)
     report_fatal_error(
         "Calling convention SVE_VectorCall is unsupported on Darwin.");
+  if (MF->getFunction().getCallingConv() ==
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+        "only supported to improve calls to SME ACLE save/restore/disable-za "
+        "functions, and is not intended to be used beyond that scope.");
+  if (MF->getFunction().getCallingConv() ==
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+        "only supported to improve calls to SME ACLE __arm_sme_state "
+        "and is not intended to be used beyond that scope.");
   if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS)
     return MF->getInfo<AArch64FunctionInfo>()->isSplitCSR()
                ? CSR_Darwin_AArch64_CXX_TLS_PE_SaveList
@@ -193,6 +217,14 @@ AArch64RegisterInfo::getDarwinCallPreservedMask(const MachineFunction &MF,
   if (CC == CallingConv::AArch64_SVE_VectorCall)
     report_fatal_error(
         "Calling convention SVE_VectorCall is unsupported on Darwin.");
+  if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is "
+        "unsupported on Darwin.");
+  if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+    report_fatal_error(
+        "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is "
+        "unsupported on Darwin.");
   if (CC == CallingConv::CFGuard_Check)
     report_fatal_error(
         "Calling convention CFGuard_Check is unsupported on Darwin.");
@@ -230,6 +262,10 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF,
   if (CC == CallingConv::AArch64_SVE_VectorCall)
     return SCS ? CSR_AArch64_SVE_AAPCS_SCS_RegMask
                : CSR_AArch64_SVE_AAPCS_RegMask;
+  if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0)
+    return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask;
+  if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2)
+    return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask;
   if (CC == CallingConv::CFGuard_Check)
     return CSR_Win_AArch64_CFGuard_Check_RegMask;
   if (MF.getSubtarget<AArch64Subtarget>().getTargetLowering()
@@ -539,6 +575,8 @@ bool AArch64RegisterInfo::isArgumentRegister(const MachineFunction &MF,
     return HasReg(CC_AArch64_Win64_CFGuard_Check_ArgRegs, Reg);
   case CallingConv::AArch64_VectorCall:
   case CallingConv::AArch64_SVE_VectorCall:
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
     return HasReg(CC_AArch64_AAPCS_ArgRegs, Reg);
   }
 }

diff  --git a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll
index 2586a8c40793f..a20abc8247217 100644
--- a/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll
+++ b/llvm/test/CodeGen/AArch64/sme-get-pstatesm.ll
@@ -1,14 +1,46 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs -stop-after=finalize-isel < %s | FileCheck %s --check-prefix=CHECK-CSRMASK
 
-define i64 @is_streaming() {
-; CHECK-LABEL: is_streaming:
+define i64 @get_pstatesm_normal() nounwind {
+; CHECK-LABEL: get_pstatesm_normal:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mrs x8, SVCR
-; CHECK-NEXT:    and x0, x8, #0x1
+; CHECK-NEXT:    mov x0, xzr
 ; CHECK-NEXT:    ret
   %pstate = call i64 @llvm.aarch64.sme.get.pstatesm()
   ret i64 %pstate
 }
 
+define i64 @get_pstatesm_streaming() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: get_pstatesm_streaming:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w0, #1
+; CHECK-NEXT:    ret
+  %pstate = call i64 @llvm.aarch64.sme.get.pstatesm()
+  ret i64 %pstate
+}
+
+define i64 @get_pstatesm_locally_streaming() nounwind "aarch64_pstate_sm_body" {
+; CHECK-LABEL: get_pstatesm_locally_streaming:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w0, #1
+; CHECK-NEXT:    ret
+  %pstate = call i64 @llvm.aarch64.sme.get.pstatesm()
+  ret i64 %pstate
+}
+
+define i64 @get_pstatesm_streaming_compatible() nounwind "aarch64_pstate_sm_compatible" {
+; CHECK-LABEL: get_pstatesm_streaming_compatible:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x0, x0, #0x1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-CSRMASK-LABEL: name: get_pstatesm_streaming_compatible
+; CHECK-CSRMASK: BL &__arm_sme_state, csr_aarch64_sme_abi_support_routines_preservemost_from_x2
+  %pstate = call i64 @llvm.aarch64.sme.get.pstatesm()
+  ret i64 %pstate
+}
+
 declare i64 @llvm.aarch64.sme.get.pstatesm()

diff  --git a/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll
new file mode 100644
index 0000000000000..d88deec40ce72
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll
@@ -0,0 +1,37 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs -stop-after=finalize-isel < %s | FileCheck %s --check-prefix=CHECK-CSRMASK
+
+; Test that the PCS attribute is accepted and uses the correct register mask.
+;
+
+define void @test_sme_calling_convention_x0() nounwind {
+; CHECK-LABEL: test_sme_calling_convention_x0:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_tpidr2_save
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-CSRMASK-LABEL: name: test_sme_calling_convention_x0
+; CHECK-CSRMASK: BL @__arm_tpidr2_save, csr_aarch64_sme_abi_support_routines_preservemost_from_x0
+  call aarch64_sme_preservemost_from_x0 void @__arm_tpidr2_save()
+  ret void
+}
+
+define i64 @test_sme_calling_convention_x2() nounwind {
+; CHECK-LABEL: test_sme_calling_convention_x2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+;
+; CHECK-CSRMASK-LABEL: name: test_sme_calling_convention_x2
+; CHECK-CSRMASK: BL @__arm_sme_state, csr_aarch64_sme_abi_support_routines_preservemost_from_x2
+  %pstate = call aarch64_sme_preservemost_from_x2 {i64, i64} @__arm_sme_state()
+  %pstate.sm = extractvalue {i64, i64} %pstate, 0
+  ret i64 %pstate.sm
+}
+
+declare void @__arm_tpidr2_save()
+declare {i64, i64} @__arm_sme_state()


        


More information about the llvm-commits mailing list