[llvm] b00c36c - [AArch64][SME] Implement ABI for calls to/from streaming functions.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 16 07:08:33 PDT 2022


Author: Sander de Smalen
Date: 2022-09-16T14:07:47Z
New Revision: b00c36c2958b250ce6031a9b9776b3011fae4f33

URL: https://github.com/llvm/llvm-project/commit/b00c36c2958b250ce6031a9b9776b3011fae4f33
DIFF: https://github.com/llvm/llvm-project/commit/b00c36c2958b250ce6031a9b9776b3011fae4f33.diff

LOG: [AArch64][SME] Implement ABI for calls to/from streaming functions.

This patch implements the ABI for calls from:

  Normal -> Streaming
  Normal -> Streaming-compatible
  Streaming -> Normal
  Streaming -> Streaming-compatible
  Streaming -> Streaming

The compiler inserts SMSTART/SMSTOP instructions before and after the call,
depending on the required transition.

More details about the SME attributes and design can be found
in D131562.

Reviewed By: aemerson

Differential Revision: https://reviews.llvm.org/D131576

Added: 
    llvm/test/CodeGen/AArch64/sme-streaming-interface.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64CallingConvention.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
    llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
    llvm/lib/Target/AArch64/AArch64RegisterInfo.h
    llvm/lib/Target/AArch64/AArch64RegisterInfo.td
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 0000d26f3044..e53f573de66c 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -451,6 +451,10 @@ def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2
                                                  (sequence "X%u",19, 28),
                                                  LR, FP)>;
 
+// The SMSTART/SMSTOP instructions preserve only GPR registers.
+def CSR_AArch64_SMStartStop : CalleeSavedRegs<(add (sequence "X%u", 0, 28),
+                                                   LR, FP)>;
+
 def CSR_AArch64_AAPCS_SwiftTail
     : CalleeSavedRegs<(sub CSR_AArch64_AAPCS, X20, X22)>;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f1b3b28c3370..84674703735e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2058,6 +2058,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
   switch ((AArch64ISD::NodeType)Opcode) {
   case AArch64ISD::FIRST_NUMBER:
     break;
+    MAKE_CASE(AArch64ISD::SMSTART)
+    MAKE_CASE(AArch64ISD::SMSTOP)
     MAKE_CASE(AArch64ISD::CALL)
     MAKE_CASE(AArch64ISD::ADRP)
     MAKE_CASE(AArch64ISD::ADR)
@@ -4517,6 +4519,17 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
                      Mask);
 }
 
+static Optional<SMEAttrs> getCalleeAttrsFromExternalFunction(SDValue V) {
+  if (auto *ES = dyn_cast<ExternalSymbolSDNode>(V)) {
+    StringRef S(ES->getSymbol());
+    if (S == "__arm_sme_state" || S == "__arm_tpidr2_save")
+      return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Preserved);
+    if (S == "__arm_tpidr2_restore")
+      return SMEAttrs(SMEAttrs::SM_Compatible | SMEAttrs::ZA_Shared);
+  }
+  return None;
+}
+
 SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
                                                       SelectionDAG &DAG) const {
   unsigned IntNo = Op.getConstantOperandVal(1);
@@ -6640,6 +6653,25 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
   return ZExtBool;
 }
 
+SDValue AArch64TargetLowering::changeStreamingMode(
+    SelectionDAG &DAG, SDLoc DL, bool Enable,
+    SDValue Chain, SDValue InFlag, SDValue PStateSM, bool Entry) const {
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
+  SDValue MSROp =
+      DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
+
+  SDValue ExpectedSMVal =
+      DAG.getTargetConstant(Entry ? Enable : !Enable, DL, MVT::i64);
+  SmallVector<SDValue> Ops = {Chain, MSROp, PStateSM, ExpectedSMVal, RegMask};
+
+  if (InFlag)
+    Ops.push_back(InFlag);
+
+  unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
+  return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
+}
+
 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
 /// and add input and output parameter nodes.
 SDValue
@@ -6760,6 +6792,19 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
   }
 
+  // Determine whether we need any streaming mode changes.
+  SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
+  if (CLI.CB)
+    CalleeAttrs = SMEAttrs(*CLI.CB);
+  else if (Optional<SMEAttrs> Attrs =
+               getCalleeAttrsFromExternalFunction(CLI.Callee))
+    CalleeAttrs = *Attrs;
+
+  SDValue InFlag, PStateSM;
+  Optional<bool> RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
+  if (RequiresSMChange)
+    PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
+
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
   if (!IsSibCall)
@@ -7011,9 +7056,15 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (!MemOpChains.empty())
     Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
 
+  if (RequiresSMChange) {
+    SDValue NewChain = changeStreamingMode(DAG, DL, *RequiresSMChange, Chain,
+                                           InFlag, PStateSM, true);
+    Chain = NewChain.getValue(0);
+    InFlag = NewChain.getValue(1);
+  }
+
   // Build a sequence of copy-to-reg nodes chained together with token chain
   // and flag operands which copy the outgoing args into the appropriate regs.
-  SDValue InFlag;
   for (auto &RegToPass : RegsToPass) {
     Chain = DAG.getCopyToReg(Chain, DL, RegToPass.first,
                              RegToPass.second, InFlag);
@@ -7143,14 +7194,36 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       DoesCalleeRestoreStack(CallConv, TailCallOpt) ? alignTo(NumBytes, 16) : 0;
 
   Chain = DAG.getCALLSEQ_END(Chain, NumBytes, CalleePopBytes, InFlag, DL);
-  if (!Ins.empty())
-    InFlag = Chain.getValue(1);
+  InFlag = Chain.getValue(1);
 
   // Handle result values, copying them out of physregs into vregs that we
   // return.
-  return LowerCallResult(Chain, InFlag, CallConv, IsVarArg, Ins, DL, DAG,
-                         InVals, IsThisReturn,
-                         IsThisReturn ? OutVals[0] : SDValue());
+  SDValue Result =
+      LowerCallResult(Chain, InFlag, CallConv, IsVarArg, Ins, DL, DAG, InVals,
+                      IsThisReturn, IsThisReturn ? OutVals[0] : SDValue());
+
+  if (!Ins.empty())
+    InFlag = Result.getValue(Result->getNumValues() - 1);
+
+  if (RequiresSMChange) {
+    assert(PStateSM && "Expected a PStateSM to be set");
+    Result = changeStreamingMode(DAG, DL, !*RequiresSMChange, Result, InFlag,
+                                 PStateSM, false);
+    for (unsigned I = 0; I < InVals.size(); ++I) {
+      // The smstart/smstop is chained as part of the call, but when the
+      // resulting chain is discarded (which happens when the call is not part
+      // of a chain, e.g. a call to @llvm.cos()), we need to ensure the
+      // smstart/smstop is chained to the result value. We can do that by doing
+      // a vreg -> vreg copy.
+      Register Reg = MF.getRegInfo().createVirtualRegister(
+          getRegClassFor(InVals[I].getValueType().getSimpleVT()));
+      SDValue X = DAG.getCopyToReg(Result, DL, Reg, InVals[I]);
+      InVals[I] = DAG.getCopyFromReg(X, DL, Reg,
+                                     InVals[I].getValueType());
+    }
+  }
+
+  return Result;
 }
 
 bool AArch64TargetLowering::CanLowerReturn(

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 7b1b0e5f3aa2..4875e786a110 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -58,6 +58,9 @@ enum NodeType : unsigned {
 
   CALL_BTI, // Function call followed by a BTI instruction.
 
+  SMSTART,
+  SMSTOP,
+
   // Produces the full sequence of instructions for getting the thread pointer
   // offset of a variable into X0, using the TLSDesc model.
   TLSDESC_CALLSEQ,
@@ -872,6 +875,14 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  /// If a change in streaming mode is required on entry to/return from a
+  /// function call it emits and returns the corresponding SMSTART or SMSTOP node.
+  /// \p Entry tells whether this is before/after the Call, which is necessary
+  /// because PSTATE.SM is only queried once.
+  SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
+                              SDValue Chain, SDValue InFlag,
+                              SDValue PStateSM, bool Entry) const;
+
 private:
   /// Keep a pointer to the AArch64Subtarget around so that we can
   /// make the right decision when generating code for 
diff erent targets.

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index bd135288f42e..35a37a051951 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -1093,6 +1093,9 @@ bool AArch64InstrInfo::isSchedulingBoundary(const MachineInstr &MI,
   case AArch64::ISB:
     // DSB and ISB also are scheduling barriers.
     return true;
+  case AArch64::MSRpstatesvcrImm1:
+    // SMSTART and SMSTOP are also scheduling barriers.
+    return true;
   default:;
   }
   if (isSEHInstruction(MI))

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 91b6d183fa2e..8c75cb94cd5b 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -321,6 +321,10 @@ void AArch64RegisterInfo::UpdateCustomCallPreservedMask(MachineFunction &MF,
   *Mask = UpdatedMask;
 }
 
+const uint32_t *AArch64RegisterInfo::getSMStartStopCallPreservedMask() const {
+  return CSR_AArch64_SMStartStop_RegMask;
+}
+
 const uint32_t *AArch64RegisterInfo::getNoPreservedMask() const {
   return CSR_AArch64_NoRegs_RegMask;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
index 3ec4a5d8cb35..666642d2f7ed 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
@@ -68,6 +68,8 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
   // normal calls, so they need a 
diff erent mask to represent this.
   const uint32_t *getTLSCallPreservedMask() const;
 
+  const uint32_t *getSMStartStopCallPreservedMask() const;
+
   // Funclets on ARM64 Windows don't preserve any registers.
   const uint32_t *getNoPreservedMask() const override;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index eeca8a646ff2..a2576155dbe4 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -1379,7 +1379,9 @@ def SVCROperand : AsmOperandClass {
   let DiagnosticType = "Invalid" # Name;
 }
 
-def svcr_op : Operand<i32> {
+def svcr_op : Operand<i32>, TImmLeaf<i32, [{
+    return AArch64SVCR::lookupSVCRByEncoding(Imm) != nullptr;
+  }]> {
   let ParserMatchClass = SVCROperand;
   let PrintMethod = "printSVCROp";
   let DecoderMethod = "DecodeSVCROp";

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index e595d20c8d4e..d309b363e6c4 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,6 +10,15 @@
 //
 //===----------------------------------------------------------------------===//
 
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 3,
+                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
+                              SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
+                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
+                              SDNPOptInGlue, SDNPOutGlue]>;
+
 //===----------------------------------------------------------------------===//
 // Add vector elements horizontally or vertically to ZA tile.
 //===----------------------------------------------------------------------===//
@@ -117,8 +126,8 @@ defm ZERO_M : sme_zero<"zero">;
 // It's tricky to using the existing pstate operand defined in
 // AArch64SystemOperands.td since it only encodes 5 bits including op1;op2,
 // when these fields are also encoded in CRm[3:1].
-class MSRpstatesvcrImm0_1
-  : PstateWriteSimple<(ins svcr_op:$pstatefield, imm0_1:$imm), "msr",
+def MSRpstatesvcrImm1
+  : PstateWriteSimple<(ins svcr_op:$pstatefield, timm0_1:$imm), "msr",
                       "\t$pstatefield, $imm">,
     Sched<[WriteSys]> {
   bits<3> pstatefield;
@@ -129,7 +138,6 @@ class MSRpstatesvcrImm0_1
   let Inst{7-5} = 0b011; // op2
 }
 
-def MSRpstatesvcrImm1 : MSRpstatesvcrImm0_1;
 def : InstAlias<"smstart",    (MSRpstatesvcrImm1 0b011, 0b1)>;
 def : InstAlias<"smstart sm", (MSRpstatesvcrImm1 0b001, 0b1)>;
 def : InstAlias<"smstart za", (MSRpstatesvcrImm1 0b010, 0b1)>;
@@ -138,6 +146,35 @@ def : InstAlias<"smstop",     (MSRpstatesvcrImm1 0b011, 0b0)>;
 def : InstAlias<"smstop sm",  (MSRpstatesvcrImm1 0b001, 0b0)>;
 def : InstAlias<"smstop za",  (MSRpstatesvcrImm1 0b010, 0b0)>;
 
+
+// Scenario A:
+//
+//   %pstate.before.call = 1
+//   if (%pstate.before.call != 0)
+//     smstop (pstate_za|pstate_sm)
+//   call fn()
+//   if (%pstate.before.call != 0)
+//     smstart (pstate_za|pstate_sm)
+//
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 1), (i64 0)),   // before call
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 1), (i64 0)),  // after call
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+
+// Scenario B:
+//
+//   %pstate.before.call = 0
+//   if (%pstate.before.call != 1)
+//     smstart (pstate_za|pstate_sm)
+//   call fn()
+//   if (%pstate.before.call != 1)
+//     smstop (pstate_za|pstate_sm)
+//
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 0), (i64 1)),  // before call
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)),   // after call
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+
 // Read and write TPIDR2_EL0
 def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
           (MSR 0xde85, GPR64:$val)>;

diff  --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
new file mode 100644
index 000000000000..5725caeb706f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
@@ -0,0 +1,340 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme  -verify-machineinstrs < %s | FileCheck %s
+
+; This file tests the following combinations related to streaming-enabled functions:
+; [ ] N  ->  S    (Normal -> Streaming)
+; [ ] S  ->  N    (Streaming -> Normal)
+; [ ] S  ->  S    (Streaming -> Streaming)
+; [ ] S  ->  SC   (Streaming -> Streaming-compatible)
+;
+; The following combination is tested in sme-streaming-compatible-interface.ll
+; [ ] SC ->  S    (Streaming-compatible -> Streaming)
+
+declare void @normal_callee()
+declare void @streaming_callee() "aarch64_pstate_sm_enabled"
+declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible"
+
+; [x] N  ->  S
+; [ ] S  ->  N
+; [ ] S  ->  S
+; [ ] S  ->  SC
+define void @normal_caller_streaming_callee() nounwind {
+; CHECK-LABEL: normal_caller_streaming_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  ret void;
+}
+
+; [ ] N  ->  S
+; [x] S  ->  N
+; [ ] S  ->  S
+; [ ] S  ->  SC
+define void @streaming_caller_normal_callee() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: streaming_caller_normal_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee()
+  ret void;
+}
+
+; [ ] N  ->  S
+; [ ] S  ->  N
+; [x] S  ->  S
+; [ ] S  ->  SC
+define void @streaming_caller_streaming_callee() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: streaming_caller_streaming_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  ret void;
+}
+
+; [ ] N  ->  S
+; [ ] S  ->  N
+; [ ] S  ->  S
+; [x] S  ->  SC
+define void @streaming_caller_streaming_compatible_callee() nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: streaming_caller_streaming_compatible_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl streaming_compatible_callee
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_compatible_callee()
+  ret void;
+}
+
+;
+; Handle special cases here.
+;
+
+; Call to function-pointer (with attribute)
+define void @call_to_function_pointer_streaming_enabled(ptr %p) nounwind {
+; CHECK-LABEL: call_to_function_pointer_streaming_enabled:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    blr x0
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void %p() "aarch64_pstate_sm_enabled"
+  ret void
+}
+
+; Ensure NEON registers are preserved correctly.
+define <4 x i32> @smstart_clobber_simdfp(<4 x i32> %x) nounwind {
+; CHECK-LABEL: smstart_clobber_simdfp:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #96
+; CHECK-NEXT:    stp d15, d14, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    str q0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr q0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    add sp, sp, #96
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  ret <4 x i32> %x;
+}
+
+; Ensure SVE registers are preserved correctly.
+define <vscale x 4 x i32> @smstart_clobber_sve(<vscale x 4 x i32> %x) #0 {
+; CHECK-LABEL: smstart_clobber_sve:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-18
+; CHECK-NEXT:    str p15, [sp, #4, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p14, [sp, #5, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p13, [sp, #6, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p12, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p11, [sp, #8, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p10, [sp, #9, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p9, [sp, #10, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p8, [sp, #11, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p7, [sp, #12, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p6, [sp, #13, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p5, [sp, #14, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p4, [sp, #15, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str z23, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z22, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z21, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z20, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z19, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z18, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z17, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z16, [sp, #9, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z15, [sp, #10, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z14, [sp, #11, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z13, [sp, #12, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z12, [sp, #13, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z11, [sp, #14, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z10, [sp, #15, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z9, [sp, #16, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z8, [sp, #17, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr p15, [sp, #4, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p14, [sp, #5, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p13, [sp, #6, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p12, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p11, [sp, #8, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p10, [sp, #9, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p9, [sp, #10, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #11, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p7, [sp, #12, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p6, [sp, #13, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p5, [sp, #14, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p4, [sp, #15, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z21, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z20, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z19, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z18, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z17, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z16, [sp, #9, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z15, [sp, #10, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z14, [sp, #11, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z13, [sp, #12, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z12, [sp, #13, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z11, [sp, #14, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #15, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z9, [sp, #16, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z8, [sp, #17, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #18
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  ret <vscale x 4 x i32> %x;
+}
+
+; Call streaming callee twice; there should be no spills/fills between the two
+; calls since the registers should have already been clobbered.
+define <vscale x 4 x i32> @smstart_clobber_sve_duplicate(<vscale x 4 x i32> %x) #0 {
+; CHECK-LABEL: smstart_clobber_sve_duplicate:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-18
+; CHECK-NEXT:    str p15, [sp, #4, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p14, [sp, #5, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p13, [sp, #6, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p12, [sp, #7, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p11, [sp, #8, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p10, [sp, #9, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p9, [sp, #10, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p8, [sp, #11, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p7, [sp, #12, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p6, [sp, #13, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p5, [sp, #14, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str p4, [sp, #15, mul vl] // 2-byte Folded Spill
+; CHECK-NEXT:    str z23, [sp, #2, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z22, [sp, #3, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z21, [sp, #4, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z20, [sp, #5, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z19, [sp, #6, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z18, [sp, #7, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z17, [sp, #8, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z16, [sp, #9, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z15, [sp, #10, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z14, [sp, #11, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z13, [sp, #12, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z12, [sp, #13, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z11, [sp, #14, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z10, [sp, #15, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z9, [sp, #16, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    str z8, [sp, #17, mul vl] // 16-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    str z0, [sp] // 16-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    bl streaming_callee
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr z0, [sp] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr p15, [sp, #4, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p14, [sp, #5, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p13, [sp, #6, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p12, [sp, #7, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p11, [sp, #8, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p10, [sp, #9, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p9, [sp, #10, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p8, [sp, #11, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p7, [sp, #12, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p6, [sp, #13, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p5, [sp, #14, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr p4, [sp, #15, mul vl] // 2-byte Folded Reload
+; CHECK-NEXT:    ldr z23, [sp, #2, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z22, [sp, #3, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z21, [sp, #4, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z20, [sp, #5, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z19, [sp, #6, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z18, [sp, #7, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z17, [sp, #8, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z16, [sp, #9, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z15, [sp, #10, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z14, [sp, #11, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z13, [sp, #12, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z12, [sp, #13, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z11, [sp, #14, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z10, [sp, #15, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z9, [sp, #16, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr z8, [sp, #17, mul vl] // 16-byte Folded Reload
+; CHECK-NEXT:    addvl sp, sp, #18
+; CHECK-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @streaming_callee()
+  call void @streaming_callee()
+  ret <vscale x 4 x i32> %x;
+}
+
+; Ensure smstart is not removed, because call to llvm.cos is not part of a chain.
+define double @call_to_intrinsic_without_chain(double %x) nounwind "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: call_to_intrinsic_without_chain:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sub sp, sp, #96
+; CHECK-NEXT:    stp d15, d14, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #80] // 8-byte Folded Spill
+; CHECK-NEXT:    str d0, [sp, #88] // 8-byte Folded Spill
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    ldr d0, [sp, #88] // 8-byte Folded Reload
+; CHECK-NEXT:    bl cos
+; CHECK-NEXT:    str d0, [sp, #8] // 8-byte Folded Spill
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    ldp d9, d8, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr d0, [sp, #88] // 8-byte Folded Reload
+; CHECK-NEXT:    ldr d1, [sp, #8] // 8-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #80] // 8-byte Folded Reload
+; CHECK-NEXT:    fadd d0, d1, d0
+; CHECK-NEXT:    add sp, sp, #96
+; CHECK-NEXT:    ret
+entry:
+  %res = call fast double @llvm.cos.f64(double %x)
+  %res.fadd = fadd fast double %res, %x
+  ret double %res.fadd
+}
+
+declare double @llvm.cos.f64(double)
+
+attributes #0 = { nounwind "target-features"="+sve" }


        


More information about the llvm-commits mailing list