[llvm] 51e3d2f - [AArch64][SME] Conditionally do smstart/smstop (#77113)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 18 01:17:27 PST 2024
Author: Matthew Devereau
Date: 2024-01-18T09:17:23Z
New Revision: 51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4
URL: https://github.com/llvm/llvm-project/commit/51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4
DIFF: https://github.com/llvm/llvm-project/commit/51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4.diff
LOG: [AArch64][SME] Conditionally do smstart/smstop (#77113)
This patch adds conditional enabling/disabling of streaming mode for
functions which have both the aarch64_pstate_sm_compatible and
aarch64_pstate_sm_body attributes.
This combination allows callees to determine if switching streaming mode
is required instead of relying on the caller.
Added:
llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a3035bf78da27f..8a6f1dc7487bae 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4854,17 +4854,9 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
}
-SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
- SMEAttrs Attrs, SDLoc DL,
- EVT VT) const {
- if (Attrs.hasStreamingInterfaceOrBody())
- return DAG.getConstant(1, DL, VT);
-
- if (Attrs.hasNonStreamingInterfaceAndBody())
- return DAG.getConstant(0, DL, VT);
-
- assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
-
+SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
+ SDValue Chain, SDLoc DL,
+ EVT VT) const {
SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
getPointerTy(DAG.getDataLayout()));
Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
@@ -6892,9 +6884,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Insert the SMSTART if this is a locally streaming function and
// make sure it is Glued to the last CopyFromReg value.
if (IsLocallyStreaming) {
- Chain =
- changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
- DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
+ SDValue PStateSM;
+ if (Attrs.hasStreamingCompatibleInterface()) {
+ PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
+ Register Reg = MF.getRegInfo().createVirtualRegister(
+ getRegClassFor(PStateSM.getValueType().getSimpleVT()));
+ FuncInfo->setPStateSMReg(Reg);
+ Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+ } else {
+ PStateSM = DAG.getConstant(0, DL, MVT::i64);
+ }
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
+ /*Entry*/ true);
// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
@@ -7652,7 +7653,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
std::optional<bool> RequiresSMChange =
CallerAttrs.requiresSMChange(CalleeAttrs);
if (RequiresSMChange) {
- PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
+ if (CallerAttrs.hasStreamingInterfaceOrBody())
+ PStateSM = DAG.getConstant(1, DL, MVT::i64);
+ else if (CallerAttrs.hasNonStreamingInterface())
+ PStateSM = DAG.getConstant(0, DL, MVT::i64);
+ else
+ PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
OptimizationRemarkEmitter ORE(&MF.getFunction());
ORE.emit([&]() {
auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -8205,9 +8211,17 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
// Emit SMSTOP before returning from a locally streaming function
SMEAttrs FuncAttrs(MF.getFunction());
if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
- Chain = changeStreamingMode(
- DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
- DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+ if (FuncAttrs.hasStreamingCompatibleInterface()) {
+ Register Reg = FuncInfo->getPStateSMReg();
+ assert(Reg.isValid() && "PStateSM Register is invalid");
+ SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+ Chain =
+ changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+ } else
+ Chain = changeStreamingMode(
+ DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
Glue = Chain.getValue(1);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1fd639b4f7ee8f..6047a3b7b2864a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1290,10 +1290,10 @@ class AArch64TargetLowering : public TargetLowering {
// This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
- // Returns the runtime value for PSTATE.SM. When the function is streaming-
- // compatible, this generates a call to __arm_sme_state.
- SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
- SDLoc DL, EVT VT) const;
+ // Returns the runtime value for PSTATE.SM by generating a call to
+ // __arm_sme_state.
+ SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
+ EVT VT) const;
bool preferScalarizeSplat(SDNode *N) const override;
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index cd4a18bfbc23a8..d5941e6284111a 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -208,6 +208,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
int64_t StackProbeSize = 0;
+ // Holds a register containing pstate.sm. This is set
+ // on function entry to record the initial pstate of a function.
+ Register PStateSMReg = MCRegister::NoRegister;
+
public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
@@ -216,6 +220,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;
+ Register getPStateSMReg() const { return PStateSMReg; };
+ void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
+
bool isSVECC() const { return IsSVECC; };
void setIsSVECC(bool s) { IsSVECC = s; };
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll
new file mode 100644
index 00000000000000..d67573384ca959
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll
@@ -0,0 +1,124 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+declare void @normal_callee();
+declare void @streaming_callee() "aarch64_pstate_sm_enabled";
+declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";
+
+define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: sm_body_sm_compatible_simple:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x8, x0, #0x1
+; CHECK-NEXT: tbnz w8, #0, .LBB0_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB0_2:
+; CHECK-NEXT: tbnz w8, #0, .LBB0_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB0_4:
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: fmov s0, wzr
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ ret float zeroinitializer
+}
+
+define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
+; CHECK: // %bb.0:
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x19, x0, #0x1
+; CHECK-NEXT: tbnz w19, #0, .LBB1_2
+; CHECK-NEXT: // %bb.1:
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB1_2:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: bl normal_callee
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: tbnz w19, #0, .LBB1_4
+; CHECK-NEXT: // %bb.3:
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB1_4:
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+ call void @normal_callee()
+ ret void
+}
+
+; Function Attrs: nounwind uwtable vscale_range(1,16)
+define void @streaming_body_and_streaming_compatible_interface_multi_basic_block(i32 noundef %x) "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: streaming_body_and_streaming_compatible_interface_multi_basic_block:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT: stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT: mov w8, w0
+; CHECK-NEXT: stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT: stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT: bl __arm_sme_state
+; CHECK-NEXT: and x19, x0, #0x1
+; CHECK-NEXT: tbnz w19, #0, .LBB2_2
+; CHECK-NEXT: // %bb.1: // %entry
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: .LBB2_2: // %entry
+; CHECK-NEXT: cbz w8, .LBB2_6
+; CHECK-NEXT: // %bb.3: // %if.else
+; CHECK-NEXT: bl streaming_compatible_callee
+; CHECK-NEXT: tbnz w19, #0, .LBB2_5
+; CHECK-NEXT: // %bb.4: // %if.else
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB2_5: // %if.else
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+; CHECK-NEXT: .LBB2_6: // %if.then
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: bl normal_callee
+; CHECK-NEXT: smstart sm
+; CHECK-NEXT: tbnz w19, #0, .LBB2_8
+; CHECK-NEXT: // %bb.7: // %if.then
+; CHECK-NEXT: smstop sm
+; CHECK-NEXT: .LBB2_8: // %if.then
+; CHECK-NEXT: ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT: ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT: ret
+entry:
+ %cmp = icmp eq i32 %x, 0
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then: ; preds = %entry
+ tail call void @normal_callee()
+ br label %return
+
+if.else: ; preds = %entry
+ tail call void @streaming_compatible_callee()
+ br label %return
+
+return: ; preds = %if.else, %if.then
+ ret void
+}
More information about the llvm-commits
mailing list