[llvm] [AArch64][SME] Conditionally do smstart/smstop (PR #77113)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 5 08:02:16 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Matthew Devereau (MDevereau)
<details>
<summary>Changes</summary>
This patch adds conditional enabling/disabling of streaming mode for functions which have both the aarch64_pstate_sm_compatible and aarch64_pstate_sm_body attributes.
This combination allows callees to determine if switching streaming mode is required instead of relying on the caller.
---
Full diff: https://github.com/llvm/llvm-project/pull/77113.diff
3 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+26-10)
- (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+5)
- (added) llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll (+177)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 50658a855cfb37..72ae574efb6191 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4855,14 +4855,14 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
SMEAttrs Attrs, SDLoc DL,
EVT VT) const {
- if (Attrs.hasStreamingInterfaceOrBody())
+ if (Attrs.hasStreamingInterfaceOrBody() &&
+ !Attrs.hasStreamingCompatibleInterface())
return DAG.getConstant(1, DL, VT);
- if (Attrs.hasNonStreamingInterfaceAndBody())
+ if (Attrs.hasNonStreamingInterfaceAndBody() &&
+ !Attrs.hasStreamingCompatibleInterface())
return DAG.getConstant(0, DL, VT);
- assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
-
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
getPointerTy(DAG.getDataLayout()));
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
@@ -6888,9 +6888,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
- Chain =
- changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
- DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
+ SDValue PStateSM;
+ if (Attrs.hasStreamingCompatibleInterface()) {
+ PStateSM = getPStateSM(DAG, Chain, Attrs, DL, MVT::i64);
+ Register Reg = MF.getRegInfo().createVirtualRegister(
+ getRegClassFor(PStateSM.getValueType().getSimpleVT()));
+ FuncInfo->setPStateSMReg(Reg);
+ Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+ } else {
+ PStateSM = DAG.getConstant(0, DL, MVT::i64);
+ }
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
+ /*Entry*/ true);
// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
@@ -8201,9 +8210,16 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs(MF.getFunction());
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
- Chain = changeStreamingMode(
- DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
- DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+ SDValue PStateSM;
+ if (FuncAttrs.hasStreamingCompatibleInterface()) {
+ Register Reg = FuncInfo->getPStateSMReg();
+ assert(Reg.isValid() && "PStateSM Register is invalid");
+ PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+ } else {
+ PStateSM = DAG.getConstant(1, DL, MVT::i64);
+ }
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(), PStateSM, /*Entry*/ true);
Glue = Chain.getValue(1);
}
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index cd4a18bfbc23a8..096fde364a2dcc 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -208,6 +208,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t StackProbeSize = 0;
+ Register PStateSMReg = MCRegister::NoRegister;
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -216,6 +218,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;
+ Register getPStateSMReg() const { return PStateSMReg; };
+ void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
+
bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll
new file mode 100644
index 00000000000000..16375041e4298f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible.ll
@@ -0,0 +1,177 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+declare void @normal_callee();
+declare void @streaming_callee() "aarch64_pstate_sm_enabled";
+declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";
+
+define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_sm_compatible_simple:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 80
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: .cfi_offset b8, -24
+; CHECK-NEXT: .cfi_offset b9, -32
+; CHECK-NEXT: .cfi_offset b10, -40
+; CHECK-NEXT: .cfi_offset b11, -48
+; CHECK-NEXT: .cfi_offset b12, -56
+; CHECK-NEXT: .cfi_offset b13, -64
+; CHECK-NEXT: .cfi_offset b14, -72
+; CHECK-NEXT: .cfi_offset b15, -80
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x8, x0, #0x1
+; CHECK-NEXT: tbnz w8, #0, .LBB0_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB0_2:
+; CHECK-NEXT: tbz w8, #0, .LBB0_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB0_4:
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: fmov s0, wzr
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ ret float zeroinitializer
+}
+
+define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-96]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT: stp x20, x19, [sp, #80] // 16-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 96
+; CHECK-NEXT: .cfi_offset w19, -8
+; CHECK-NEXT: .cfi_offset w20, -16
+; CHECK-NEXT: .cfi_offset w30, -32
+; CHECK-NEXT: .cfi_offset b8, -40
+; CHECK-NEXT: .cfi_offset b9, -48
+; CHECK-NEXT: .cfi_offset b10, -56
+; CHECK-NEXT: .cfi_offset b11, -64
+; CHECK-NEXT: .cfi_offset b12, -72
+; CHECK-NEXT: .cfi_offset b13, -80
+; CHECK-NEXT: .cfi_offset b14, -88
+; CHECK-NEXT: .cfi_offset b15, -96
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x19, x0, #0x1
+; CHECK-NEXT: tbnz w19, #0, .LBB1_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB1_2:
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x20, x0, #0x1
+; CHECK-NEXT: tbz w20, #0, .LBB1_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB1_4:
+; CHECK-NEXT: bl normal_callee
+; CHECK-NEXT: tbz w20, #0, .LBB1_6
+; CHECK-NEXT: // %bb.5:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB1_6:
+; CHECK-NEXT: tbz w19, #0, .LBB1_8
+; CHECK-NEXT: // %bb.7:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB1_8:
+; CHECK-NEXT: ldp x20, x19, [sp, #80] // 16-byte Folded Reload
+; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #96 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @normal_callee()
+ ret void
+}
+
+define void @sm_body_caller_sm_compatible_caller_streaming_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_streaming_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 80
+; CHECK-NEXT: .cfi_offset w19, -8
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: .cfi_offset b8, -24
+; CHECK-NEXT: .cfi_offset b9, -32
+; CHECK-NEXT: .cfi_offset b10, -40
+; CHECK-NEXT: .cfi_offset b11, -48
+; CHECK-NEXT: .cfi_offset b12, -56
+; CHECK-NEXT: .cfi_offset b13, -64
+; CHECK-NEXT: .cfi_offset b14, -72
+; CHECK-NEXT: .cfi_offset b15, -80
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x19, x0, #0x1
+; CHECK-NEXT: tbnz w19, #0, .LBB2_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB2_2:
+; CHECK-NEXT: bl streaming_callee
+; CHECK-NEXT: tbz w19, #0, .LBB2_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB2_4:
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @streaming_callee()
+ ret void
+}
+
+define void @sm_body_caller_sm_compatible_caller_streaming_compatible_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_streaming_compatible_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 80
+; CHECK-NEXT: .cfi_offset w19, -8
+; CHECK-NEXT: .cfi_offset w30, -16
+; CHECK-NEXT: .cfi_offset b8, -24
+; CHECK-NEXT: .cfi_offset b9, -32
+; CHECK-NEXT: .cfi_offset b10, -40
+; CHECK-NEXT: .cfi_offset b11, -48
+; CHECK-NEXT: .cfi_offset b12, -56
+; CHECK-NEXT: .cfi_offset b13, -64
+; CHECK-NEXT: .cfi_offset b14, -72
+; CHECK-NEXT: .cfi_offset b15, -80
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x19, x0, #0x1
+; CHECK-NEXT: tbnz w19, #0, .LBB3_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB3_2:
+; CHECK-NEXT: bl streaming_compatible_callee
+; CHECK-NEXT: tbz w19, #0, .LBB3_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB3_4:
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @streaming_compatible_callee()
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/77113
More information about the llvm-commits
mailing list