[clang] [clang] Emit @llvm.assume when we know the streaming mode of the function (PR #121917)
Nicholas Guy via cfe-commits
cfe-commits at lists.llvm.org
Tue Jan 7 21:56:54 PST 2025
https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/121917
>From 27e9773135d1171c931aaa6b3f8c5f954b658969 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 7 Jan 2025 11:09:18 +0000
Subject: [PATCH 1/3] [clang] Emit @llvm.assume when we know the streaming mode
of the function
---
clang/lib/CodeGen/CGBuiltin.cpp | 6 ++++
.../sme-intrinsics/acle_sme_state_funs.c | 35 ++++++++++++-------
2 files changed, 28 insertions(+), 13 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index dcea32969fb990..3765285c58f6ca 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -11335,6 +11335,12 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
unsigned SMEAttrs = FPT->getAArch64SMEAttributes();
if (!(SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)) {
bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask;
+ // Emit the llvm.assume intrinsic so that called functions can use the
+ // streaming mode information discerned here
+ Value* call = Builder.CreateCall(CGM.getIntrinsic(Builtin->LLVMIntrinsic));
+ if (!IsStreaming)
+ call = Builder.CreateNot(call);
+ Builder.CreateIntrinsic(Intrinsic::assume, {}, {call});
return ConstantInt::getBool(Builder.getContext(), IsStreaming);
}
}
diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
index 72f2d17fc6dc11..1e630e196fcb66 100644
--- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
+++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
@@ -22,23 +22,32 @@ bool test_in_streaming_mode_streaming_compatible(void) __arm_streaming_compatibl
// CHECK-LABEL: @test_in_streaming_mode_streaming(
// CHECK-NEXT: entry:
+// CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
+// CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP0]])
// CHECK-NEXT: ret i1 true
//
// CPP-CHECK-LABEL: @_Z32test_in_streaming_mode_streamingv(
// CPP-CHECK-NEXT: entry:
+// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
+// CPP-CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP0]])
// CPP-CHECK-NEXT: ret i1 true
//
bool test_in_streaming_mode_streaming(void) __arm_streaming {
-//
return __arm_in_streaming_mode();
}
// CHECK-LABEL: @test_in_streaming_mode_non_streaming(
// CHECK-NEXT: entry:
+// CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
+// CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[TMP0]], true
+// CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]])
// CHECK-NEXT: ret i1 false
//
// CPP-CHECK-LABEL: @_Z36test_in_streaming_mode_non_streamingv(
// CPP-CHECK-NEXT: entry:
+// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
+// CPP-CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[TMP0]], true
+// CPP-CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]])
// CPP-CHECK-NEXT: ret i1 false
//
bool test_in_streaming_mode_non_streaming(void) {
@@ -47,12 +56,12 @@ bool test_in_streaming_mode_non_streaming(void) {
// CHECK-LABEL: @test_za_disable(
// CHECK-NEXT: entry:
-// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]]
+// CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR8:[0-9]+]]
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: @_Z15test_za_disablev(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR7:[0-9]+]]
+// CPP-CHECK-NEXT: tail call void @__arm_za_disable() #[[ATTR8:[0-9]+]]
// CPP-CHECK-NEXT: ret void
//
void test_za_disable(void) __arm_streaming_compatible {
@@ -61,14 +70,14 @@ void test_za_disable(void) __arm_streaming_compatible {
// CHECK-LABEL: @test_has_sme(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]]
+// CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR8]]
// CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0
// CHECK-NEXT: ret i1 [[TOBOOL_I]]
//
// CPP-CHECK-LABEL: @_Z12test_has_smev(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR7]]
+// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call aarch64_sme_preservemost_from_x2 { i64, i64 } @__arm_sme_state() #[[ATTR8]]
// CPP-CHECK-NEXT: [[TMP1:%.*]] = extractvalue { i64, i64 } [[TMP0]], 0
// CPP-CHECK-NEXT: [[TOBOOL_I:%.*]] = icmp slt i64 [[TMP1]], 0
// CPP-CHECK-NEXT: ret i1 [[TOBOOL_I]]
@@ -91,12 +100,12 @@ void test_svundef_za(void) __arm_streaming_compatible __arm_out("za") {
// CHECK-LABEL: @test_sc_memcpy(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memcpyPvPKvm(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memcpy(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_compatible {
@@ -105,12 +114,12 @@ void *test_sc_memcpy(void *dest, const void *src, size_t n) __arm_streaming_comp
// CHECK-LABEL: @test_sc_memmove(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z15test_sc_memmovePvPKvm(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memmove(ptr noundef [[DEST:%.*]], ptr noundef [[SRC:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_compatible {
@@ -119,12 +128,12 @@ void *test_sc_memmove(void *dest, const void *src, size_t n) __arm_streaming_com
// CHECK-LABEL: @test_sc_memset(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memsetPvim(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memset(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible {
@@ -133,12 +142,12 @@ void *test_sc_memset(void *s, int c, size_t n) __arm_streaming_compatible {
// CHECK-LABEL: @test_sc_memchr(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CHECK-NEXT: ret ptr [[CALL]]
//
// CPP-CHECK-LABEL: @_Z14test_sc_memchrPvim(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR7]]
+// CPP-CHECK-NEXT: [[CALL:%.*]] = tail call ptr @__arm_sc_memchr(ptr noundef [[S:%.*]], i32 noundef [[C:%.*]], i64 noundef [[N:%.*]]) #[[ATTR8]]
// CPP-CHECK-NEXT: ret ptr [[CALL]]
//
void *test_sc_memchr(void *s, int c, size_t n) __arm_streaming_compatible {
>From 7cc675068c00103c1ae5089c3709d0aa3f35563a Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 7 Jan 2025 11:38:53 +0000
Subject: [PATCH 2/3] Format
---
clang/lib/CodeGen/CGBuiltin.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 3765285c58f6ca..454cc6d16352f6 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -11337,7 +11337,8 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask;
// Emit the llvm.assume intrinsic so that called functions can use the
// streaming mode information discerned here
- Value* call = Builder.CreateCall(CGM.getIntrinsic(Builtin->LLVMIntrinsic));
+ Value *call =
+ Builder.CreateCall(CGM.getIntrinsic(Builtin->LLVMIntrinsic));
if (!IsStreaming)
call = Builder.CreateNot(call);
Builder.CreateIntrinsic(Intrinsic::assume, {}, {call});
>From ac671ca41235910f37d84ea0050e968482984ece Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 8 Jan 2025 05:43:19 +0000
Subject: [PATCH 3/3] Add hook to allow targets to emit additional IR at call
sites
---
clang/lib/CodeGen/CGBuiltin.cpp | 7 -----
clang/lib/CodeGen/CGCall.cpp | 3 ++
clang/lib/CodeGen/TargetInfo.h | 4 +++
clang/lib/CodeGen/Targets/AArch64.cpp | 29 +++++++++++++++++++
.../sme-intrinsics/acle_sme_state_funs.c | 10 -------
5 files changed, 36 insertions(+), 17 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 454cc6d16352f6..dcea32969fb990 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -11335,13 +11335,6 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
unsigned SMEAttrs = FPT->getAArch64SMEAttributes();
if (!(SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)) {
bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask;
- // Emit the llvm.assume intrinsic so that called functions can use the
- // streaming mode information discerned here
- Value *call =
- Builder.CreateCall(CGM.getIntrinsic(Builtin->LLVMIntrinsic));
- if (!IsStreaming)
- call = Builder.CreateNot(call);
- Builder.CreateIntrinsic(Intrinsic::assume, {}, {call});
return ConstantInt::getBool(Builder.getContext(), IsStreaming);
}
}
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 89e2eace9120bf..44b4e3ece8ee24 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -5101,6 +5101,9 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
const FunctionDecl *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl);
CGM.getTargetCodeGenInfo().checkFunctionCallABI(CGM, Loc, CallerDecl,
CalleeDecl, CallArgs, RetTy);
+ // 0. Allow the target to emit an additional prolog for the function call
+ CGM.getTargetCodeGenInfo().emitFunctionCallProlog(Builder, CallerDecl,
+ CalleeDecl);
// 1. Set up the arguments.
diff --git a/clang/lib/CodeGen/TargetInfo.h b/clang/lib/CodeGen/TargetInfo.h
index ab3142bdea684e..5ba6fc4acc9f08 100644
--- a/clang/lib/CodeGen/TargetInfo.h
+++ b/clang/lib/CodeGen/TargetInfo.h
@@ -443,6 +443,10 @@ class TargetCodeGenInfo {
return nullptr;
}
+ virtual void emitFunctionCallProlog(CGBuilderTy &Builder,
+ const FunctionDecl *Caller,
+ const FunctionDecl *Callee) const {}
+
// Set the Branch Protection Attributes of the Function accordingly to the
// BPI. Remove attributes that contradict with current BPI.
static void
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 7db67ecba07c8f..4cfafcc6747fb7 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -10,6 +10,7 @@
#include "TargetInfo.h"
#include "clang/AST/Decl.h"
#include "clang/Basic/DiagnosticFrontend.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
using namespace clang;
@@ -181,6 +182,9 @@ class AArch64TargetCodeGenInfo : public TargetCodeGenInfo {
bool wouldInliningViolateFunctionCallABI(
const FunctionDecl *Caller, const FunctionDecl *Callee) const override;
+ void emitFunctionCallProlog(CGBuilderTy &Builder, const FunctionDecl *Caller,
+ const FunctionDecl *Callee) const override;
+
private:
// Diagnose calls between functions with incompatible Streaming SVE
// attributes.
@@ -1275,6 +1279,31 @@ bool AArch64TargetCodeGenInfo::wouldInliningViolateFunctionCallABI(
GetArmSMEInlinability(Caller, Callee) != ArmSMEInlinability::Ok;
}
+void AArch64TargetCodeGenInfo::emitFunctionCallProlog(
+ CGBuilderTy &Builder, const FunctionDecl *Caller,
+ const FunctionDecl *Callee) const {
+ const AArch64ABIInfo &ABIInfo = getABIInfo<AArch64ABIInfo>();
+ const TargetInfo &TI = ABIInfo.getContext().getTargetInfo();
+
+ if (!TI.hasFeature("sme"))
+ return;
+
+ if (!Callee || !isStreamingCompatible(Callee))
+ return;
+
+ if (const auto *FPT = Caller->getType()->getAs<FunctionProtoType>()) {
+ unsigned SMEAttrs = FPT->getAArch64SMEAttributes();
+ if (!(SMEAttrs & FunctionType::SME_PStateSMCompatibleMask)) {
+ bool IsStreaming = SMEAttrs & FunctionType::SME_PStateSMEnabledMask;
+ llvm::Value *Call = Builder.CreateIntrinsic(
+ llvm::Intrinsic::aarch64_sme_in_streaming_mode, {}, {});
+ if (!IsStreaming)
+ Call = Builder.CreateNot(Call);
+ Builder.CreateAssumption(Call);
+ }
+ }
+}
+
void AArch64ABIInfo::appendAttributeMangling(TargetClonesAttr *Attr,
unsigned Index,
raw_ostream &Out) const {
diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
index 1e630e196fcb66..80af80682d1946 100644
--- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
+++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_state_funs.c
@@ -22,14 +22,10 @@ bool test_in_streaming_mode_streaming_compatible(void) __arm_streaming_compatibl
// CHECK-LABEL: @test_in_streaming_mode_streaming(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
-// CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP0]])
// CHECK-NEXT: ret i1 true
//
// CPP-CHECK-LABEL: @_Z32test_in_streaming_mode_streamingv(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
-// CPP-CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP0]])
// CPP-CHECK-NEXT: ret i1 true
//
bool test_in_streaming_mode_streaming(void) __arm_streaming {
@@ -38,16 +34,10 @@ bool test_in_streaming_mode_streaming(void) __arm_streaming {
// CHECK-LABEL: @test_in_streaming_mode_non_streaming(
// CHECK-NEXT: entry:
-// CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
-// CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[TMP0]], true
-// CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]])
// CHECK-NEXT: ret i1 false
//
// CPP-CHECK-LABEL: @_Z36test_in_streaming_mode_non_streamingv(
// CPP-CHECK-NEXT: entry:
-// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call i1 @llvm.aarch64.sme.in.streaming.mode()
-// CPP-CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[TMP0]], true
-// CPP-CHECK-NEXT: tail call void @llvm.assume(i1 [[TMP1]])
// CPP-CHECK-NEXT: ret i1 false
//
bool test_in_streaming_mode_non_streaming(void) {
More information about the cfe-commits
mailing list