[llvm] 51e3d2f - [AArch64][SME] Conditionally do smstart/smstop (#77113)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 01:17:27 PST 2024


Author: Matthew Devereau
Date: 2024-01-18T09:17:23Z
New Revision: 51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4

URL: https://github.com/llvm/llvm-project/commit/51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4
DIFF: https://github.com/llvm/llvm-project/commit/51e3d2f73d8f3f5c70f0c1b6b73f62ec9c680cb4.diff

LOG: [AArch64][SME] Conditionally do smstart/smstop (#77113)

This patch adds conditional enabling/disabling of streaming mode for
functions which have both the aarch64_pstate_sm_compatible and
aarch64_pstate_sm_body attributes.

This combination allows callees to determine if switching streaming mode
is required instead of relying on the caller.

Added: 
    llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a3035bf78da27f..8a6f1dc7487bae 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4854,17 +4854,9 @@ static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
   return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
 }
 
-SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
-                                           SMEAttrs Attrs, SDLoc DL,
-                                           EVT VT) const {
-  if (Attrs.hasStreamingInterfaceOrBody())
-    return DAG.getConstant(1, DL, VT);
-
-  if (Attrs.hasNonStreamingInterfaceAndBody())
-    return DAG.getConstant(0, DL, VT);
-
-  assert(Attrs.hasStreamingCompatibleInterface() && "Unexpected interface");
-
+SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
+                                                  SDValue Chain, SDLoc DL,
+                                                  EVT VT) const {
   SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
                                          getPointerTy(DAG.getDataLayout()));
   Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
@@ -6892,9 +6884,18 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
   // Insert the SMSTART if this is a locally streaming function and
   // make sure it is Glued to the last CopyFromReg value.
   if (IsLocallyStreaming) {
-    Chain =
-        changeStreamingMode(DAG, DL, /*Enable*/ true, DAG.getRoot(), Glue,
-                            DAG.getConstant(0, DL, MVT::i64), /*Entry*/ true);
+    SDValue PStateSM;
+    if (Attrs.hasStreamingCompatibleInterface()) {
+      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
+      Register Reg = MF.getRegInfo().createVirtualRegister(
+          getRegClassFor(PStateSM.getValueType().getSimpleVT()));
+      FuncInfo->setPStateSMReg(Reg);
+      Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
+    } else {
+      PStateSM = DAG.getConstant(0, DL, MVT::i64);
+    }
+    Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue, PStateSM,
+                                /*Entry*/ true);
 
     // Ensure that the SMSTART happens after the CopyWithChain such that its
     // chain result is used.
@@ -7652,7 +7653,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   std::optional<bool> RequiresSMChange =
       CallerAttrs.requiresSMChange(CalleeAttrs);
   if (RequiresSMChange) {
-    PStateSM = getPStateSM(DAG, Chain, CallerAttrs, DL, MVT::i64);
+    if (CallerAttrs.hasStreamingInterfaceOrBody())
+      PStateSM = DAG.getConstant(1, DL, MVT::i64);
+    else if (CallerAttrs.hasNonStreamingInterface())
+      PStateSM = DAG.getConstant(0, DL, MVT::i64);
+    else
+      PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
     OptimizationRemarkEmitter ORE(&MF.getFunction());
     ORE.emit([&]() {
       auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
@@ -8205,9 +8211,17 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
   // Emit SMSTOP before returning from a locally streaming function
   SMEAttrs FuncAttrs(MF.getFunction());
   if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
-    Chain = changeStreamingMode(
-        DAG, DL, /*Enable*/ false, Chain, /*Glue*/ SDValue(),
-        DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
+    if (FuncAttrs.hasStreamingCompatibleInterface()) {
+      Register Reg = FuncInfo->getPStateSMReg();
+      assert(Reg.isValid() && "PStateSM Register is invalid");
+      SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
+      Chain =
+          changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
+                              /*Glue*/ SDValue(), PStateSM, /*Entry*/ false);
+    } else
+      Chain = changeStreamingMode(
+          DAG, DL, /*Enable*/ false, Chain,
+          /*Glue*/ SDValue(), DAG.getConstant(1, DL, MVT::i64), /*Entry*/ true);
     Glue = Chain.getValue(1);
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 1fd639b4f7ee8f..6047a3b7b2864a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1290,10 +1290,10 @@ class AArch64TargetLowering : public TargetLowering {
   // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
-  // Returns the runtime value for PSTATE.SM. When the function is streaming-
-  // compatible, this generates a call to __arm_sme_state.
-  SDValue getPStateSM(SelectionDAG &DAG, SDValue Chain, SMEAttrs Attrs,
-                      SDLoc DL, EVT VT) const;
+  // Returns the runtime value for PSTATE.SM by generating a call to
+  // __arm_sme_state.
+  SDValue getRuntimePStateSM(SelectionDAG &DAG, SDValue Chain, SDLoc DL,
+                             EVT VT) const;
 
   bool preferScalarizeSplat(SDNode *N) const override;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index cd4a18bfbc23a8..d5941e6284111a 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -208,6 +208,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
 
   int64_t StackProbeSize = 0;
 
+  // Holds a register containing pstate.sm. This is set
+  // on function entry to record the initial pstate of a function.
+  Register PStateSMReg = MCRegister::NoRegister;
+
 public:
   AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);
 
@@ -216,6 +220,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
         const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
       const override;
 
+  Register getPStateSMReg() const { return PStateSMReg; };
+  void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
+
   bool isSVECC() const { return IsSVECC; };
   void setIsSVECC(bool s) { IsSVECC = s; };
 

diff  --git a/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll
new file mode 100644
index 00000000000000..d67573384ca959
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-body-streaming-compatible-interface.ll
@@ -0,0 +1,124 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -start-after=simplifycfg -enable-tail-merge=false -verify-machineinstrs < %s | FileCheck %s
+
+declare void @normal_callee();
+declare void @streaming_callee() "aarch64_pstate_sm_enabled";
+declare void @streaming_compatible_callee() "aarch64_pstate_sm_compatible";
+
+define float @sm_body_sm_compatible_simple() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: sm_body_sm_compatible_simple:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    str x30, [sp, #64] // 8-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x8, x0, #0x1
+; CHECK-NEXT:    tbnz w8, #0, .LBB0_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    tbnz w8, #0, .LBB0_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB0_4:
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    fmov s0, wzr
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldr x30, [sp, #64] // 8-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  ret float zeroinitializer
+}
+
+define void @sm_body_caller_sm_compatible_caller_normal_callee() "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: sm_body_caller_sm_compatible_caller_normal_callee:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB1_2
+; CHECK-NEXT:  // %bb.1:
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB1_2:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    tbnz w19, #0, .LBB1_4
+; CHECK-NEXT:  // %bb.3:
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB1_4:
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @normal_callee()
+  ret void
+}
+
+; Function Attrs: nounwind uwtable vscale_range(1,16)
+define void @streaming_body_and_streaming_compatible_interface_multi_basic_block(i32 noundef %x) "aarch64_pstate_sm_compatible" "aarch64_pstate_sm_body" nounwind {
+; CHECK-LABEL: streaming_body_and_streaming_compatible_interface_multi_basic_block:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    stp d15, d14, [sp, #-80]! // 16-byte Folded Spill
+; CHECK-NEXT:    stp d13, d12, [sp, #16] // 16-byte Folded Spill
+; CHECK-NEXT:    mov w8, w0
+; CHECK-NEXT:    stp d11, d10, [sp, #32] // 16-byte Folded Spill
+; CHECK-NEXT:    stp d9, d8, [sp, #48] // 16-byte Folded Spill
+; CHECK-NEXT:    stp x30, x19, [sp, #64] // 16-byte Folded Spill
+; CHECK-NEXT:    bl __arm_sme_state
+; CHECK-NEXT:    and x19, x0, #0x1
+; CHECK-NEXT:    tbnz w19, #0, .LBB2_2
+; CHECK-NEXT:  // %bb.1: // %entry
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:  .LBB2_2: // %entry
+; CHECK-NEXT:    cbz w8, .LBB2_6
+; CHECK-NEXT:  // %bb.3: // %if.else
+; CHECK-NEXT:    bl streaming_compatible_callee
+; CHECK-NEXT:    tbnz w19, #0, .LBB2_5
+; CHECK-NEXT:  // %bb.4: // %if.else
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB2_5: // %if.else
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB2_6: // %if.then
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:    bl normal_callee
+; CHECK-NEXT:    smstart sm
+; CHECK-NEXT:    tbnz w19, #0, .LBB2_8
+; CHECK-NEXT:  // %bb.7: // %if.then
+; CHECK-NEXT:    smstop sm
+; CHECK-NEXT:  .LBB2_8: // %if.then
+; CHECK-NEXT:    ldp x30, x19, [sp, #64] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d9, d8, [sp, #48] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d11, d10, [sp, #32] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d13, d12, [sp, #16] // 16-byte Folded Reload
+; CHECK-NEXT:    ldp d15, d14, [sp], #80 // 16-byte Folded Reload
+; CHECK-NEXT:    ret
+entry:
+  %cmp = icmp eq i32 %x, 0
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  tail call void @normal_callee()
+  br label %return
+
+if.else:                                          ; preds = %entry
+  tail call void @streaming_compatible_callee()
+  br label %return
+
+return:                                           ; preds = %if.else, %if.then
+  ret void
+}


        


More information about the llvm-commits mailing list