[llvm] e639e7e - [AArch64] NFC: Simplify the smstart/smstop pseudo. (#85067)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 15 01:51:02 PDT 2024


Author: Sander de Smalen
Date: 2024-03-15T08:50:58Z
New Revision: e639e7e986e0c1dcb5af3de65548d8518eb685a6

URL: https://github.com/llvm/llvm-project/commit/e639e7e986e0c1dcb5af3de65548d8518eb685a6
DIFF: https://github.com/llvm/llvm-project/commit/e639e7e986e0c1dcb5af3de65548d8518eb685a6.diff

LOG: [AArch64] NFC: Simplify the smstart/smstop pseudo. (#85067)

This is just a bit of cleanup to make the pseudo/code easier to
understand. This is based on the observation that we only need to pass
in a runtime value for 'pstate' if is actually needed for generating a
runtime check.

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
    llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index b2c52b443753dc..03f0778bae59d5 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -987,7 +987,7 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // Expand the pseudo into smstart or smstop instruction. The pseudo has the
   // following operands:
   //
-  //   MSRpstatePseudo <za|sm|both>, <0|1>, pstate.sm, expectedval, <regmask>
+  //   MSRpstatePseudo <za|sm|both>, <0|1>, condition[, pstate.sm], <regmask>
   //
   // The pseudo is expanded into a conditional smstart/smstop, with a
   // check if pstate.sm (register) equals the expected value, and if not,
@@ -997,9 +997,9 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // streaming-compatible function:
   //
   // OrigBB:
-  //   MSRpstatePseudo 3, 0, %0, 0, <regmask>             <- Conditional SMSTOP
+  //   MSRpstatePseudo 3, 0, IfCallerIsStreaming, %0, <regmask>  <- Cond SMSTOP
   //   bl @normal_callee
-  //   MSRpstatePseudo 3, 1, %0, 0, <regmask>             <- Conditional SMSTART
+  //   MSRpstatePseudo 3, 1, IfCallerIsStreaming, %0, <regmask>  <- Cond SMSTART
   //
   // ...which will be transformed into:
   //
@@ -1022,11 +1022,20 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   // We test the live value of pstate.sm and toggle pstate.sm if this is not the
   // expected value for the callee (0 for a normal callee and 1 for a streaming
   // callee).
-  auto PStateSM = MI.getOperand(2).getReg();
+  unsigned Opc;
+  switch (MI.getOperand(2).getImm()) {
+  case AArch64SME::Always:
+    llvm_unreachable("Should have matched to instruction directly");
+  case AArch64SME::IfCallerIsStreaming:
+    Opc = AArch64::TBNZW;
+    break;
+  case AArch64SME::IfCallerIsNonStreaming:
+    Opc = AArch64::TBZW;
+    break;
+  }
+  auto PStateSM = MI.getOperand(3).getReg();
   auto TRI = MBB.getParent()->getSubtarget().getRegisterInfo();
   unsigned SMReg32 = TRI->getSubReg(PStateSM, AArch64::sub_32);
-  bool IsStreamingCallee = MI.getOperand(3).getImm();
-  unsigned Opc = IsStreamingCallee ? AArch64::TBZW : AArch64::TBNZW;
   MachineInstrBuilder Tbx =
       BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9665ae5ceb903f..7720e0a46c8e93 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5270,13 +5270,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
         AArch64ISD::SMSTART, DL, MVT::Other,
         Op->getOperand(0), // Chain
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
   case Intrinsic::aarch64_sme_za_disable:
     return DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other,
         Op->getOperand(0), // Chain
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
   }
 }
 
@@ -7197,11 +7197,11 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
           getRegClassFor(PStateSM.getValueType().getSimpleVT()));
       FuncInfo->setPStateSMReg(Reg);
       Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
-    } else {
-      PStateSM = DAG.getConstant(0, DL, MVT::i64);
-    }
-    Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
-                                /*Entry*/ true);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
+    } else
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
+                                  AArch64SME::Always);
 
     // Ensure that the SMSTART happens after the CopyWithChain such that its
     // chain result is used.
@@ -7786,9 +7786,11 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
   }
 }
 
-SDValue AArch64TargetLowering::changeStreamingMode(
-    SelectionDAG &DAG, SDLoc DL, bool Enable,
-    SDValue Chain, SDValue InGlue, SDValue PStateSM, bool Entry) const {
+SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
+                                                   bool Enable, SDValue Chain,
+                                                   SDValue InGlue,
+                                                   unsigned Condition,
+                                                   SDValue PStateSM) const {
   MachineFunction &MF = DAG.getMachineFunction();
   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
   FuncInfo->setHasStreamingModeChanges(true);
@@ -7797,10 +7799,13 @@ SDValue AArch64TargetLowering::changeStreamingMode(
   SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
   SDValue MSROp =
       DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
-
-  SDValue ExpectedSMVal =
-      DAG.getTargetConstant(Entry ? Enable : !Enable, DL, MVT::i64);
-  SmallVector<SDValue> Ops = {Chain, MSROp, PStateSM, ExpectedSMVal, RegMask};
+  SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
+  SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
+  if (Condition != AArch64SME::Always) {
+    assert(PStateSM && "PStateSM should be defined");
+    Ops.push_back(PStateSM);
+  }
+  Ops.push_back(RegMask);
 
   if (InGlue)
     Ops.push_back(InGlue);
@@ -7809,6 +7814,19 @@ SDValue AArch64TargetLowering::changeStreamingMode(
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
+static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
+                               const SMEAttrs &CalleeAttrs) {
+  if (!CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasStreamingBody())
+    return AArch64SME::Always;
+  if (CalleeAttrs.hasNonStreamingInterface())
+    return AArch64SME::IfCallerIsStreaming;
+  if (CalleeAttrs.hasStreamingInterface())
+    return AArch64SME::IfCallerIsNonStreaming;
+
+  llvm_unreachable("Unsupported attributes");
+}
+
 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
 /// and add input and output parameter nodes.
 SDValue
@@ -8028,7 +8046,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     Chain = DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
 
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
@@ -8299,9 +8317,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   SDValue InGlue;
   if (RequiresSMChange) {
-    SDValue NewChain =
-        changeStreamingMode(DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain,
-                            InGlue, PStateSM, true);
+    SDValue NewChain = changeStreamingMode(
+        DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
     Chain = NewChain.getValue(0);
     InGlue = NewChain.getValue(1);
   }
@@ -8455,8 +8473,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   if (RequiresSMChange) {
     assert(PStateSM && "Expected a PStateSM to be set");
-    Result = changeStreamingMode(DAG, DL, !CalleeAttrs.hasStreamingInterface(),
-                                 Result, InGlue, PStateSM, false);
+    Result = changeStreamingMode(
+        DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
+        getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
   }
 
   if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
@@ -8464,7 +8483,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
     Result = DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other, Result,
         DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(0, DL, MVT::i64), DAG.getConstant(1, DL, MVT::i64));
+        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
 
   if (ShouldPreserveZT0)
     Result =
@@ -8599,13 +8618,12 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
       Register Reg = FuncInfo->getPStateSMReg();
       assert(Reg.isValid() && "PStateSM Register is invalid");
       SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
-      Chain =
-          changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
-                              /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                                  /*Glue*/ SDValue(),
+                                  AArch64SME::IfCallerIsNonStreaming, PStateSM);
     } else
-      Chain = changeStreamingMode(
-          DAG, DL, /*Enable*/ false, Chain,
-          /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+      Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                                  /*Glue*/ SDValue(), AArch64SME::Always);
     Glue = Chain.getValue(1);
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 68341c199e0a2a..89016cbf56e39e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -968,12 +968,12 @@ class AArch64TargetLowering : public TargetLowering {
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
-  /// function call it emits and returns the corresponding SMSTART or SMSTOP node.
-  /// \p Entry tells whether this is before/after the Call, which is necessary
-  /// because PSTATE.SM is only queried once.
+  /// function call it emits and returns the corresponding SMSTART or SMSTOP
+  /// node. \p Condition should be one of the enum values from
+  /// AArch64SME::ToggleCondition.
   SDValue changeStreamingMode(SelectionDAG &DAG, SDLoc DL, bool Enable,
-                              SDValue Chain, SDValue InGlue,
-                              SDValue PStateSM, bool Entry) const;
+                              SDValue Chain, SDValue InGlue, unsigned Condition,
+                              SDValue PStateSM = SDValue()) const;
 
   bool isVScaleKnownToBeAPowerOfTwo() const override { return true; }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 2907ba74ff8108..1554f1c92b5bbb 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,12 +10,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 3,
-                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 3,
-                             [SDTCisInt<0>, SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
+                             [SDTCisInt<0>, SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
 def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -158,34 +158,6 @@ def : Pat<(AArch64_restore_za
             (i64 GPR64:$tpidr2_el0), (i64 GPR64sp:$tpidr2obj), (i64 texternalsym:$restore_routine)),
           (RestoreZAPseudo GPR64:$tpidr2_el0, GPR64sp:$tpidr2obj, texternalsym:$restore_routine)>;
 
-// Scenario A:
-//
-//   %pstate.before.call = 1
-//   if (%pstate.before.call != 0)
-//     smstop (pstate_za|pstate_sm)
-//   call fn()
-//   if (%pstate.before.call != 0)
-//     smstart (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 1), (i64 0)),   // before call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 1), (i64 0)),  // after call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-
-// Scenario B:
-//
-//   %pstate.before.call = 0
-//   if (%pstate.before.call != 1)
-//     smstart (pstate_za|pstate_sm)
-//   call fn()
-//   if (%pstate.before.call != 1)
-//     smstop (pstate_za|pstate_sm)
-//
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 0), (i64 1)),  // before call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 0), (i64 1)),   // after call
-          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
-
 // Read and write TPIDR2_EL0
 def : Pat<(int_aarch64_sme_set_tpidr2 i64:$val),
           (MSR 0xde85, GPR64:$val)>;
@@ -230,17 +202,24 @@ defm COALESCER_BARRIER : CoalescerBarriers;
 // SME instructions.
 def MSRpstatePseudo :
   Pseudo<(outs),
-           (ins svcr_op:$pstatefield, timm0_1:$imm, GPR64:$rtpstate, timm0_1:$expected_pstate, variable_ops), []>,
+           (ins svcr_op:$pstatefield, timm0_1:$imm, timm0_31:$condition, variable_ops), []>,
     Sched<[WriteSys]> {
   let hasPostISelHook = 1;
   let Uses = [VG];
   let Defs = [VG];
 }
 
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b1, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 GPR64:$rtpstate), (i64 timm0_1:$expected_pstate)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b0, GPR64:$rtpstate, timm0_1:$expected_pstate)>;
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
+
+// Unconditional start/stop
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+          (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
+
 
 //===----------------------------------------------------------------------===//
 // SME2 Instructions

diff  --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
index ed8336a2e8ad34..f821bb527aedb8 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h
@@ -591,6 +591,14 @@ namespace AArch64BTIHint {
   #include "AArch64GenSystemOperands.inc"
 }
 
+namespace AArch64SME {
+enum ToggleCondition : unsigned {
+  Always,
+  IfCallerIsStreaming,
+  IfCallerIsNonStreaming
+};
+}
+
 namespace AArch64SE {
     enum ShiftExtSpecifiers {
         Invalid = -1,


        


More information about the llvm-commits mailing list