[llvm] [AArch64] NFC: Simplify the smstart/smstop pseudo. (PR #85067)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 13 05:12:08 PDT 2024
https://github.com/sdesmalen-arm created https://github.com/llvm/llvm-project/pull/85067
This is just a bit of cleanup to make the pseudo/code easier to understand. This is based on the observation that we only need to pass in a runtime value for 'pstate' if is actually needed for generating a runtime check.
>From 32a0fd374f8e0dc28fa2d26d8773da8217a1c16c Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 7 Mar 2024 15:38:01 +0000
Subject: [PATCH] [AArch64] NFC: Simplify the smstart/smstop pseudo.
This is just a bit of cleanup to make the pseudo/code easier to understand.
This is based on the observation that we only need to pass in a runtime
value for 'pstate' if is actually needed for generating a runtime check.
---
.../AArch64/AArch64ExpandPseudoInsts.cpp | 21 ++++--
.../Target/AArch64/AArch64ISelLowering.cpp | 72 ++++++++++++-------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 10 +--
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 53 +++++---------
.../Target/AArch64/Utils/AArch64BaseInfo.h | 8 +++
5 files changed, 89 insertions(+), 75 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index b2c52b443753dc..3afd48f7fb299c 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -987,7 +987,7 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// Expand the pseudo into smstart or smstop instruction. The pseudo has the
// following operands:
//
- // MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, expectedval, <regmask>
+ // MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, condition, <regmask>
//
// The pseudo is expanded into a conditional smstart/smstop, with a
// check if pstate.sm (register) equals the expected value, and if not,
@@ -997,9 +997,9 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// streaming-compatible function:
//
// OrigBB:
- // MSRpstatePseudo 3, 0, %0, 0, <regmask> <- Conditional SMSTOP
+ // MSRpstatePseudo 3, 0, %0, IfCallerIsStreaming, <regmask> <- Cond SMSTOP
// bl @normal_callee
- // MSRpstatePseudo 3, 1, %0, 0, <regmask> <- Conditional SMSTART
+ // MSRpstatePseudo 3, 1, %0, IfCallerIsStreaming, <regmask> <- Cond SMSTART
//
// ...which will be transformed into:
//
@@ -1022,11 +1022,20 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
// We test the live value of pstate.sm and toggle pstate.sm if this is not the
// expected value for the callee (0 for a normal callee and 1 for a streaming
// callee).
- auto PStateSM = MI.getOperand(2).getReg();
+ unsigned Opc;
+ switch (MI.getOperand(2).getImm()) {
+ case AArch64SME::Always:
+ llvm_unreachable("Should have matched to instruction directly");
+ case AArch64SME::IfCallerIsStreaming:
+ Opc = AArch64::TBNZW;
+ break;
+ case AArch64SME::IfCallerIsNonStreaming:
+ Opc = AArch64::TBZW;
+ break;
+ }
+ auto PStateSM = MI.getOperand(3).getReg();
auto TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
unsigned SMReg32 = TRI->getSubReg(PStateSM, AArch64::sub_32);
- bool IsStreamingCallee = MI.getOperand(3).getImm();
- unsigned Opc = IsStreamingCallee ? AArch64::TBZW : AArch64::TBNZW;
MachineInstrBuilder Tbx =
BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 054311d39e7b83..90c9f1fd11ff25 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5270,13 +5270,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
AArch64ISD::SMSTART, DL, MVT::Other,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
case Intrinsic::aarch64_sme_za_disable:
return DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other,
Op->getOperand(0), // Chain
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
}
}
@@ -7197,11 +7197,11 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
getRegClassFor(PStateSM.getValueType().getSimpleVT()));
FuncInfo->setPStateSMReg(Reg);
Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
- } else {
- PStateSM = DAG.getConstant(0, DL, MVT::i64);
- }
- Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
- /*Entry*/ true);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+ AArch64SME::IfCallerIsNonStreaming, PStateSM);
+ } else
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+ AArch64SME::Always);
// Ensure that the SMSTART happens after the CopyWithChain such that its
// chain result is used.
@@ -7776,9 +7776,11 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
}
}
-SDValue AArch64TargetLowering::changeStreamingMode(
- SelectionDAG &DAG, SDLoc DL, bool Enable,
- SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
+SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
+ bool Enable, SDValue Chain,
+ SDValue InGlue,
+ unsigned Condition,
+ SDValue PStateSM) const {
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
FuncInfo->setHasStreamingModeChanges(true);
@@ -7787,10 +7789,13 @@ SDValue AArch64TargetLowering::changeStreamingMode(
SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
SDValue MSROp =
DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
-
- SDValue ExpectedSMVal =
- DAG.getTargetConstant(Entry ? Enable : !Enable, DL, MVT::i64);
- SmallVector<SDValue> Ops = {Chain, MSROp, PStateSM, ExpectedSMVal, RegMask};
+ SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
+ SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
+ if (Condition != AArch64SME::Always) {
+ assert(PStateSM && "PStateSM should be defined");
+ Ops.push_back(PStateSM);
+ }
+ Ops.push_back(RegMask);
if (InGlue)
Ops.push_back(InGlue);
@@ -7799,6 +7804,19 @@ SDValue AArch64TargetLowering::changeStreamingMode(
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}
+static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
+ const SMEAttrs &CalleeAttrs) {
+ if (!CallerAttrs.hasStreamingCompatibleInterface() ||
+ CallerAttrs.hasStreamingBody())
+ return AArch64SME::Always;
+ if (CalleeAttrs.hasNonStreamingInterface())
+ return AArch64SME::IfCallerIsStreaming;
+ if (CalleeAttrs.hasStreamingInterface())
+ return AArch64SME::IfCallerIsNonStreaming;
+
+ llvm_unreachable("Unsupported attributes");
+}
+
/// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
/// and add input and output parameter nodes.
SDValue
@@ -8018,7 +8036,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Chain = DAG.getNode(
AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
@@ -8289,9 +8307,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
SDValue InGlue;
if (RequiresSMChange) {
- SDValue NewChain =
- changeStreamingMode(DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain,
- InGlue, PStateSM, true);
+ SDValue NewChain = changeStreamingMode(
+ DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
+ getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
Chain = NewChain.getValue(0);
InGlue = NewChain.getValue(1);
}
@@ -8445,8 +8463,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresSMChange) {
assert(PStateSM && "Expected a PStateSM to be set");
- Result = changeStreamingMode(DAG, DL, !CalleeAttrs.hasStreamingInterface(),
- Result, InGlue, PStateSM, false);
+ Result = changeStreamingMode(
+ DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
+ getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
}
if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
@@ -8454,7 +8473,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Result = DAG.getNode(
AArch64ISD::SMSTART, DL, MVT::Other, Result,
DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
- DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+ DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
if (ShouldPreserveZT0)
Result =
@@ -8589,13 +8608,12 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
Register Reg = FuncInfo->getPStateSMReg();
assert(Reg.isValid() && "PStateSM Register is invalid");
SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
- Chain =
- changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
- /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(),
+ AArch64SME::IfCallerIsNonStreaming, PStateSM);
} else
- Chain = changeStreamingMode(
- DAG, DL, /*Enable*/ false, Chain,
- /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+ Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+ /*Glue*/ SDValue(), AArch64SME::Always);
Glue = Chain.getValue(1);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 68341c199e0a2a..89016cbf56e39e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -968,12 +968,12 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
- /// function call it emits and returns the corresponding SMSTART or SMSTOP node.
- /// \p Entry tells whether this is before/after the Call, which is necessary
- /// because PSTATE.SM is only queried once.
+ /// function call it emits and returns the corresponding SMSTART or SMSTOP
+ /// node. \p Condition should be one of the enum values from
+ /// AArch64SME::ToggleCondition.
SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
- SDValue Chain, SDValue InGlue,
- SDValue PStateSM, bool Entry) const;
+ SDValue Chain, SDValue InGlue, unsigned Condition,
+ SDValue PStateSM = SDValue()) const;
bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 2907ba74ff8108..1554f1c92b5bbb 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,12 +10,12 @@
//
//===----------------------------------------------------------------------===//
-def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 3,
- [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<0>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
- [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
+ [SDTCisInt<0>, SDTCisInt<0>]>,
[SDNPHasChain, SDNPSideEffect, SDNPVariadic,
SDNPOptInGlue, SDNPOutGlue]>;
def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -158,34 +158,6 @@ def : Pat<(AArch64_restore_za
(i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)),
(RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>;
-// Scenario A:
-//
-// %pstate.before.call = 1
-// if (%pstate.before.call != 0)
-// smstop (pstate_za|pstate_sm)
-// call fn()
-// if (%pstate.before.call != 0)
-// smstart (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 1), (i64 0)), // before call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 1), (i64 0)), // after call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-
-// Scenario B:
-//
-// %pstate.before.call = 0
-// if (%pstate.before.call != 1)
-// smstart (pstate_za|pstate_sm)
-// call fn()
-// if (%pstate.before.call != 1)
-// smstop (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 0), (i64 1)), // before call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)), // after call
- (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-
// Read and write TPIDR2_EL0
def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
(MSR 0xde85, GPR64:$val)>;
@@ -230,17 +202,24 @@ defm COALESCER_BARRIER : CoalescerBarriers;
// SME instructions.
def MSRpstatePseudo :
Pseudo<(outs),
- (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
+ (ins svcr_op:$pstatefield, timm0_1:$imm, timm0_31:$condition, variable_ops), []>,
Sched<[WriteSys]> {
let hasPostISelHook = 1;
let Uses = [VG];
let Defs = [VG];
}
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
- (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
- (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+ (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+ (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
+
+// Unconditional start/stop
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+ (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+ (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+
//===----------------------------------------------------------------------===//
// SME2 Instructions
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
index ed8336a2e8ad34..f821bb527aedb8 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
@@ -591,6 +591,14 @@ namespace AArch64BTIHint {
#include "AArch64GenSystemOperands.inc"
}
+namespace AArch64SME {
+enum ToggleCondition : unsigned {
+ Always,
+ IfCallerIsStreaming,
+ IfCallerIsNonStreaming
+};
+}
+
namespace AArch64SE {
enum ShiftExtSpecifiers {
Invalid = -1,
More information about the llvm-commits
mailing list