[clang] [llvm] [Clang][AArch64] Add customisable immediate range checking to NEON (PR #100278)

via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 2 08:45:14 PDT 2024


================
@@ -403,142 +369,183 @@ enum ArmSMEState : unsigned {
   ArmZT0Mask = 0b11 << 2
 };
 
+bool SemaARM::CheckImmediateArg(CallExpr *TheCall, unsigned CheckTy,
+                                unsigned ArgIdx, unsigned EltBitWidth,
+                                unsigned VecBitWidth) {
+
+  typedef bool (*OptionSetCheckFnTy)(int64_t Value);
+
+  // Function that checks whether the operand (ArgIdx) is an immediate
+  // that is one of the predefined values.
+  auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
+                                 int ErrDiag) -> bool {
+    // We can't check the value of a dependent argument.
+    Expr *Arg = TheCall->getArg(ArgIdx);
+    if (Arg->isTypeDependent() || Arg->isValueDependent())
+      return false;
+
+    // Check constant-ness first.
+    llvm::APSInt Imm;
+    if (SemaRef.BuiltinConstantArg(TheCall, ArgIdx, Imm))
+      return true;
+
+    if (!CheckImm(Imm.getSExtValue()))
+      return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
+    return false;
+  };
+
+  switch ((ImmCheckType)CheckTy) {
+  case ImmCheckType::ImmCheck0_31:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 31))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_13:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 13))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_63:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 63))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_16:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 16))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_7:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 7))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_1:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_3:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 3))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_7:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 7))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckExtract:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+                                        (2048 / EltBitWidth) - 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckCvt:
+  case ImmCheckType::ImmCheckShiftRight:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckShiftRightNarrow:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth / 2))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckShiftLeft:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, EltBitWidth - 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckLaneIndex:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+                                        (VecBitWidth / EltBitWidth) - 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckLaneIndexCompRotate:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+                                        (VecBitWidth / (2 * EltBitWidth)) - 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckLaneIndexDot:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0,
+                                        (VecBitWidth / (4 * EltBitWidth)) - 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckComplexRot90_270:
+    if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
+                            diag::err_rotation_argument_to_cadd))
+      return true;
+    break;
+  case ImmCheckType::ImmCheckComplexRotAll90:
+    if (CheckImmediateInSet(
+            [](int64_t V) { return V == 0 || V == 90 || V == 180 || V == 270; },
+            diag::err_rotation_argument_to_cmla))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_1:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 1))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_2:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 2))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_3:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 3))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_0:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 0))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_15:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 15))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck0_255:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 255))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_32:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 32))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck1_64:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 64))
+      return true;
+    break;
+  case ImmCheckType::ImmCheck2_4_Mul2:
+    if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 2, 4) ||
+        SemaRef.BuiltinConstantArgMultiple(TheCall, ArgIdx, 2))
+      return true;
+    break;
+  default:
+    llvm_unreachable("Invalid immediate range typeflag!");
+    break;
+  }
+  return false;
+}
+
+bool SemaARM::ParseNeonImmChecks(
+    CallExpr *TheCall,
+    SmallVector<std::tuple<int, int, int, int>, 2> &ImmChecks,
+    int OverloadType = -1) {
+  unsigned CheckTy;
+  unsigned ArgIdx, ElementSizeInBits, VecSizeInBits;
+  bool HasError = false;
+
+  for (const auto &I : ImmChecks) {
+    std::tie(ArgIdx, CheckTy, ElementSizeInBits, VecSizeInBits) = I;
+
+    if (OverloadType >= 0)
+      ElementSizeInBits = NeonTypeFlags(OverloadType).getEltSizeInBits();
+
+    HasError |= CheckImmediateArg(TheCall, CheckTy, ArgIdx, ElementSizeInBits,
+                                  VecSizeInBits);
+  }
+
+  return HasError;
+}
+
 bool SemaARM::ParseSVEImmChecks(
     CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) {
-  // Perform all the immediate checks for this builtin call.
-  bool HasError = false;
-  for (auto &I : ImmChecks) {
-    int ArgNum, CheckTy, ElementSizeInBits;
-    std::tie(ArgNum, CheckTy, ElementSizeInBits) = I;
-
-    typedef bool (*OptionSetCheckFnTy)(int64_t Value);
-
-    // Function that checks whether the operand (ArgNum) is an immediate
-    // that is one of the predefined values.
-    auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm,
-                                   int ErrDiag) -> bool {
-      // We can't check the value of a dependent argument.
-      Expr *Arg = TheCall->getArg(ArgNum);
-      if (Arg->isTypeDependent() || Arg->isValueDependent())
-        return false;
-
-      // Check constant-ness first.
-      llvm::APSInt Imm;
-      if (SemaRef.BuiltinConstantArg(TheCall, ArgNum, Imm))
-        return true;
 
-      if (!CheckImm(Imm.getSExtValue()))
-        return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange();
-      return false;
-    };
+  bool HasError = false;
+  unsigned CheckTy, ArgIdx, ElementSizeInBits;
 
-    switch ((SVETypeFlags::ImmCheckType)CheckTy) {
-    case SVETypeFlags::ImmCheck0_31:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 31))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_13:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 13))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck1_16:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 16))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_7:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 7))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck1_1:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck1_3:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 3))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck1_7:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 7))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckExtract:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
-                                          (2048 / ElementSizeInBits) - 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckShiftRight:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1,
-                                          ElementSizeInBits))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckShiftRightNarrow:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1,
-                                          ElementSizeInBits / 2))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckShiftLeft:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
-                                          ElementSizeInBits - 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckLaneIndex:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
-                                          (128 / (1 * ElementSizeInBits)) - 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckLaneIndexCompRotate:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
-                                          (128 / (2 * ElementSizeInBits)) - 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckLaneIndexDot:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0,
-                                          (128 / (4 * ElementSizeInBits)) - 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckComplexRot90_270:
-      if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; },
-                              diag::err_rotation_argument_to_cadd))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheckComplexRotAll90:
-      if (CheckImmediateInSet(
-              [](int64_t V) {
-                return V == 0 || V == 90 || V == 180 || V == 270;
-              },
-              diag::err_rotation_argument_to_cmla))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_1:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 1))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_2:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 2))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_3:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 3))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_0:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 0))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_15:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 15))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck0_255:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 255))
-        HasError = true;
-      break;
-    case SVETypeFlags::ImmCheck2_4_Mul2:
-      if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 2, 4) ||
-          SemaRef.BuiltinConstantArgMultiple(TheCall, ArgNum, 2))
-        HasError = true;
-      break;
-    }
+  for (const auto &I : ImmChecks) {
+    std::tie(ArgIdx, CheckTy, ElementSizeInBits) = I;
+    HasError |=
----------------
SpencerAbson wrote:

Done.

https://github.com/llvm/llvm-project/pull/100278


More information about the cfe-commits mailing list