[clang] [llvm] [Clang][AArch64] Add customisable immediate range checking to NEON (PR #100278)
Momchil Velikov via cfe-commits
cfe-commits at lists.llvm.org
Mon Sep 2 03:36:52 PDT 2024
================
@@ -403,142 +369,183 @@ enum ArmSMEState : unsigned {
ArmZT0Mask = 0b11 << 2
};
+bool SemaARM::CheckImmediateArg(CallExpr *TheCall, unsigned CheckTy,
+ unsigned ArgIdx, unsigned EltBitWidth,
+ unsigned VecBitWidth) {
+
+ typedef bool (*OptionSetCheckFnTy)(int64_t Value);
+
+ // Function that checks whether the operand (ArgIdx) is an immediate
+ // that is one of the predefined values.
+ auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
+ int ErrDiag) -> bool {
+ // We can't check the value of a dependent argument.
+ Expr *Arg = TheCall->getArg(ArgIdx);
+ if (Arg->isTypeDependent() || Arg->isValueDependent())
+ return false;
+
+ // Check constant-ness first.
+ llvm::APSInt Imm;
+ if (SemaRef.BuiltinConstantArg(TheCall, ArgIdx, Imm))
+ return true;
+
+ if (!CheckImm(Imm.getSExtValue()))
+ return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
+ return false;
+ };
+
+ switch ((ImmCheckType)CheckTy) {
+ case ImmCheckType::ImmCheck0_31:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 31))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_13:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 13))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_63:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 63))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_16:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 16))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_7:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 7))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_1:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_3:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 3))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_7:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 7))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckExtract:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+ (2048 / EltBitWidth) - 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckCvt:
+ case ImmCheckType::ImmCheckShiftRight:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckShiftRightNarrow:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth / 2))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckShiftLeft:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, EltBitWidth - 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckLaneIndex:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+ (VecBitWidth / EltBitWidth) - 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckLaneIndexCompRotate:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+ (VecBitWidth / (2 * EltBitWidth)) - 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckLaneIndexDot:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+ (VecBitWidth / (4 * EltBitWidth)) - 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckComplexRot90_270:
+ if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
+ diag::err_rotation_argument_to_cadd))
+ return true;
+ break;
+ case ImmCheckType::ImmCheckComplexRotAll90:
+ if (CheckImmediateInSet(
+ [](int64_t V) { return V == 0 || V == 90 || V == 180 || V == 270; },
+ diag::err_rotation_argument_to_cmla))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_1:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 1))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_2:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 2))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_3:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 3))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_0:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 0))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_15:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 15))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck0_255:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 255))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_32:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 32))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck1_64:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 64))
+ return true;
+ break;
+ case ImmCheckType::ImmCheck2_4_Mul2:
+ if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 2, 4) ||
+ SemaRef.BuiltinConstantArgMultiple(TheCall, ArgIdx, 2))
+ return true;
+ break;
+ default:
+ llvm_unreachable("Invalid immediate range typeflag!");
+ break;
+ }
+ return false;
+}
+
+bool SemaARM::ParseNeonImmChecks(
+ CallExpr *TheCall,
+ SmallVector<std::tuple<int, int, int, int>, 2> &ImmChecks,
+ int OverloadType = -1) {
+ unsigned CheckTy;
+ unsigned ArgIdx, ElementSizeInBits, VecSizeInBits;
+ bool HasError = false;
+
+ for (const auto &I : ImmChecks) {
+ std::tie(ArgIdx, CheckTy, ElementSizeInBits, VecSizeInBits) = I;
+
+ if (OverloadType >= 0)
+ ElementSizeInBits = NeonTypeFlags(OverloadType).getEltSizeInBits();
+
+ HasError |= CheckImmediateArg(TheCall, CheckTy, ArgIdx, ElementSizeInBits,
+ VecSizeInBits);
+ }
+
+ return HasError;
+}
+
bool SemaARM::ParseSVEImmChecks(
CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
- // Perform all the immediate checks for this builtin call.
- bool HasError = false;
- for (auto &I : ImmChecks) {
- int ArgNum, CheckTy, ElementSizeInBits;
- std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
-
- typedef bool (*OptionSetCheckFnTy)(int64_t Value);
-
- // Function that checks whether the operand (ArgNum) is an immediate
- // that is one of the predefined values.
- auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
- int ErrDiag) -> bool {
- // We can't check the value of a dependent argument.
- Expr *Arg = TheCall->getArg(ArgNum);
- if (Arg->isTypeDependent() || Arg->isValueDependent())
- return false;
-
- // Check constant-ness first.
- llvm::APSInt Imm;
- if (SemaRef.BuiltinConstantArg(TheCall, ArgNum, Imm))
- return true;
- if (!CheckImm(Imm.getSExtValue()))
- return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
- return false;
- };
+ bool HasError = false;
+ unsigned CheckTy, ArgIdx, ElementSizeInBits;
- switch ((SVETypeFlags::ImmCheckType)CheckTy) {
- case SVETypeFlags::ImmCheck0_31:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_13:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck1_16:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_7:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck1_1:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck1_3:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 3))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck1_7:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 7))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckExtract:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
- (2048 / ElementSizeInBits) - 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckShiftRight:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1,
- ElementSizeInBits))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckShiftRightNarrow:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1,
- ElementSizeInBits / 2))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckShiftLeft:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
- ElementSizeInBits - 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckLaneIndex:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
- (128 / (1 * ElementSizeInBits)) - 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckLaneIndexCompRotate:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
- (128 / (2 * ElementSizeInBits)) - 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckLaneIndexDot:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
- (128 / (4 * ElementSizeInBits)) - 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckComplexRot90_270:
- if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
- diag::err_rotation_argument_to_cadd))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheckComplexRotAll90:
- if (CheckImmediateInSet(
- [](int64_t V) {
- return V == 0 || V == 90 || V == 180 || V == 270;
- },
- diag::err_rotation_argument_to_cmla))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_1:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_2:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_3:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_0:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_15:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck0_255:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
- HasError = true;
- break;
- case SVETypeFlags::ImmCheck2_4_Mul2:
- if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
- SemaRef.BuiltinConstantArgMultiple(TheCall, ArgNum, 2))
- HasError = true;
- break;
- }
+ for (const auto &I : ImmChecks) {
+ std::tie(ArgIdx, CheckTy, ElementSizeInBits) = I;
+ HasError |=
----------------
momchil-velikov wrote:
LIkewise, can use structured binding declaration.
https://github.com/llvm/llvm-project/pull/100278
More information about the cfe-commits
mailing list