[clang] [llvm] [Clang][SME] Detect always_inline used with mismatched streaming attributes (PR #77936)

Sam Tebbs via cfe-commits cfe-commits at lists.llvm.org
Wed Jan 24 07:11:53 PST 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/77936

>From 7314429a203900a8f555e1b0471fdd4cfd4d8d03 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 10 Jan 2024 14:57:04 +0000
Subject: [PATCH 1/8] [Clang][SME] Detect always_inline used with mismatched
 streaming attributes

This patch adds an error that is emitted when a streaming function is
marked as always_inline and is called from a non-streaming function.
---
 .../clang/Basic/DiagnosticFrontendKinds.td    |  2 ++
 clang/include/clang/Sema/Sema.h               |  9 +++++++
 clang/lib/CodeGen/CMakeLists.txt              |  1 +
 clang/lib/CodeGen/Targets/AArch64.cpp         | 20 ++++++++++++++
 clang/lib/Sema/SemaChecking.cpp               | 27 +++++++------------
 ...-sme-func-attrs-inline-locally-streaming.c | 12 +++++++++
 .../aarch64-sme-func-attrs-inline-streaming.c | 12 +++++++++
 7 files changed, 66 insertions(+), 17 deletions(-)
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c

diff --git a/clang/include/clang/Basic/DiagnosticFrontendKinds.td b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
index 85ecfdf9de62d44..2d0f971858840db 100644
--- a/clang/include/clang/Basic/DiagnosticFrontendKinds.td
+++ b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
@@ -279,6 +279,8 @@ def err_builtin_needs_feature : Error<"%0 needs target feature %1">;
 def err_function_needs_feature : Error<
   "always_inline function %1 requires target feature '%2', but would "
   "be inlined into function %0 that is compiled without support for '%2'">;
+def err_function_alwaysinline_attribute_mismatch : Error<
+  "always_inline function %1 and its caller %0 have mismatched %2 attributes">;
 
 def warn_avx_calling_convention
     : Warning<"AVX vector %select{return|argument}0 of type %1 without '%2' "
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 6ce422d66ae5b0e..dd75b5aad3d9c86 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13832,8 +13832,17 @@ class Sema final {
     FormatArgumentPassingKind ArgPassingKind;
   };
 
+enum ArmStreamingType {
+  ArmNonStreaming,
+  ArmStreaming,
+  ArmStreamingCompatible,
+  ArmStreamingOrSVE2p1
+};
+
+
   static bool getFormatStringInfo(const FormatAttr *Format, bool IsCXXMember,
                                   bool IsVariadic, FormatStringInfo *FSI);
+  static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD);
 
 private:
   void CheckArrayAccess(const Expr *BaseExpr, const Expr *IndexExpr,
diff --git a/clang/lib/CodeGen/CMakeLists.txt b/clang/lib/CodeGen/CMakeLists.txt
index 52216d93a302bbb..03a6f2f1d7a9d26 100644
--- a/clang/lib/CodeGen/CMakeLists.txt
+++ b/clang/lib/CodeGen/CMakeLists.txt
@@ -151,4 +151,5 @@ add_clang_library(clangCodeGen
   clangFrontend
   clangLex
   clangSerialization
+  clangSema
   )
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index ee7f95084d2e0b6..4018f91422e358f 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -8,6 +8,8 @@
 
 #include "ABIInfoImpl.h"
 #include "TargetInfo.h"
+#include "clang/Basic/DiagnosticFrontend.h"
+#include "clang/Sema/Sema.h"
 
 using namespace clang;
 using namespace clang::CodeGen;
@@ -155,6 +157,11 @@ class AArch64TargetCodeGenInfo : public TargetCodeGenInfo {
     }
     return TargetCodeGenInfo::isScalarizableAsmOperand(CGF, Ty);
   }
+
+  void checkFunctionCallABI(CodeGenModule &CGM, SourceLocation CallLoc,
+                            const FunctionDecl *Caller,
+                            const FunctionDecl *Callee,
+                            const CallArgList &Args) const override;
 };
 
 class WindowsAArch64TargetCodeGenInfo : public AArch64TargetCodeGenInfo {
@@ -814,6 +821,19 @@ Address AArch64ABIInfo::EmitMSVAArg(CodeGenFunction &CGF, Address VAListAddr,
                           /*allowHigherAlign*/ false);
 }
 
+void AArch64TargetCodeGenInfo::checkFunctionCallABI(
+    CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
+    const FunctionDecl *Callee, const CallArgList &Args) const {
+    if (!Callee->hasAttr<AlwaysInlineAttr>())
+      return;
+
+    auto CalleeIsStreaming = Sema::getArmStreamingFnType(Callee) == Sema::ArmStreaming;
+    auto CallerIsStreaming = Sema::getArmStreamingFnType(Caller) == Sema::ArmStreaming;
+
+    if (CalleeIsStreaming && !CallerIsStreaming)
+        CGM.getDiags().Report(CallLoc, diag::err_function_alwaysinline_attribute_mismatch) << Caller->getDeclName() << Callee->getDeclName() << "streaming";
+}
+
 std::unique_ptr<TargetCodeGenInfo>
 CodeGen::createAArch64TargetCodeGenInfo(CodeGenModule &CGM,
                                         AArch64ABIKind Kind) {
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index ace3e386988f005..a92db7d67e1cbd0 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2998,13 +2998,6 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   llvm_unreachable("Invalid NeonTypeFlag!");
 }
 
-enum ArmStreamingType {
-  ArmNonStreaming,
-  ArmStreaming,
-  ArmStreamingCompatible,
-  ArmStreamingOrSVE2p1
-};
-
 bool Sema::ParseSVEImmChecks(
     CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
   // Perform all the immediate checks for this builtin call.
@@ -3145,7 +3138,7 @@ bool Sema::ParseSVEImmChecks(
   return HasError;
 }
 
-static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
+Sema::ArmStreamingType Sema::getArmStreamingFnType(const FunctionDecl *FD) {
   if (FD->hasAttr<ArmLocallyStreamingAttr>())
     return ArmStreaming;
   if (const auto *T = FD->getType()->getAs<FunctionProtoType>()) {
@@ -3159,31 +3152,31 @@ static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
 
 static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
                                      const FunctionDecl *FD,
-                                     ArmStreamingType BuiltinType) {
-  ArmStreamingType FnType = getArmStreamingFnType(FD);
-  if (BuiltinType == ArmStreamingOrSVE2p1) {
+                                     Sema::ArmStreamingType BuiltinType) {
+  Sema::ArmStreamingType FnType = Sema::getArmStreamingFnType(FD);
+  if (BuiltinType == Sema::ArmStreamingOrSVE2p1) {
     // Check intrinsics that are available in [sve2p1 or sme/sme2].
     llvm::StringMap<bool> CallerFeatureMap;
     S.Context.getFunctionFeatureMap(CallerFeatureMap, FD);
     if (Builtin::evaluateRequiredTargetFeatures("sve2p1", CallerFeatureMap))
-      BuiltinType = ArmStreamingCompatible;
+      BuiltinType = Sema::ArmStreamingCompatible;
     else
-      BuiltinType = ArmStreaming;
+      BuiltinType = Sema::ArmStreaming;
   }
 
-  if (FnType == ArmStreaming && BuiltinType == ArmNonStreaming) {
+  if (FnType == Sema::ArmStreaming && BuiltinType == Sema::ArmNonStreaming) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "streaming";
   }
 
-  if (FnType == ArmStreamingCompatible &&
-      BuiltinType != ArmStreamingCompatible) {
+  if (FnType == Sema::ArmStreamingCompatible &&
+      BuiltinType != Sema::ArmStreamingCompatible) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "streaming compatible";
     return;
   }
 
-  if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
+  if (FnType == Sema::ArmNonStreaming && BuiltinType == Sema::ArmStreaming) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "non-streaming";
   }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c
new file mode 100644
index 000000000000000..4aa9fbf4a8fa18c
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c
@@ -0,0 +1,12 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+// Conflicting attributes when using always_inline
+__attribute__((always_inline)) __arm_locally_streaming
+int inlined_fn_local(void) {
+    return 42;
+}
+// expected-error at +1 {{always_inline function 'inlined_fn_local' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
+int inlined_fn_caller(void) { return inlined_fn_local(); }
+__arm_locally_streaming
+int inlined_fn_caller_local(void) { return inlined_fn_local(); }
+int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_local(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c
new file mode 100644
index 000000000000000..7268a49bb2491d0
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c
@@ -0,0 +1,12 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+// Conflicting attributes when using always_inline
+__attribute__((always_inline))
+int inlined_fn_streaming(void) __arm_streaming {
+    return 42;
+}
+// expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
+int inlined_fn_caller(void) { return inlined_fn_streaming(); }
+__arm_locally_streaming
+int inlined_fn_caller_local(void) { return inlined_fn_streaming(); }
+int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_streaming(); }

>From 04866d0c0be106a3d0296938f5e1b1d5f9e2591e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 12 Jan 2024 15:13:28 +0000
Subject: [PATCH 2/8] fixup: formatting

---
 clang/include/clang/Sema/Sema.h       | 13 ++++++-------
 clang/lib/CodeGen/Targets/AArch64.cpp | 16 ++++++++++------
 2 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index dd75b5aad3d9c86..1ee065e7de9b15b 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13832,13 +13832,12 @@ class Sema final {
     FormatArgumentPassingKind ArgPassingKind;
   };
 
-enum ArmStreamingType {
-  ArmNonStreaming,
-  ArmStreaming,
-  ArmStreamingCompatible,
-  ArmStreamingOrSVE2p1
-};
-
+  enum ArmStreamingType {
+    ArmNonStreaming,
+    ArmStreaming,
+    ArmStreamingCompatible,
+    ArmStreamingOrSVE2p1
+  };
 
   static bool getFormatStringInfo(const FormatAttr *Format, bool IsCXXMember,
                                   bool IsVariadic, FormatStringInfo *FSI);
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 4018f91422e358f..fe7b384c721d695 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -824,14 +824,18 @@ Address AArch64ABIInfo::EmitMSVAArg(CodeGenFunction &CGF, Address VAListAddr,
 void AArch64TargetCodeGenInfo::checkFunctionCallABI(
     CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
     const FunctionDecl *Callee, const CallArgList &Args) const {
-    if (!Callee->hasAttr<AlwaysInlineAttr>())
-      return;
+  if (!Callee->hasAttr<AlwaysInlineAttr>())
+    return;
 
-    auto CalleeIsStreaming = Sema::getArmStreamingFnType(Callee) == Sema::ArmStreaming;
-    auto CallerIsStreaming = Sema::getArmStreamingFnType(Caller) == Sema::ArmStreaming;
+  auto CalleeIsStreaming =
+      Sema::getArmStreamingFnType(Callee) == Sema::ArmStreaming;
+  auto CallerIsStreaming =
+      Sema::getArmStreamingFnType(Caller) == Sema::ArmStreaming;
 
-    if (CalleeIsStreaming && !CallerIsStreaming)
-        CGM.getDiags().Report(CallLoc, diag::err_function_alwaysinline_attribute_mismatch) << Caller->getDeclName() << Callee->getDeclName() << "streaming";
+  if (CalleeIsStreaming && !CallerIsStreaming)
+    CGM.getDiags().Report(CallLoc,
+                          diag::err_function_alwaysinline_attribute_mismatch)
+        << Caller->getDeclName() << Callee->getDeclName() << "streaming";
 }
 
 std::unique_ptr<TargetCodeGenInfo>

>From 156c96bc9660472c17c666ff465613d7276e0db9 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 15 Jan 2024 11:32:44 +0000
Subject: [PATCH 3/8] fixup: allow streaming compatible callee and check if
 modes are the same

---
 clang/lib/CodeGen/Targets/AArch64.cpp           | 11 ++++++-----
 .../sme-inline-streaming-compatible-caller.c    |  9 +++++++++
 .../sme-inline-streaming-compatible.c           | 17 +++++++++++++++++
 .../sme-inline-streaming-locally.c}             |  0
 .../sme-inline-streaming.c}                     |  0
 5 files changed, 32 insertions(+), 5 deletions(-)
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
 rename clang/test/CodeGen/{aarch64-sme-func-attrs-inline-locally-streaming.c => aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c} (100%)
 rename clang/test/CodeGen/{aarch64-sme-func-attrs-inline-streaming.c => aarch64-sme-func-attrs-inline/sme-inline-streaming.c} (100%)

diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index fe7b384c721d695..72f70c931fd9e67 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -827,12 +827,13 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
   if (!Callee->hasAttr<AlwaysInlineAttr>())
     return;
 
-  auto CalleeIsStreaming =
-      Sema::getArmStreamingFnType(Callee) == Sema::ArmStreaming;
-  auto CallerIsStreaming =
-      Sema::getArmStreamingFnType(Caller) == Sema::ArmStreaming;
+  auto CalleeStreamingMode = Sema::getArmStreamingFnType(Callee);
+  auto CallerStreamingMode = Sema::getArmStreamingFnType(Caller);
 
-  if (CalleeIsStreaming && !CallerIsStreaming)
+  // The caller can inline the callee if their streaming modes match or the
+  // callee is streaming compatible
+  if (CalleeStreamingMode != CallerStreamingMode &&
+      CalleeStreamingMode != Sema::ArmStreamingCompatible)
     CGM.getDiags().Report(CallLoc,
                           diag::err_function_alwaysinline_attribute_mismatch)
         << Caller->getDeclName() << Callee->getDeclName() << "streaming";
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
new file mode 100644
index 000000000000000..5c1779291a7e5b9
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
@@ -0,0 +1,9 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+// Conflicting attributes when using always_inline
+__attribute__((always_inline))
+int inlined_fn_streaming(void) __arm_streaming {
+    return 42;
+}
+// expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
+int inlined_fn_caller(void) __arm_streaming_compatible { return inlined_fn_streaming(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
new file mode 100644
index 000000000000000..a996c429fdda0da
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
@@ -0,0 +1,17 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+// Conflicting attributes when using always_inline
+__attribute__((always_inline))
+int inlined_fn_streaming_compatible(void) __arm_streaming_compatible {
+    return 42;
+}
+__attribute__((always_inline))
+int inlined_fn(void) {
+    return 42;
+}
+int inlined_fn_caller(void) { return inlined_fn_streaming_compatible(); }
+__arm_locally_streaming
+int inlined_fn_caller_local(void) { return inlined_fn_streaming_compatible(); }
+int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_streaming_compatible(); }
+// expected-error at +1 {{always_inline function 'inlined_fn' and its caller 'inlined_fn_caller_compatible' have mismatched streaming attributes}}
+int inlined_fn_caller_compatible(void) __arm_streaming_compatible { return inlined_fn(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
similarity index 100%
rename from clang/test/CodeGen/aarch64-sme-func-attrs-inline-locally-streaming.c
rename to clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
similarity index 100%
rename from clang/test/CodeGen/aarch64-sme-func-attrs-inline-streaming.c
rename to clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c

>From 5428dd7ed0745fd58cec768cc12742c40b86ce60 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 15 Jan 2024 16:14:00 +0000
Subject: [PATCH 4/8] fixup: rename error

---
 clang/include/clang/Basic/DiagnosticFrontendKinds.td | 2 +-
 clang/lib/CodeGen/Targets/AArch64.cpp                | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticFrontendKinds.td b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
index 2d0f971858840db..27149553c79a2a7 100644
--- a/clang/include/clang/Basic/DiagnosticFrontendKinds.td
+++ b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
@@ -279,7 +279,7 @@ def err_builtin_needs_feature : Error<"%0 needs target feature %1">;
 def err_function_needs_feature : Error<
   "always_inline function %1 requires target feature '%2', but would "
   "be inlined into function %0 that is compiled without support for '%2'">;
-def err_function_alwaysinline_attribute_mismatch : Error<
+def err_function_always_inline_attribute_mismatch : Error<
   "always_inline function %1 and its caller %0 have mismatched %2 attributes">;
 
 def warn_avx_calling_convention
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 72f70c931fd9e67..8d9f840c018fa53 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -835,7 +835,7 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
   if (CalleeStreamingMode != CallerStreamingMode &&
       CalleeStreamingMode != Sema::ArmStreamingCompatible)
     CGM.getDiags().Report(CallLoc,
-                          diag::err_function_alwaysinline_attribute_mismatch)
+                          diag::err_function_always_inline_attribute_mismatch)
         << Caller->getDeclName() << Callee->getDeclName() << "streaming";
 }
 

>From 3b5451c8743a9f5a00d460c6a8eb998b6acaedc2 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 15 Jan 2024 16:14:14 +0000
Subject: [PATCH 5/8] fixup: return void in tests

---
 .../sme-inline-streaming-compatible-caller.c     |  6 ++----
 .../sme-inline-streaming-compatible.c            | 16 ++++++----------
 .../sme-inline-streaming-locally.c               | 10 ++++------
 .../sme-inline-streaming.c                       | 10 ++++------
 4 files changed, 16 insertions(+), 26 deletions(-)

diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
index 5c1779291a7e5b9..add3f464bbaaa36 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible-caller.c
@@ -2,8 +2,6 @@
 
 // Conflicting attributes when using always_inline
 __attribute__((always_inline))
-int inlined_fn_streaming(void) __arm_streaming {
-    return 42;
-}
+void inlined_fn_streaming(void) __arm_streaming {}
 // expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
-int inlined_fn_caller(void) __arm_streaming_compatible { return inlined_fn_streaming(); }
+void inlined_fn_caller(void) __arm_streaming_compatible { inlined_fn_streaming(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
index a996c429fdda0da..a07c42c141d5b21 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-compatible.c
@@ -2,16 +2,12 @@
 
 // Conflicting attributes when using always_inline
 __attribute__((always_inline))
-int inlined_fn_streaming_compatible(void) __arm_streaming_compatible {
-    return 42;
-}
+void inlined_fn_streaming_compatible(void) __arm_streaming_compatible {}
 __attribute__((always_inline))
-int inlined_fn(void) {
-    return 42;
-}
-int inlined_fn_caller(void) { return inlined_fn_streaming_compatible(); }
+void inlined_fn(void) {}
+void inlined_fn_caller(void) { inlined_fn_streaming_compatible(); }
 __arm_locally_streaming
-int inlined_fn_caller_local(void) { return inlined_fn_streaming_compatible(); }
-int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_streaming_compatible(); }
+void inlined_fn_caller_local(void) { inlined_fn_streaming_compatible(); }
+void inlined_fn_caller_streaming(void) __arm_streaming { inlined_fn_streaming_compatible(); }
 // expected-error at +1 {{always_inline function 'inlined_fn' and its caller 'inlined_fn_caller_compatible' have mismatched streaming attributes}}
-int inlined_fn_caller_compatible(void) __arm_streaming_compatible { return inlined_fn(); }
+void inlined_fn_caller_compatible(void) __arm_streaming_compatible { inlined_fn(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
index 4aa9fbf4a8fa18c..6d645334ecb9d1c 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
@@ -2,11 +2,9 @@
 
 // Conflicting attributes when using always_inline
 __attribute__((always_inline)) __arm_locally_streaming
-int inlined_fn_local(void) {
-    return 42;
-}
+void inlined_fn_local(void) {}
 // expected-error at +1 {{always_inline function 'inlined_fn_local' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
-int inlined_fn_caller(void) { return inlined_fn_local(); }
+void inlined_fn_caller(void) { inlined_fn_local(); }
 __arm_locally_streaming
-int inlined_fn_caller_local(void) { return inlined_fn_local(); }
-int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_local(); }
+void inlined_fn_caller_local(void) { inlined_fn_local(); }
+void inlined_fn_caller_streaming(void) __arm_streaming { inlined_fn_local(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
index 7268a49bb2491d0..8afce84837838c7 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
@@ -2,11 +2,9 @@
 
 // Conflicting attributes when using always_inline
 __attribute__((always_inline))
-int inlined_fn_streaming(void) __arm_streaming {
-    return 42;
-}
+void inlined_fn_streaming(void) __arm_streaming {}
 // expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
-int inlined_fn_caller(void) { return inlined_fn_streaming(); }
+void inlined_fn_caller(void) { inlined_fn_streaming(); }
 __arm_locally_streaming
-int inlined_fn_caller_local(void) { return inlined_fn_streaming(); }
-int inlined_fn_caller_streaming(void) __arm_streaming { return inlined_fn_streaming(); }
+void inlined_fn_caller_local(void) { inlined_fn_streaming(); }
+void inlined_fn_caller_streaming(void) __arm_streaming { inlined_fn_streaming(); }

>From c49ff5e797a123a146c7df722556eb4d12c85047 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 17 Jan 2024 17:10:41 +0000
Subject: [PATCH 6/8] fixup: use SMEAttrs class

---
 clang/include/clang/Sema/Sema.h               |  8 ------
 clang/lib/CodeGen/CMakeLists.txt              |  2 +-
 clang/lib/CodeGen/Targets/AArch64.cpp         | 25 ++++++++++++-----
 clang/lib/Sema/SemaChecking.cpp               | 27 ++++++++++++-------
 .../llvm/Support}/AArch64SMEAttributes.h      |  0
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 +-
 llvm/lib/Target/AArch64/SMEABIPass.cpp        |  2 +-
 .../AArch64/Utils/AArch64SMEAttributes.cpp    |  2 +-
 .../Target/AArch64/SMEAttributesTest.cpp      |  2 +-
 9 files changed, 40 insertions(+), 30 deletions(-)
 rename llvm/{lib/Target/AArch64/Utils => include/llvm/Support}/AArch64SMEAttributes.h (100%)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 1ee065e7de9b15b..6ce422d66ae5b0e 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13832,16 +13832,8 @@ class Sema final {
     FormatArgumentPassingKind ArgPassingKind;
   };
 
-  enum ArmStreamingType {
-    ArmNonStreaming,
-    ArmStreaming,
-    ArmStreamingCompatible,
-    ArmStreamingOrSVE2p1
-  };
-
   static bool getFormatStringInfo(const FormatAttr *Format, bool IsCXXMember,
                                   bool IsVariadic, FormatStringInfo *FSI);
-  static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD);
 
 private:
   void CheckArrayAccess(const Expr *BaseExpr, const Expr *IndexExpr,
diff --git a/clang/lib/CodeGen/CMakeLists.txt b/clang/lib/CodeGen/CMakeLists.txt
index 03a6f2f1d7a9d26..919b826f5c57168 100644
--- a/clang/lib/CodeGen/CMakeLists.txt
+++ b/clang/lib/CodeGen/CMakeLists.txt
@@ -31,6 +31,7 @@ set(LLVM_LINK_COMPONENTS
   Target
   TargetParser
   TransformUtils
+  AArch64Utils
   )
 
 # Workaround for MSVC ARM64 performance regression:
@@ -151,5 +152,4 @@ add_clang_library(clangCodeGen
   clangFrontend
   clangLex
   clangSerialization
-  clangSema
   )
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 8d9f840c018fa53..8b15b798c5f896d 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -9,7 +9,7 @@
 #include "ABIInfoImpl.h"
 #include "TargetInfo.h"
 #include "clang/Basic/DiagnosticFrontend.h"
-#include "clang/Sema/Sema.h"
+#include "llvm/Support/AArch64SMEAttributes.h"
 
 using namespace clang;
 using namespace clang::CodeGen;
@@ -827,13 +827,24 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
   if (!Callee->hasAttr<AlwaysInlineAttr>())
     return;
 
-  auto CalleeStreamingMode = Sema::getArmStreamingFnType(Callee);
-  auto CallerStreamingMode = Sema::getArmStreamingFnType(Caller);
+  auto GetSMEAttrs = [](const FunctionDecl *F) {
+    llvm::SMEAttrs FAttrs;
+    if (F->hasAttr<ArmLocallyStreamingAttr>())
+      FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
+    if (const auto *T = F->getType()->getAs<FunctionProtoType>()) {
+      if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateSMEnabledMask)
+        FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
+      if (T->getAArch64SMEAttributes() &
+          FunctionType::SME_PStateSMCompatibleMask)
+        FAttrs.set(llvm::SMEAttrs::Mask::SM_Compatible);
+    }
+    return FAttrs;
+  };
+
+  auto CalleeAttrs = GetSMEAttrs(Callee);
+  auto CallerAttrs = GetSMEAttrs(Caller);
 
-  // The caller can inline the callee if their streaming modes match or the
-  // callee is streaming compatible
-  if (CalleeStreamingMode != CallerStreamingMode &&
-      CalleeStreamingMode != Sema::ArmStreamingCompatible)
+  if (CallerAttrs.requiresSMChange(CalleeAttrs, true))
     CGM.getDiags().Report(CallLoc,
                           diag::err_function_always_inline_attribute_mismatch)
         << Caller->getDeclName() << Callee->getDeclName() << "streaming";
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index a92db7d67e1cbd0..ace3e386988f005 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2998,6 +2998,13 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
   llvm_unreachable("Invalid NeonTypeFlag!");
 }
 
+enum ArmStreamingType {
+  ArmNonStreaming,
+  ArmStreaming,
+  ArmStreamingCompatible,
+  ArmStreamingOrSVE2p1
+};
+
 bool Sema::ParseSVEImmChecks(
     CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
   // Perform all the immediate checks for this builtin call.
@@ -3138,7 +3145,7 @@ bool Sema::ParseSVEImmChecks(
   return HasError;
 }
 
-Sema::ArmStreamingType Sema::getArmStreamingFnType(const FunctionDecl *FD) {
+static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
   if (FD->hasAttr<ArmLocallyStreamingAttr>())
     return ArmStreaming;
   if (const auto *T = FD->getType()->getAs<FunctionProtoType>()) {
@@ -3152,31 +3159,31 @@ Sema::ArmStreamingType Sema::getArmStreamingFnType(const FunctionDecl *FD) {
 
 static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
                                      const FunctionDecl *FD,
-                                     Sema::ArmStreamingType BuiltinType) {
-  Sema::ArmStreamingType FnType = Sema::getArmStreamingFnType(FD);
-  if (BuiltinType == Sema::ArmStreamingOrSVE2p1) {
+                                     ArmStreamingType BuiltinType) {
+  ArmStreamingType FnType = getArmStreamingFnType(FD);
+  if (BuiltinType == ArmStreamingOrSVE2p1) {
     // Check intrinsics that are available in [sve2p1 or sme/sme2].
     llvm::StringMap<bool> CallerFeatureMap;
     S.Context.getFunctionFeatureMap(CallerFeatureMap, FD);
     if (Builtin::evaluateRequiredTargetFeatures("sve2p1", CallerFeatureMap))
-      BuiltinType = Sema::ArmStreamingCompatible;
+      BuiltinType = ArmStreamingCompatible;
     else
-      BuiltinType = Sema::ArmStreaming;
+      BuiltinType = ArmStreaming;
   }
 
-  if (FnType == Sema::ArmStreaming && BuiltinType == Sema::ArmNonStreaming) {
+  if (FnType == ArmStreaming && BuiltinType == ArmNonStreaming) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "streaming";
   }
 
-  if (FnType == Sema::ArmStreamingCompatible &&
-      BuiltinType != Sema::ArmStreamingCompatible) {
+  if (FnType == ArmStreamingCompatible &&
+      BuiltinType != ArmStreamingCompatible) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "streaming compatible";
     return;
   }
 
-  if (FnType == Sema::ArmNonStreaming && BuiltinType == Sema::ArmStreaming) {
+  if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
     S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
         << TheCall->getSourceRange() << "non-streaming";
   }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/include/llvm/Support/AArch64SMEAttributes.h
similarity index 100%
rename from llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
rename to llvm/include/llvm/Support/AArch64SMEAttributes.h
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 6047a3b7b2864aa..9d9d67490aa7644 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -15,13 +15,13 @@
 #define LLVM_LIB_TARGET_AARCH64_AARCH64ISELLOWERING_H
 
 #include "AArch64.h"
-#include "Utils/AArch64SMEAttributes.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/IR/CallingConv.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/Support/AArch64SMEAttributes.h"
 
 namespace llvm {
 
diff --git a/llvm/lib/Target/AArch64/SMEABIPass.cpp b/llvm/lib/Target/AArch64/SMEABIPass.cpp
index 3315171798d9f1b..b9a68179966c043 100644
--- a/llvm/lib/Target/AArch64/SMEABIPass.cpp
+++ b/llvm/lib/Target/AArch64/SMEABIPass.cpp
@@ -14,7 +14,6 @@
 
 #include "AArch64.h"
 #include "Utils/AArch64BaseInfo.h"
-#include "Utils/AArch64SMEAttributes.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilder.h"
@@ -23,6 +22,7 @@
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/AArch64SMEAttributes.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index ccdec78d7808663..4af54c1d611d563 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "AArch64SMEAttributes.h"
+#include "llvm/Support/AArch64SMEAttributes.h"
 #include "llvm/IR/InstrTypes.h"
 #include <cassert>
 
diff --git a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
index 2f7201464ba2f23..b99cb4c0d775a85 100644
--- a/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
+++ b/llvm/unittests/Target/AArch64/SMEAttributesTest.cpp
@@ -1,7 +1,7 @@
-#include "Utils/AArch64SMEAttributes.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Module.h"
+#include "llvm/Support/AArch64SMEAttributes.h"
 #include "llvm/Support/SourceMgr.h"
 
 #include "gtest/gtest.h"

>From b38b7d59246de3a8a26e10238aed5cb757b89e1d Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 18 Jan 2024 11:07:00 +0000
Subject: [PATCH 7/8] fixup: rebase and check for new za state

---
 clang/include/clang/Basic/DiagnosticFrontendKinds.td       | 2 ++
 clang/lib/CodeGen/Targets/AArch64.cpp                      | 7 +++++++
 .../aarch64-sme-func-attrs-inline/sme-inline-new-za.c      | 6 ++++++
 3 files changed, 15 insertions(+)
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-new-za.c

diff --git a/clang/include/clang/Basic/DiagnosticFrontendKinds.td b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
index 27149553c79a2a7..cbb5db3e0b1c826 100644
--- a/clang/include/clang/Basic/DiagnosticFrontendKinds.td
+++ b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
@@ -281,6 +281,8 @@ def err_function_needs_feature : Error<
   "be inlined into function %0 that is compiled without support for '%2'">;
 def err_function_always_inline_attribute_mismatch : Error<
   "always_inline function %1 and its caller %0 have mismatched %2 attributes">;
+def err_function_always_inline_new_za : Error<
+  "always_inline function %0 has new za state">;
 
 def warn_avx_calling_convention
     : Warning<"AVX vector %select{return|argument}0 of type %1 without '%2' "
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 8b15b798c5f896d..9f4323bef38896c 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -831,6 +831,10 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
     llvm::SMEAttrs FAttrs;
     if (F->hasAttr<ArmLocallyStreamingAttr>())
       FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
+    if (auto *NewAttr = F->getAttr<ArmNewAttr>()) {
+      if (NewAttr->isNewZA())
+        FAttrs.set(llvm::SMEAttrs::Mask::ZA_New);
+    }
     if (const auto *T = F->getType()->getAs<FunctionProtoType>()) {
       if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateSMEnabledMask)
         FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
@@ -848,6 +852,9 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
     CGM.getDiags().Report(CallLoc,
                           diag::err_function_always_inline_attribute_mismatch)
         << Caller->getDeclName() << Callee->getDeclName() << "streaming";
+  if (CalleeAttrs.hasNewZABody())
+    CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_za)
+        << Callee->getDeclName();
 }
 
 std::unique_ptr<TargetCodeGenInfo>
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-new-za.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-new-za.c
new file mode 100644
index 000000000000000..97af6ec2be7a35c
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-new-za.c
@@ -0,0 +1,6 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+__attribute__((always_inline)) __arm_new("za")
+void inline_new_za(void)  { }
+// expected-error at +1 {{always_inline function 'inline_new_za' has new za state}}
+void inline_caller() { inline_new_za(); }

>From c624882613fc3245ca3c8d73ce68b17a9cb43068 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 23 Jan 2024 14:43:32 +0000
Subject: [PATCH 8/8] fixup: re-implement SMEAttributes and check for lazy save

---
 .../clang/Basic/DiagnosticFrontendKinds.td    |  2 +
 clang/lib/CodeGen/Targets/AArch64.cpp         | 81 +++++++++++++++----
 .../sme-inline-lazy-save.c                    | 14 ++++
 .../sme-inline-streaming-locally.c            |  4 +-
 .../sme-inline-streaming.c                    |  4 +-
 5 files changed, 86 insertions(+), 19 deletions(-)
 create mode 100644 clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-lazy-save.c

diff --git a/clang/include/clang/Basic/DiagnosticFrontendKinds.td b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
index cbb5db3e0b1c826..d31577b5cf5e74d 100644
--- a/clang/include/clang/Basic/DiagnosticFrontendKinds.td
+++ b/clang/include/clang/Basic/DiagnosticFrontendKinds.td
@@ -283,6 +283,8 @@ def err_function_always_inline_attribute_mismatch : Error<
   "always_inline function %1 and its caller %0 have mismatched %2 attributes">;
 def err_function_always_inline_new_za : Error<
   "always_inline function %0 has new za state">;
+def err_function_always_inline_lazy_save : Error<
+  "inlining always_inline function %0 into %1 would require a lazy za save">;
 
 def warn_avx_calling_convention
     : Warning<"AVX vector %select{return|argument}0 of type %1 without '%2' "
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 9f4323bef38896c..64d47a2f296d335 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -821,32 +821,80 @@ Address AArch64ABIInfo::EmitMSVAArg(CodeGenFunction &CGF, Address VAListAddr,
                           /*allowHigherAlign*/ false);
 }
 
-void AArch64TargetCodeGenInfo::checkFunctionCallABI(
-    CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
-    const FunctionDecl *Callee, const CallArgList &Args) const {
-  if (!Callee->hasAttr<AlwaysInlineAttr>())
-    return;
+class SMEAttributes {
+public:
+  bool IsStreaming = false;
+  bool IsStreamingBody = false;
+  bool IsStreamingCompatible = false;
+  bool HasNewZA = false;
 
-  auto GetSMEAttrs = [](const FunctionDecl *F) {
-    llvm::SMEAttrs FAttrs;
+  SMEAttributes(const FunctionDecl *F) {
     if (F->hasAttr<ArmLocallyStreamingAttr>())
-      FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
+      IsStreamingBody = true;
     if (auto *NewAttr = F->getAttr<ArmNewAttr>()) {
       if (NewAttr->isNewZA())
-        FAttrs.set(llvm::SMEAttrs::Mask::ZA_New);
+        HasNewZA = true;
     }
     if (const auto *T = F->getType()->getAs<FunctionProtoType>()) {
       if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateSMEnabledMask)
-        FAttrs.set(llvm::SMEAttrs::Mask::SM_Enabled);
+        IsStreaming = true;
       if (T->getAArch64SMEAttributes() &
           FunctionType::SME_PStateSMCompatibleMask)
-        FAttrs.set(llvm::SMEAttrs::Mask::SM_Compatible);
+        IsStreamingCompatible = true;
+    }
+  }
+
+  bool hasStreamingBody() const { return IsStreamingBody; }
+  bool hasStreamingInterface() const { return IsStreaming; }
+  bool hasStreamingCompatibleInterface() const { return IsStreamingCompatible; }
+  bool hasStreamingInterfaceOrBody() const {
+    return hasStreamingBody() || hasStreamingInterface();
+  }
+  bool hasNonStreamingInterface() const {
+    return !hasStreamingInterface() && !hasStreamingCompatibleInterface();
+  }
+  bool hasNonStreamingInterfaceAndBody() const {
+    return hasNonStreamingInterface() && !hasStreamingBody();
+  }
+
+  bool requiresSMChange(const SMEAttributes Callee,
+                        bool BodyOverridesInterface = false) {
+    // If the transition is not through a call (e.g. when considering inlining)
+    // and Callee has a streaming body, then we can ignore the interface of
+    // Callee.
+    if (BodyOverridesInterface && Callee.hasStreamingBody()) {
+      return !hasStreamingInterfaceOrBody();
     }
-    return FAttrs;
-  };
 
-  auto CalleeAttrs = GetSMEAttrs(Callee);
-  auto CallerAttrs = GetSMEAttrs(Caller);
+    if (Callee.hasStreamingCompatibleInterface())
+      return false;
+
+    if (hasStreamingCompatibleInterface())
+      return true;
+
+    // Both non-streaming
+    if (hasNonStreamingInterfaceAndBody() && Callee.hasNonStreamingInterface())
+      return false;
+
+    // Both streaming
+    if (hasStreamingInterfaceOrBody() && Callee.hasStreamingInterface())
+      return false;
+
+    return Callee.hasStreamingInterface();
+  }
+
+  bool hasNewZABody() { return HasNewZA; }
+  bool requiresLazySave() const { return HasNewZA; }
+};
+
+void AArch64TargetCodeGenInfo::checkFunctionCallABI(
+    CodeGenModule &CGM, SourceLocation CallLoc, const FunctionDecl *Caller,
+    const FunctionDecl *Callee, const CallArgList &Args) const {
+  if (!Callee->hasAttr<AlwaysInlineAttr>())
+    return;
+
+  SMEAttributes CalleeAttrs(Callee);
+  SMEAttributes CallerAttrs(Caller);
 
   if (CallerAttrs.requiresSMChange(CalleeAttrs, true))
     CGM.getDiags().Report(CallLoc,
@@ -855,6 +903,9 @@ void AArch64TargetCodeGenInfo::checkFunctionCallABI(
   if (CalleeAttrs.hasNewZABody())
     CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_new_za)
         << Callee->getDeclName();
+  if (CallerAttrs.requiresLazySave())
+    CGM.getDiags().Report(CallLoc, diag::err_function_always_inline_lazy_save)
+        << Callee->getDeclName() << Caller->getDeclName();
 }
 
 std::unique_ptr<TargetCodeGenInfo>
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-lazy-save.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-lazy-save.c
new file mode 100644
index 000000000000000..4cd0b32a08dab07
--- /dev/null
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-lazy-save.c
@@ -0,0 +1,14 @@
+// RUN: %clang --target=aarch64-none-linux-gnu -march=armv9-a+sme -O3 -S -Xclang -verify %s
+
+__attribute__((always_inline))
+void inlined(void) {}
+
+void inline_caller(void) {
+    inlined();
+}
+
+__arm_new("za")
+// expected-error at +2 {{inlining always_inline function 'inlined' into 'inline_caller_new_za' would require a lazy za save}}
+void inline_caller_new_za(void) {
+    inlined();
+}
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
index 6d645334ecb9d1c..09bbbaf43ea5a27 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming-locally.c
@@ -3,8 +3,8 @@
 // Conflicting attributes when using always_inline
 __attribute__((always_inline)) __arm_locally_streaming
 void inlined_fn_local(void) {}
-// expected-error at +1 {{always_inline function 'inlined_fn_local' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
-void inlined_fn_caller(void) { inlined_fn_local(); }
 __arm_locally_streaming
 void inlined_fn_caller_local(void) { inlined_fn_local(); }
 void inlined_fn_caller_streaming(void) __arm_streaming { inlined_fn_local(); }
+// expected-error at +1 {{always_inline function 'inlined_fn_local' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
+void inlined_fn_caller(void) { inlined_fn_local(); }
diff --git a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
index 8afce84837838c7..9f0e944c8f84010 100644
--- a/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
+++ b/clang/test/CodeGen/aarch64-sme-func-attrs-inline/sme-inline-streaming.c
@@ -3,8 +3,8 @@
 // Conflicting attributes when using always_inline
 __attribute__((always_inline))
 void inlined_fn_streaming(void) __arm_streaming {}
-// expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
-void inlined_fn_caller(void) { inlined_fn_streaming(); }
 __arm_locally_streaming
 void inlined_fn_caller_local(void) { inlined_fn_streaming(); }
 void inlined_fn_caller_streaming(void) __arm_streaming { inlined_fn_streaming(); }
+// expected-error at +1 {{always_inline function 'inlined_fn_streaming' and its caller 'inlined_fn_caller' have mismatched streaming attributes}}
+void inlined_fn_caller(void) { inlined_fn_streaming(); }



More information about the cfe-commits mailing list