[clang] [AArch64][SME] Warn when using a streaming builtin from a non-streaming function (PR #74064)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Dec 1 03:43:41 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Sam Tebbs (SamTebbs33)
<details>
<summary>Changes</summary>
This PR adds a warning that's emitted when a non-streaming or non-streaming-compatible builtin is called in an unsuitable function.
Uses work by Kerry McLaughlin.
---
Full diff: https://github.com/llvm/llvm-project/pull/74064.diff
8 Files Affected:
- (modified) clang/include/clang/Basic/CMakeLists.txt (+6)
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
- (modified) clang/include/clang/Sema/Sema.h (+3)
- (modified) clang/lib/Sema/SemaChecking.cpp (+191)
- (modified) clang/test/Sema/aarch64-incompat-sm-builtin-calls.c (+21)
- (modified) clang/utils/TableGen/SveEmitter.cpp (+68)
- (modified) clang/utils/TableGen/TableGen.cpp (+9)
- (modified) clang/utils/TableGen/TableGenBackends.h (+1)
``````````diff
diff --git a/clang/include/clang/Basic/CMakeLists.txt b/clang/include/clang/Basic/CMakeLists.txt
index 085e316fcc671df..bdd72d1d63c431b 100644
--- a/clang/include/clang/Basic/CMakeLists.txt
+++ b/clang/include/clang/Basic/CMakeLists.txt
@@ -97,6 +97,12 @@ clang_tablegen(arm_sme_builtin_cg.inc -gen-arm-sme-builtin-codegen
clang_tablegen(arm_sme_sema_rangechecks.inc -gen-arm-sme-sema-rangechecks
SOURCE arm_sme.td
TARGET ClangARMSmeSemaRangeChecks)
+clang_tablegen(arm_sme_streaming_attrs.inc -gen-arm-sme-streaming-attrs
+ SOURCE arm_sme.td
+ TARGET ClangARMSmeStreamingAttrs)
+clang_tablegen(arm_sme_builtins_za_state.inc -gen-arm-sme-builtin-za-state
+ SOURCE arm_sme.td
+ TARGET ClangARMSmeBuiltinsZAState)
clang_tablegen(arm_cde_builtins.inc -gen-arm-cde-builtin-def
SOURCE arm_cde.td
TARGET ClangARMCdeBuiltinsDef)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 6dfb2d7195203a3..c7036fc881a13e1 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3151,6 +3151,9 @@ def err_attribute_arm_feature_sve_bits_unsupported : Error<
def warn_attribute_arm_sm_incompat_builtin : Warning<
"builtin call has undefined behaviour when called from a %0 function">,
InGroup<DiagGroup<"undefined-arm-streaming">>;
+def warn_attribute_arm_za_builtin_no_za_state : Warning<
+ "builtin call is not valid when calling from a function without active ZA state">,
+ InGroup<DiagGroup<"undefined-arm-za">>;
def err_sve_vector_in_non_sve_target : Error<
"SVE vector type %0 cannot be used in a target without sve">;
def err_attribute_riscv_rvv_bits_unsupported : Error<
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 6de1a098e067a38..c13c6942c219700 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13845,7 +13845,10 @@ class Sema final {
bool CheckNeonBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckMVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+ bool ParseSVEImmChecks(CallExpr *TheCall,
+ SmallVector<std::tuple<int, int, int>, 3> &ImmChecks);
bool CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+ bool CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckCDEBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
CallExpr *TheCall);
bool CheckARMCoprocessorImmediate(const TargetInfo &TI, const Expr *CoprocArg,
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 77c8334f3ca25d3..f27eb8ad95cc703 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2995,6 +2995,134 @@ static QualType getNeonEltType(NeonTypeFlags Flags, ASTContext &Context,
enum ArmStreamingType { ArmNonStreaming, ArmStreaming, ArmStreamingCompatible };
+bool Sema::ParseSVEImmChecks(
+ CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
+ // Perform all the immediate checks for this builtin call.
+ bool HasError = false;
+ for (auto &I : ImmChecks) {
+ int ArgNum, CheckTy, ElementSizeInBits;
+ std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
+
+ typedef bool (*OptionSetCheckFnTy)(int64_t Value);
+
+ // Function that checks whether the operand (ArgNum) is an immediate
+ // that is one of the predefined values.
+ auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
+ int ErrDiag) -> bool {
+ // We can't check the value of a dependent argument.
+ Expr *Arg = TheCall->getArg(ArgNum);
+ if (Arg->isTypeDependent() || Arg->isValueDependent())
+ return false;
+
+ // Check constant-ness first.
+ llvm::APSInt Imm;
+ if (SemaBuiltinConstantArg(TheCall, ArgNum, Imm))
+ return true;
+
+ if (!CheckImm(Imm.getSExtValue()))
+ return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
+ return false;
+ };
+
+ switch ((SVETypeFlags::ImmCheckType)CheckTy) {
+ case SVETypeFlags::ImmCheck0_31:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_13:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck1_16:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_7:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckExtract:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+ (2048 / ElementSizeInBits) - 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckShiftRight:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1, ElementSizeInBits))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckShiftRightNarrow:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 1,
+ ElementSizeInBits / 2))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckShiftLeft:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+ ElementSizeInBits - 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckLaneIndex:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+ (128 / (1 * ElementSizeInBits)) - 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckLaneIndexCompRotate:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+ (128 / (2 * ElementSizeInBits)) - 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckLaneIndexDot:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0,
+ (128 / (4 * ElementSizeInBits)) - 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckComplexRot90_270:
+ if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
+ diag::err_rotation_argument_to_cadd))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheckComplexRotAll90:
+ if (CheckImmediateInSet(
+ [](int64_t V) {
+ return V == 0 || V == 90 || V == 180 || V == 270;
+ },
+ diag::err_rotation_argument_to_cmla))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_1:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_2:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_3:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_0:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_15:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck0_255:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
+ HasError = true;
+ break;
+ case SVETypeFlags::ImmCheck2_4_Mul2:
+ if (SemaBuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
+ SemaBuiltinConstantArgMultiple(TheCall, ArgNum, 2))
+ HasError = true;
+ break;
+ }
+ }
+
+ return HasError;
+}
+
static ArmStreamingType getArmStreamingFnType(const FunctionDecl *FD) {
if (FD->hasAttr<ArmLocallyStreamingAttr>())
return ArmStreaming;
@@ -3023,6 +3151,66 @@ static void checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
<< TheCall->getSourceRange() << "streaming compatible";
return;
}
+
+ if (FnType == ArmNonStreaming && BuiltinType == ArmStreaming) {
+ S.Diag(TheCall->getBeginLoc(), diag::warn_attribute_arm_sm_incompat_builtin)
+ << TheCall->getSourceRange() << "non-streaming";
+ }
+}
+
+static bool hasSMEZAState(const FunctionDecl *FD) {
+ if (FD->hasAttr<ArmNewZAAttr>())
+ return true;
+ if (const auto *T = FD->getType()->getAs<FunctionProtoType>())
+ if (T->getAArch64SMEAttributes() & FunctionType::SME_PStateZASharedMask)
+ return true;
+ return false;
+}
+
+static bool hasSMEZAState(unsigned BuiltinID) {
+ switch (BuiltinID) {
+ default:
+ return false;
+#define GET_SME_BUILTIN_HAS_ZA_STATE
+#include "clang/Basic/arm_sme_builtins_za_state.inc"
+#undef GET_SME_BUILTIN_HAS_ZA_STATE
+ }
+}
+
+bool Sema::CheckSMEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+ if (const FunctionDecl *FD = getCurFunctionDecl()) {
+ bool debug = FD->getDeclName().getAsString() == "incompat_sve_sm";
+ std::optional<ArmStreamingType> BuiltinType;
+
+ switch (BuiltinID) {
+ default:
+ break;
+#define GET_SME_STREAMING_ATTRS
+#include "clang/Basic/arm_sme_streaming_attrs.inc"
+#undef GET_SME_STREAMING_ATTRS
+ }
+
+ if (BuiltinType)
+ checkArmStreamingBuiltin(*this, TheCall, FD, *BuiltinType);
+
+ if (hasSMEZAState(BuiltinID) && !hasSMEZAState(FD))
+ Diag(TheCall->getBeginLoc(),
+ diag::warn_attribute_arm_za_builtin_no_za_state)
+ << TheCall->getSourceRange();
+ }
+
+ // Range check SME intrinsics that take immediate values.
+ SmallVector<std::tuple<int, int, int>, 3> ImmChecks;
+
+ switch (BuiltinID) {
+ default:
+ return false;
+#define GET_SME_IMMEDIATE_CHECK
+#include "clang/Basic/arm_sme_sema_rangechecks.inc"
+#undef GET_SME_IMMEDIATE_CHECK
+ }
+
+ return ParseSVEImmChecks(TheCall, ImmChecks);
}
bool Sema::CheckSVEBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -3559,6 +3747,9 @@ bool Sema::CheckAArch64BuiltinFunctionCall(const TargetInfo &TI,
if (CheckSVEBuiltinFunctionCall(BuiltinID, TheCall))
return true;
+ if (CheckSMEBuiltinFunctionCall(BuiltinID, TheCall))
+ return true;
+
// For intrinsics which take an immediate value as part of the instruction,
// range check them here.
unsigned i = 0, l = 0, u = 0;
diff --git a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
index e77e09c4435188d..8f33075c7a9b51f 100644
--- a/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
+++ b/clang/test/Sema/aarch64-incompat-sm-builtin-calls.c
@@ -5,6 +5,7 @@
// REQUIRES: aarch64-registered-target
#include "arm_neon.h"
+#include "arm_sme_draft_spec_subject_to_change.h"
int16x8_t incompat_neon_sm(int16x8_t splat) __arm_streaming {
// expected-warning at +1 {{builtin call has undefined behaviour when called from a streaming function}}
@@ -20,3 +21,23 @@ int16x8_t incompat_neon_smc(int16x8_t splat) __arm_streaming_compatible {
// expected-warning at +1 {{builtin call has undefined behaviour when called from a streaming compatible function}}
return (int16x8_t)__builtin_neon_vqaddq_v((int8x16_t)splat, (int8x16_t)splat, 33);
}
+
+void incompat_sme_norm(svbool_t pg, void const *ptr) __arm_shared_za {
+ // expected-warning at +1 {{builtin call has undefined behaviour when called from a non-streaming function}}
+ return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
+}
+
+void incompat_sme_smc(svbool_t pg, void const *ptr) __arm_streaming_compatible __arm_shared_za {
+ // expected-warning at +1 {{builtin call has undefined behaviour when called from a streaming compatible function}}
+ return __builtin_sme_svld1_hor_za128(0, 0, pg, ptr);
+}
+
+void incompat_sme_sm(svbool_t pn, svbool_t pm, svfloat32_t zn, svfloat32_t zm) __arm_shared_za {
+ // expected-warning at +1 {{builtin call has undefined behaviour when called from a non-streaming function}}
+ svmops_za32_f32_m(0, pn, pm, zn, zm);
+}
+
+svbool_t streaming_caller_ptrue(void) __arm_streaming {
+ // expected-no-warning
+ return svand_z(svptrue_b16(), svptrue_pat_b16(SV_ALL), svptrue_pat_b16(SV_VL4));
+}
diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp
index b380bd9dfe6643a..cb02ef09e130eb3 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -378,6 +378,9 @@ class SVEEmitter {
/// Emit all the information needed to map builtin -> LLVM IR intrinsic.
void createSMECodeGenMap(raw_ostream &o);
+ /// Create a table for a builtin's requirement for PSTATE.SM.
+ void createStreamingAttrs(raw_ostream &o, ACLEKind Kind);
+
/// Emit all the range checks for the immediates.
void createSMERangeChecks(raw_ostream &o);
@@ -1369,6 +1372,12 @@ void SVEEmitter::createHeader(raw_ostream &OS) {
OS << "#define __aio static __inline__ __attribute__((__always_inline__, "
"__nodebug__, __overloadable__))\n\n";
+ OS << "#ifdef __ARM_FEATURE_SME\n";
+ OS << "#define __asc __attribute__((arm_streaming_compatible))\n";
+ OS << "#else\n";
+ OS << "#define __asc\n";
+ OS << "#endif\n\n";
+
// Add reinterpret functions.
for (auto [N, Suffix] :
std::initializer_list<std::pair<unsigned, const char *>>{
@@ -1688,6 +1697,61 @@ void SVEEmitter::createSMERangeChecks(raw_ostream &OS) {
OS << "#endif\n\n";
}
+void SVEEmitter::createStreamingAttrs(raw_ostream &OS, ACLEKind Kind) {
+ std::vector<Record *> RV = Records.getAllDerivedDefinitions("Inst");
+ SmallVector<std::unique_ptr<Intrinsic>, 128> Defs;
+ for (auto *R : RV)
+ createIntrinsic(R, Defs);
+
+ // The mappings must be sorted based on BuiltinID.
+ llvm::sort(Defs, [](const std::unique_ptr<Intrinsic> &A,
+ const std::unique_ptr<Intrinsic> &B) {
+ return A->getMangledName() < B->getMangledName();
+ });
+
+ switch (Kind) {
+ case ACLEKind::SME:
+ OS << "#ifdef GET_SME_STREAMING_ATTRS\n";
+ break;
+ case ACLEKind::SVE:
+ OS << "#ifdef GET_SVE_STREAMING_ATTRS\n";
+ break;
+ }
+
+ // Ensure these are only emitted once.
+ std::set<std::string> Emitted;
+
+ uint64_t IsStreamingFlag = getEnumValueForFlag("IsStreaming");
+ uint64_t IsStreamingCompatibleFlag =
+ getEnumValueForFlag("IsStreamingCompatible");
+ for (auto &Def : Defs) {
+ if (Emitted.find(Def->getMangledName()) != Emitted.end())
+ continue;
+
+ switch (Kind) {
+ case ACLEKind::SME:
+ OS << "case SME::BI__builtin_sme_";
+ break;
+ case ACLEKind::SVE:
+ OS << "case SVE::BI__builtin_sve_";
+ break;
+ }
+ OS << Def->getMangledName() << ":\n";
+
+ if (Def->isFlagSet(IsStreamingFlag))
+ OS << " BuiltinType = ArmStreaming;\n";
+ else if (Def->isFlagSet(IsStreamingCompatibleFlag))
+ OS << " BuiltinType = ArmStreamingCompatible;\n";
+ else
+ OS << " BuiltinType = ArmNonStreaming;\n";
+ OS << " break;\n";
+
+ Emitted.insert(Def->getMangledName());
+ }
+
+ OS << "#endif\n\n";
+}
+
namespace clang {
void EmitSveHeader(RecordKeeper &Records, raw_ostream &OS) {
SVEEmitter(Records).createHeader(OS);
@@ -1724,4 +1788,8 @@ void EmitSmeBuiltinCG(RecordKeeper &Records, raw_ostream &OS) {
void EmitSmeRangeChecks(RecordKeeper &Records, raw_ostream &OS) {
SVEEmitter(Records).createSMERangeChecks(OS);
}
+
+void EmitSmeStreamingAttrs(RecordKeeper &Records, raw_ostream &OS) {
+ SVEEmitter(Records).createStreamingAttrs(OS, ACLEKind::SME);
+}
} // End namespace clang
diff --git a/clang/utils/TableGen/TableGen.cpp b/clang/utils/TableGen/TableGen.cpp
index 7efb6c731d3e5ee..9ba2fb07f1380da 100644
--- a/clang/utils/TableGen/TableGen.cpp
+++ b/clang/utils/TableGen/TableGen.cpp
@@ -89,6 +89,8 @@ enum ActionType {
GenArmSmeBuiltins,
GenArmSmeBuiltinCG,
GenArmSmeRangeChecks,
+ GenArmSmeStreamingAttrs,
+ GenArmSmeBuiltinZAState,
GenArmCdeHeader,
GenArmCdeBuiltinDef,
GenArmCdeBuiltinSema,
@@ -251,6 +253,10 @@ cl::opt<ActionType> Action(
"Generate arm_sme_builtin_cg_map.inc for clang"),
clEnumValN(GenArmSmeRangeChecks, "gen-arm-sme-sema-rangechecks",
"Generate arm_sme_sema_rangechecks.inc for clang"),
+ clEnumValN(GenArmSmeStreamingAttrs, "gen-arm-sme-streaming-attrs",
+ "Generate arm_sme_streaming_attrs.inc for clang"),
+ clEnumValN(GenArmSmeBuiltinZAState, "gen-arm-sme-builtin-za-state",
+ "Generate arm_sme_builtins_za_state.inc for clang"),
clEnumValN(GenArmMveHeader, "gen-arm-mve-header",
"Generate arm_mve.h for clang"),
clEnumValN(GenArmMveBuiltinDef, "gen-arm-mve-builtin-def",
@@ -500,6 +506,9 @@ bool ClangTableGenMain(raw_ostream &OS, RecordKeeper &Records) {
case GenArmSmeRangeChecks:
EmitSmeRangeChecks(Records, OS);
break;
+ case GenArmSmeStreamingAttrs:
+ EmitSmeStreamingAttrs(Records, OS);
+ break;
case GenArmCdeHeader:
EmitCdeHeader(Records, OS);
break;
diff --git a/clang/utils/TableGen/TableGenBackends.h b/clang/utils/TableGen/TableGenBackends.h
index d8f447069376bca..2f1c96bfa59640c 100644
--- a/clang/utils/TableGen/TableGenBackends.h
+++ b/clang/utils/TableGen/TableGenBackends.h
@@ -109,6 +109,7 @@ void EmitSmeHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
void EmitSmeBuiltins(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
void EmitSmeBuiltinCG(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
void EmitSmeRangeChecks(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
+void EmitSmeStreamingAttrs(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
void EmitMveHeader(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
void EmitMveBuiltinDef(llvm::RecordKeeper &Records, llvm::raw_ostream &OS);
``````````
</details>
https://github.com/llvm/llvm-project/pull/74064
More information about the cfe-commits
mailing list