[llvm] [AArch64] Split AArch64ISD::COND_SMSTART/STOP off AArch64::SMSTART/STOP (NFC) (PR #140711)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 20 04:11:09 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

The conditional variants of SMSTART/STOP currently take the current PStateSM as a variadic value. This is not supported by the verification added in #<!-- -->140472 (which requires variadic values to be of type Register or RegisterMask), so this patch splits the the conditional variants into new `COND_` nodes, where these extra parameters are fixed arguments.

Suggested in https://github.com/llvm/llvm-project/pull/140472#discussion_r2094635066

Part of #<!-- -->140472.

---
Full diff: https://github.com/llvm/llvm-project/pull/140711.diff


4 Files Affected:

- (modified) llvm/docs/AArch64SME.rst (+8-6) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+15-13) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+18-10) 


``````````diff
diff --git a/llvm/docs/AArch64SME.rst b/llvm/docs/AArch64SME.rst
index b5a01cb204b81..ac8ce32ddb9e6 100644
--- a/llvm/docs/AArch64SME.rst
+++ b/llvm/docs/AArch64SME.rst
@@ -213,12 +213,14 @@ Instruction Selection Nodes
 
 .. code-block:: none
 
-  AArch64ISD::SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
-  AArch64ISD::SMSTOP  Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
-
-The ``SMSTART/SMSTOP`` nodes take ``CurrentState`` and ``ExpectedState`` operand for
-the case of a conditional SMSTART/SMSTOP. The instruction will only be executed
-if CurrentState != ExpectedState.
+  AArch64ISD::SMSTART Chain, [SM|ZA|Both][, RegMask]
+  AArch64ISD::SMSTOP  Chain, [SM|ZA|Both][, RegMask]
+  AArch64ISD::COND_SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
+  AArch64ISD::COND_SMSTOP  Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask]
+
+The ``COND_SMSTART/COND_SMSTOP`` nodes additionally take ``CurrentState`` and
+``ExpectedState``, in this case the instruction will only be executed if
+``CurrentState != ExpectedState``.
 
 When ``CurrentState`` and ``ExpectedState`` can be evaluated at compile-time
 (i.e. they are both constants) then an unconditional ``smstart/smstop``
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 293292d47dd48..d1000dd64bdf7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2726,6 +2726,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::VG_RESTORE)
     MAKE_CASE(AArch64ISD::SMSTART)
     MAKE_CASE(AArch64ISD::SMSTOP)
+    MAKE_CASE(AArch64ISD::COND_SMSTART)
+    MAKE_CASE(AArch64ISD::COND_SMSTOP)
     MAKE_CASE(AArch64ISD::RESTORE_ZA)
     MAKE_CASE(AArch64ISD::RESTORE_ZT)
     MAKE_CASE(AArch64ISD::SAVE_ZT)
@@ -6033,14 +6035,12 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
     return DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other,
         Op->getOperand(0), // Chain
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
   case Intrinsic::aarch64_sme_za_disable:
     return DAG.getNode(
         AArch64ISD::SMSTOP, DL, MVT::Other,
         Op->getOperand(0), // Chain
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
   }
 }
 
@@ -8913,18 +8913,22 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
   SDValue MSROp =
       DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
-  SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
-  SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
+  SmallVector<SDValue> Ops = {Chain, MSROp};
+  unsigned Opcode;
   if (Condition != AArch64SME::Always) {
+    SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
+    Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP;
     assert(PStateSM && "PStateSM should be defined");
+    Ops.push_back(ConditionOp);
     Ops.push_back(PStateSM);
+  } else {
+    Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
   }
   Ops.push_back(RegMask);
 
   if (InGlue)
     Ops.push_back(InGlue);
 
-  unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
@@ -9189,9 +9193,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   if (DisableZA)
     Chain = DAG.getNode(
-        AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
+        AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
 
   // Adjust the stack pointer for the new arguments...
   // These operations are automatically eliminated by the prolog/epilog pass
@@ -9668,9 +9671,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   if (CallAttrs.requiresEnablingZAAfterCall())
     // Unconditionally resume ZA.
     Result = DAG.getNode(
-        AArch64ISD::SMSTART, DL, MVT::Other, Result,
-        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
-        DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
+        AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result,
+        DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32));
 
   if (ShouldPreserveZT0)
     Result =
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c1e6d70099fa5..59a9d7d179778 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -73,6 +73,8 @@ enum NodeType : unsigned {
 
   SMSTART,
   SMSTOP,
+  COND_SMSTART,
+  COND_SMSTOP,
   RESTORE_ZA,
   RESTORE_ZT,
   SAVE_ZT,
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 363ecee49c0f2..e7482da001074 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -10,12 +10,20 @@
 //
 //===----------------------------------------------------------------------===//
 
-def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2,
-                             [SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 1,
+                             [SDTCisInt<0>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
-def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2,
-                             [SDTCisInt<0>, SDTCisInt<0>]>,
+def AArch64_smstop  : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 1,
+                             [SDTCisInt<0>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
+                              SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64_cond_smstart : SDNode<"AArch64ISD::COND_SMSTART", SDTypeProfile<0, 3,
+                             [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>,
+                             [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
+                              SDNPOptInGlue, SDNPOutGlue]>;
+def AArch64_cond_smstop  : SDNode<"AArch64ISD::COND_SMSTOP", SDTypeProfile<0, 3,
+                             [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>,
                              [SDNPHasChain, SDNPSideEffect, SDNPVariadic,
                               SDNPOptInGlue, SDNPOutGlue]>;
 def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3,
@@ -305,15 +313,15 @@ def MSRpstatePseudo :
   let Defs = [VG];
 }
 
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)),
-          (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>;
+def : Pat<(AArch64_cond_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition, GPR64:$pstatesm)>;
+def : Pat<(AArch64_cond_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)),
+          (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition, GPR64:$pstatesm)>;
 
 // Unconditional start/stop
-def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+def : Pat<(AArch64_smstart (i32 svcr_op:$pstate)),
           (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>;
-def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)),
+def : Pat<(AArch64_smstop (i32 svcr_op:$pstate)),
           (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;
 
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/140711


More information about the llvm-commits mailing list