[clang] [llvm] [AArch64][SME] Disable inlining of callees with new ZT0 state (PR #121338)

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 31 01:39:05 PST 2024


https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/121338

>From 8b8f191d9c6980f7342c0bea2681ffd8d1dbe90b Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Mon, 30 Dec 2024 13:24:34 +0000
Subject: [PATCH 1/2] [AArch64][SME] Disable inlining of callees with new ZT0
 state

Inlining must be disabled for new-ZT0 callees as the callee is required
to save ZT0 and toggle PSTATE.ZA on entry.
---
 .../AArch64/AArch64TargetTransformInfo.cpp    |  2 +-
 .../Inline/AArch64/sme-pstateza-attrs.ll      | 30 +++++++++++++++++--
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 0566a875900127..82a5f8c61bd849 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -256,7 +256,7 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
     CalleeAttrs.set(SMEAttrs::SM_Enabled, true);
   }
 
-  if (CalleeAttrs.isNewZA())
+  if (CalleeAttrs.isNewZA() || CalleeAttrs.isNewZT0())
     return false;
 
   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
diff --git a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
index 7ffbd64c700aa2..4cd1491611be0f 100644
--- a/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
+++ b/llvm/test/Transforms/Inline/AArch64/sme-pstateza-attrs.ll
@@ -391,9 +391,33 @@ define void @nonzt0_callee() {
   ret void
 }
 
+define void @new_zt0_callee() "aarch64_new_zt0" {
+; CHECK-LABEL: define void @new_zt0_callee
+; CHECK-SAME: () #[[ATTR4:[0-9]+]] {
+; CHECK-NEXT:    call void asm sideeffect "
+; CHECK-NEXT:    call void @inlined_body()
+; CHECK-NEXT:    ret void
+;
+  call void asm sideeffect "; inlineasm", ""()
+  call void @inlined_body()
+  ret void
+}
+
+define void @nonzt0_caller_new_zt0_callee_dont_inline() {
+; CHECK-LABEL: define void @nonzt0_caller_new_zt0_callee_dont_inline
+; CHECK-SAME: () #[[ATTR0]] {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @new_zt0_callee()
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @new_zt0_callee()
+  ret void
+}
+
 define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_caller_nonzt0_callee_dont_inline
-; CHECK-SAME: () #[[ATTR4:[0-9]+]] {
+; CHECK-SAME: () #[[ATTR5:[0-9]+]] {
 ; CHECK-NEXT:    call void @nonzt0_callee()
 ; CHECK-NEXT:    ret void
 ;
@@ -403,7 +427,7 @@ define void @shared_zt0_caller_nonzt0_callee_dont_inline() "aarch64_inout_zt0" {
 
 define void @shared_zt0_callee() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_callee
-; CHECK-SAME: () #[[ATTR4]] {
+; CHECK-SAME: () #[[ATTR5]] {
 ; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void
@@ -415,7 +439,7 @@ define void @shared_zt0_callee() "aarch64_inout_zt0" {
 
 define void @shared_zt0_caller_shared_zt0_callee_inline() "aarch64_inout_zt0" {
 ; CHECK-LABEL: define void @shared_zt0_caller_shared_zt0_callee_inline
-; CHECK-SAME: () #[[ATTR4]] {
+; CHECK-SAME: () #[[ATTR5]] {
 ; CHECK-NEXT:    call void asm sideeffect "
 ; CHECK-NEXT:    call void @inlined_body()
 ; CHECK-NEXT:    ret void

>From 6755e6e451ae50d949055940aa5183c6cc7f55dd Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Mon, 30 Dec 2024 17:02:12 +0000
Subject: [PATCH 2/2] - Add similar restriction for new ZT0 callees to
 GetArmSMEInlinability in Clang

---
 .../clang/Basic/DiagnosticFrontendKinds.td        |  2 ++
 clang/lib/CodeGen/Targets/AArch64.cpp             | 15 ++++++++++++---
 .../AArch64/sme-inline-callees-streaming-attrs.c  | 13 +++++++++++--
 .../CodeGen/AArch64/sme-inline-streaming-attrs.c  | 12 ++++++++----
 4 files changed, 33 insertions(+), 9 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticFrontendKinds.td b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
index 1ed379c76c8ea2..f3593f5313340b 100644
--- a/clang/include/clang/Basic/DiagnosticFrontendKinds.td
+++ b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
@@ -291,6 +291,8 @@ def warn_function_always_inline_attribute_mismatch : Warning<
   "inlining may change runtime behaviour">, InGroup<AArch64SMEAttributes>;
 def err_function_always_inline_new_za : Error<
   "always_inline function %0 has new za state">;
+def err_function_always_inline_new_zt0
+    : Error<"always_inline function %0 has new zt0 state">;
 
 def warn_avx_calling_convention
     : Warning<"AVX vector %select{return|argument}0 of type %1 without '%2' "
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ad7f405cc72550..f7f3f8b0dc4236 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -1169,8 +1169,9 @@ void AArch64TargetCodeGenInfo::checkFunctionABI(
 enum class ArmSMEInlinability : uint8_t {
   Ok = 0,
   ErrorCalleeRequiresNewZA = 1 << 0,
-  WarnIncompatibleStreamingModes = 1 << 1,
-  ErrorIncompatibleStreamingModes = 1 << 2,
+  ErrorCalleeRequiresNewZT0 = 1 << 1,
+  WarnIncompatibleStreamingModes = 1 << 2,
+  ErrorIncompatibleStreamingModes = 1 << 3,
 
   IncompatibleStreamingModes =
       WarnIncompatibleStreamingModes | ErrorIncompatibleStreamingModes,
@@ -1198,9 +1199,12 @@ static ArmSMEInlinability GetArmSMEInlinability(const FunctionDecl *Caller,
     else
       Inlinability |= ArmSMEInlinability::WarnIncompatibleStreamingModes;
   }
-  if (auto *NewAttr = Callee->getAttr<ArmNewAttr>())
+  if (auto *NewAttr = Callee->getAttr<ArmNewAttr>()) {
     if (NewAttr->isNewZA())
       Inlinability |= ArmSMEInlinability::ErrorCalleeRequiresNewZA;
+    if (NewAttr->isNewZT0())
+      Inlinability |= ArmSMEInlinability::ErrorCalleeRequiresNewZT0;
+  }
 
   return Inlinability;
 }
@@ -1227,6 +1231,11 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABIStreaming(
       ArmSMEInlinability::ErrorCalleeRequiresNewZA)
     CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_za)
         << Callee->getDeclName();
+
+  if ((Inlinability & ArmSMEInlinability::ErrorCalleeRequiresNewZT0) ==
+      ArmSMEInlinability::ErrorCalleeRequiresNewZT0)
+    CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_zt0)
+        << Callee->getDeclName();
 }
 
 // If the target does not have floating-point registers, but we are using a
diff --git a/clang/test/CodeGen/AArch64/sme-inline-callees-streaming-attrs.c b/clang/test/CodeGen/AArch64/sme-inline-callees-streaming-attrs.c
index ce6f203631fc5c..2071e66e0d652c 100644
--- a/clang/test/CodeGen/AArch64/sme-inline-callees-streaming-attrs.c
+++ b/clang/test/CodeGen/AArch64/sme-inline-callees-streaming-attrs.c
@@ -1,5 +1,5 @@
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme %s -DUSE_FLATTEN -o - | FileCheck %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme %s -DUSE_ALWAYS_INLINE_STMT -o - | FileCheck %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme -target-feature +sme2 %s -DUSE_FLATTEN -o - | FileCheck %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -emit-llvm -target-feature +sme -target-feature +sme2 %s -DUSE_ALWAYS_INLINE_STMT -o - | FileCheck %s
 
 // REQUIRES: aarch64-registered-target
 
@@ -20,6 +20,7 @@ void fn_streaming_compatible(void) __arm_streaming_compatible { was_inlined(); }
 void fn_streaming(void) __arm_streaming { was_inlined(); }
 __arm_locally_streaming void fn_locally_streaming(void) { was_inlined(); }
 __arm_new("za") void fn_streaming_new_za(void) __arm_streaming { was_inlined(); }
+__arm_new("zt0") void fn_streaming_new_zt0(void) __arm_streaming { was_inlined(); }
 
 FN_ATTR
 void caller(void) {
@@ -28,6 +29,7 @@ void caller(void) {
     STMT_ATTR fn_streaming();
     STMT_ATTR fn_locally_streaming();
     STMT_ATTR fn_streaming_new_za();
+    STMT_ATTR fn_streaming_new_zt0();
 }
 // CHECK-LABEL: void @caller()
 //  CHECK-NEXT: entry:
@@ -36,6 +38,7 @@ void caller(void) {
 //  CHECK-NEXT:   call void @fn_streaming
 //  CHECK-NEXT:   call void @fn_locally_streaming
 //  CHECK-NEXT:   call void @fn_streaming_new_za
+//  CHECK-NEXT:   call void @fn_streaming_new_zt0
 
 FN_ATTR void caller_streaming_compatible(void) __arm_streaming_compatible {
     STMT_ATTR fn();
@@ -43,6 +46,7 @@ FN_ATTR void caller_streaming_compatible(void) __arm_streaming_compatible {
     STMT_ATTR fn_streaming();
     STMT_ATTR fn_locally_streaming();
     STMT_ATTR fn_streaming_new_za();
+    STMT_ATTR fn_streaming_new_zt0();
 }
 // CHECK-LABEL: void @caller_streaming_compatible()
 //  CHECK-NEXT: entry:
@@ -51,6 +55,7 @@ FN_ATTR void caller_streaming_compatible(void) __arm_streaming_compatible {
 //  CHECK-NEXT:   call void @fn_streaming
 //  CHECK-NEXT:   call void @fn_locally_streaming
 //  CHECK-NEXT:   call void @fn_streaming_new_za
+//  CHECK-NEXT:   call void @fn_streaming_new_zt0
 
 FN_ATTR void caller_streaming(void) __arm_streaming {
     STMT_ATTR fn();
@@ -58,6 +63,7 @@ FN_ATTR void caller_streaming(void) __arm_streaming {
     STMT_ATTR fn_streaming();
     STMT_ATTR fn_locally_streaming();
     STMT_ATTR fn_streaming_new_za();
+    STMT_ATTR fn_streaming_new_zt0();
 }
 // CHECK-LABEL: void @caller_streaming()
 //  CHECK-NEXT: entry:
@@ -66,6 +72,7 @@ FN_ATTR void caller_streaming(void) __arm_streaming {
 //  CHECK-NEXT:   call void @was_inlined
 //  CHECK-NEXT:   call void @was_inlined
 //  CHECK-NEXT:   call void @fn_streaming_new_za
+//  CHECK-NEXT:   call void @fn_streaming_new_zt0
 
 FN_ATTR __arm_locally_streaming
 void caller_locally_streaming(void) {
@@ -74,6 +81,7 @@ void caller_locally_streaming(void) {
     STMT_ATTR fn_streaming();
     STMT_ATTR fn_locally_streaming();
     STMT_ATTR fn_streaming_new_za();
+    STMT_ATTR fn_streaming_new_zt0();
 }
 // CHECK-LABEL: void @caller_locally_streaming()
 //  CHECK-NEXT: entry:
@@ -82,3 +90,4 @@ void caller_locally_streaming(void) {
 //  CHECK-NEXT:   call void @was_inlined
 //  CHECK-NEXT:   call void @was_inlined
 //  CHECK-NEXT:   call void @fn_streaming_new_za
+//  CHECK-NEXT:   call void @fn_streaming_new_zt0
diff --git a/clang/test/CodeGen/AArch64/sme-inline-streaming-attrs.c b/clang/test/CodeGen/AArch64/sme-inline-streaming-attrs.c
index 9c3d08a25945a3..68102c9ded40c4 100644
--- a/clang/test/CodeGen/AArch64/sme-inline-streaming-attrs.c
+++ b/clang/test/CodeGen/AArch64/sme-inline-streaming-attrs.c
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_NONE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_COMPATIBLE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_STREAMING %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_LOCALLY %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -target-feature +sme2 -verify -DTEST_NONE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -target-feature +sme2 -verify -DTEST_COMPATIBLE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -target-feature +sme2 -verify -DTEST_STREAMING %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -target-feature +sme2 -verify -DTEST_LOCALLY %s
 
 // REQUIRES: aarch64-registered-target
 
@@ -10,6 +10,8 @@ __ai void inlined_fn(void) {}
 __ai void inlined_fn_streaming_compatible(void) __arm_streaming_compatible {}
 __ai void inlined_fn_streaming(void) __arm_streaming {}
 __ai __arm_locally_streaming void inlined_fn_local(void) {}
+__ai __arm_new("za") void inlined_fn_za(void) {}
+__ai __arm_new("zt0") void inlined_fn_zt0(void) {}
 
 #ifdef TEST_NONE
 void caller(void) {
@@ -17,6 +19,8 @@ void caller(void) {
     inlined_fn_streaming_compatible();
     inlined_fn_streaming(); // expected-error {{always_inline function 'inlined_fn_streaming' and its caller 'caller' have mismatching streaming attributes}}
     inlined_fn_local(); // expected-error {{always_inline function 'inlined_fn_local' and its caller 'caller' have mismatching streaming attributes}}
+    inlined_fn_za(); // expected-error {{always_inline function 'inlined_fn_za' has new za state}}
+    inlined_fn_zt0(); // expected-error {{always_inline function 'inlined_fn_zt0' has new zt0 state}}
 }
 #endif
 



More information about the llvm-commits mailing list