[llvm] [AsmParser][NFCI] Restructure DiagnosticPredicate (PR #126653)

Sam Elliott via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 10 18:13:40 PST 2025


https://github.com/lenary created https://github.com/llvm/llvm-project/pull/126653

This restructures the DiagnosticPredicate type to more closely match how neatly ParseStatus is defined. The DiagnosticPredicateTy enum is moved inside the class, but is exposed with constexpr/using so that references don't have to be so long.

The main user of this code is AArch64, which I have also updated to use the new structure of the code.

>From 45ce215f07c702e26fb3e8a3c419c2b5dbaed8d6 Mon Sep 17 00:00:00 2001
From: Sam Elliott <quic_aelliott at quicinc.com>
Date: Mon, 10 Feb 2025 18:03:26 -0800
Subject: [PATCH] [AsmParser][NFCI] Restructure DiagnosticPredicate

This restructures the DiagnosticPredicate type to more closely match how
neatly ParseStatus is defined. The DiagnosticPredicateTy enum is moved
inside the class, but is exposed with constexpr/using so that references
don't have to be so long.

The main user of this code is AArch64, which I have also updated to use
the new structure of the code.
---
 .../llvm/MC/MCParser/MCTargetAsmParser.h      |  43 +++---
 .../AArch64/AsmParser/AArch64AsmParser.cpp    | 126 +++++++++---------
 2 files changed, 90 insertions(+), 79 deletions(-)

diff --git a/llvm/include/llvm/MC/MCParser/MCTargetAsmParser.h b/llvm/include/llvm/MC/MCParser/MCTargetAsmParser.h
index 4c88448e6a1285a..c856d9f68b4822c 100644
--- a/llvm/include/llvm/MC/MCParser/MCTargetAsmParser.h
+++ b/llvm/include/llvm/MC/MCParser/MCTargetAsmParser.h
@@ -167,12 +167,6 @@ class ParseStatus {
   }
 };
 
-enum class DiagnosticPredicateTy {
-  Match,
-  NearMatch,
-  NoMatch,
-};
-
 // When an operand is parsed, the assembler will try to iterate through a set of
 // possible operand classes that the operand might match and call the
 // corresponding PredicateMethod to determine that.
@@ -199,20 +193,35 @@ enum class DiagnosticPredicateTy {
 // below which collects *all* possible diagnostics. This alternative
 // is optional and fully backward compatible with existing
 // PredicateMethods that return a 'bool' (match or no match).
-struct DiagnosticPredicate {
-  DiagnosticPredicateTy Type;
+class DiagnosticPredicate {
+  enum class PredicateTy {
+    Match,     // Matches
+    NearMatch, // Close Match: use Specific Diagnostic
+    NoMatch,   // No Match: use `InvalidOperand`
+  } Predicate;
+
+public:
+#if __cplusplus >= 202002L
+  using enum PredicateTy;
+#else
+  static constexpr PredicateTy Match = PredicateTy::Match;
+  static constexpr PredicateTy NearMatch = PredicateTy::NearMatch;
+  static constexpr PredicateTy NoMatch = PredicateTy::NoMatch;
+#endif
+
+  constexpr DiagnosticPredicate(PredicateTy T) : Predicate(T) {}
 
-  explicit DiagnosticPredicate(bool Match)
-      : Type(Match ? DiagnosticPredicateTy::Match
-                   : DiagnosticPredicateTy::NearMatch) {}
-  DiagnosticPredicate(DiagnosticPredicateTy T) : Type(T) {}
-  DiagnosticPredicate(const DiagnosticPredicate &) = default;
   DiagnosticPredicate& operator=(const DiagnosticPredicate &) = default;
+  DiagnosticPredicate(const DiagnosticPredicate &) = default;
+
+  explicit constexpr DiagnosticPredicate(bool Matches)
+      : Predicate(Matches ? Match : NearMatch) {}
+
+  operator bool() const { return Predicate == Match; }
 
-  operator bool() const { return Type == DiagnosticPredicateTy::Match; }
-  bool isMatch() const { return Type == DiagnosticPredicateTy::Match; }
-  bool isNearMatch() const { return Type == DiagnosticPredicateTy::NearMatch; }
-  bool isNoMatch() const { return Type == DiagnosticPredicateTy::NoMatch; }
+  constexpr bool isMatch() const { return Predicate == Match; }
+  constexpr bool isNearMatch() const { return Predicate == NearMatch; }
+  constexpr bool isNoMatch() const { return Predicate == NoMatch; }
 };
 
 // When matching of an assembly instruction fails, there may be multiple
diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
index 335b46b76688fa4..38963810e28d637 100644
--- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
+++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
@@ -824,7 +824,7 @@ class AArch64Operand : public MCParsedAsmOperand {
   DiagnosticPredicate isUImmScaled() const {
     if (IsRange && isImmRange() &&
         (getLastImmVal() != getFirstImmVal() + Offset))
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     return isImmScaled<Bits, Scale, IsRange>(false);
   }
@@ -833,7 +833,7 @@ class AArch64Operand : public MCParsedAsmOperand {
   DiagnosticPredicate isImmScaled(bool Signed) const {
     if ((!isImm() && !isImmRange()) || (isImm() && IsRange) ||
         (isImmRange() && !IsRange))
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     int64_t Val;
     if (isImmRange())
@@ -841,7 +841,7 @@ class AArch64Operand : public MCParsedAsmOperand {
     else {
       const MCConstantExpr *MCE = dyn_cast<MCConstantExpr>(getImm());
       if (!MCE)
-        return DiagnosticPredicateTy::NoMatch;
+        return DiagnosticPredicate::NoMatch;
       Val = MCE->getValue();
     }
 
@@ -856,33 +856,33 @@ class AArch64Operand : public MCParsedAsmOperand {
     }
 
     if (Val >= MinVal && Val <= MaxVal && (Val % Scale) == 0)
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   DiagnosticPredicate isSVEPattern() const {
     if (!isImm())
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     auto *MCE = dyn_cast<MCConstantExpr>(getImm());
     if (!MCE)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     int64_t Val = MCE->getValue();
     if (Val >= 0 && Val < 32)
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NearMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NearMatch;
   }
 
   DiagnosticPredicate isSVEVecLenSpecifier() const {
     if (!isImm())
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     auto *MCE = dyn_cast<MCConstantExpr>(getImm());
     if (!MCE)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     int64_t Val = MCE->getValue();
     if (Val >= 0 && Val <= 1)
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NearMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NearMatch;
   }
 
   bool isSymbolicUImm12Offset(const MCExpr *Expr) const {
@@ -1057,7 +1057,7 @@ class AArch64Operand : public MCParsedAsmOperand {
   template <typename T>
   DiagnosticPredicate isSVECpyImm() const {
     if (!isShiftedImm() && (!isImm() || !isa<MCConstantExpr>(getImm())))
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     bool IsByte = std::is_same<int8_t, std::make_signed_t<T>>::value ||
                   std::is_same<int8_t, T>::value;
@@ -1065,9 +1065,9 @@ class AArch64Operand : public MCParsedAsmOperand {
       if (!(IsByte && ShiftedImm->second) &&
           AArch64_AM::isSVECpyImm<T>(uint64_t(ShiftedImm->first)
                                      << ShiftedImm->second))
-        return DiagnosticPredicateTy::Match;
+        return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   // Unsigned value in the range 0 to 255. For element widths of
@@ -1075,7 +1075,7 @@ class AArch64Operand : public MCParsedAsmOperand {
   // range 0 to 65280.
   template <typename T> DiagnosticPredicate isSVEAddSubImm() const {
     if (!isShiftedImm() && (!isImm() || !isa<MCConstantExpr>(getImm())))
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     bool IsByte = std::is_same<int8_t, std::make_signed_t<T>>::value ||
                   std::is_same<int8_t, T>::value;
@@ -1083,15 +1083,15 @@ class AArch64Operand : public MCParsedAsmOperand {
       if (!(IsByte && ShiftedImm->second) &&
           AArch64_AM::isSVEAddSubImm<T>(ShiftedImm->first
                                         << ShiftedImm->second))
-        return DiagnosticPredicateTy::Match;
+        return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <typename T> DiagnosticPredicate isSVEPreferredLogicalImm() const {
     if (isLogicalImm<T>() && !isSVECpyImm<T>())
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NoMatch;
   }
 
   bool isCondCode() const { return Kind == k_CondCode; }
@@ -1319,48 +1319,48 @@ class AArch64Operand : public MCParsedAsmOperand {
   template <int ElementWidth, unsigned Class>
   DiagnosticPredicate isSVEPredicateVectorRegOfWidth() const {
     if (Kind != k_Register || Reg.Kind != RegKind::SVEPredicateVector)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (isSVEVectorReg<Class>() && (Reg.ElementWidth == ElementWidth))
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <int ElementWidth, unsigned Class>
   DiagnosticPredicate isSVEPredicateOrPredicateAsCounterRegOfWidth() const {
     if (Kind != k_Register || (Reg.Kind != RegKind::SVEPredicateAsCounter &&
                                Reg.Kind != RegKind::SVEPredicateVector))
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if ((isSVEPredicateAsCounterReg<Class>() ||
          isSVEPredicateVectorRegOfWidth<ElementWidth, Class>()) &&
         Reg.ElementWidth == ElementWidth)
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <int ElementWidth, unsigned Class>
   DiagnosticPredicate isSVEPredicateAsCounterRegOfWidth() const {
     if (Kind != k_Register || Reg.Kind != RegKind::SVEPredicateAsCounter)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (isSVEPredicateAsCounterReg<Class>() && (Reg.ElementWidth == ElementWidth))
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <int ElementWidth, unsigned Class>
   DiagnosticPredicate isSVEDataVectorRegOfWidth() const {
     if (Kind != k_Register || Reg.Kind != RegKind::SVEDataVector)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (isSVEVectorReg<Class>() && Reg.ElementWidth == ElementWidth)
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <int ElementWidth, unsigned Class,
@@ -1369,7 +1369,7 @@ class AArch64Operand : public MCParsedAsmOperand {
   DiagnosticPredicate isSVEDataVectorRegWithShiftExtend() const {
     auto VectorMatch = isSVEDataVectorRegOfWidth<ElementWidth, Class>();
     if (!VectorMatch.isMatch())
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     // Give a more specific diagnostic when the user has explicitly typed in
     // a shift-amount that does not match what is expected, but for which
@@ -1378,12 +1378,12 @@ class AArch64Operand : public MCParsedAsmOperand {
     if (!MatchShift && (ShiftExtendTy == AArch64_AM::UXTW ||
                         ShiftExtendTy == AArch64_AM::SXTW) &&
         !ShiftWidthAlwaysSame && hasShiftExtendAmount() && ShiftWidth == 8)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (MatchShift && ShiftExtendTy == getShiftExtendType())
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   bool isGPR32as64() const {
@@ -1420,15 +1420,17 @@ class AArch64Operand : public MCParsedAsmOperand {
 
   template<int64_t Angle, int64_t Remainder>
   DiagnosticPredicate isComplexRotation() const {
-    if (!isImm()) return DiagnosticPredicateTy::NoMatch;
+    if (!isImm())
+      return DiagnosticPredicate::NoMatch;
 
     const MCConstantExpr *CE = dyn_cast<MCConstantExpr>(getImm());
-    if (!CE) return DiagnosticPredicateTy::NoMatch;
+    if (!CE)
+      return DiagnosticPredicate::NoMatch;
     uint64_t Value = CE->getValue();
 
     if (Value % Angle == Remainder && Value <= 270)
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NearMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <unsigned RegClassID> bool isGPR64() const {
@@ -1439,12 +1441,12 @@ class AArch64Operand : public MCParsedAsmOperand {
   template <unsigned RegClassID, int ExtWidth>
   DiagnosticPredicate isGPR64WithShiftExtend() const {
     if (Kind != k_Register || Reg.Kind != RegKind::Scalar)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (isGPR64<RegClassID>() && getShiftExtendType() == AArch64_AM::LSL &&
         getShiftExtendAmount() == Log2_32(ExtWidth / 8))
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NearMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NearMatch;
   }
 
   /// Is this a vector list with the type implicit (presumably attached to the
@@ -1479,10 +1481,10 @@ class AArch64Operand : public MCParsedAsmOperand {
     bool Res =
         isTypedVectorList<VectorKind, NumRegs, NumElements, ElementWidth>();
     if (!Res)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     if (!AArch64MCRegisterClasses[RegClass].contains(VectorList.RegNum))
-      return DiagnosticPredicateTy::NearMatch;
-    return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::NearMatch;
+    return DiagnosticPredicate::Match;
   }
 
   template <RegKind VectorKind, unsigned NumRegs, unsigned Stride,
@@ -1491,21 +1493,21 @@ class AArch64Operand : public MCParsedAsmOperand {
     bool Res = isTypedVectorList<VectorKind, NumRegs, /*NumElements*/ 0,
                                  ElementWidth, Stride>();
     if (!Res)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     if ((VectorList.RegNum < (AArch64::Z0 + Stride)) ||
         ((VectorList.RegNum >= AArch64::Z16) &&
          (VectorList.RegNum < (AArch64::Z16 + Stride))))
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NoMatch;
   }
 
   template <int Min, int Max>
   DiagnosticPredicate isVectorIndex() const {
     if (Kind != k_VectorIndex)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     if (VectorIndex.Val >= Min && VectorIndex.Val <= Max)
-      return DiagnosticPredicateTy::Match;
-    return DiagnosticPredicateTy::NearMatch;
+      return DiagnosticPredicate::Match;
+    return DiagnosticPredicate::NearMatch;
   }
 
   bool isToken() const override { return Kind == k_Token; }
@@ -1531,7 +1533,7 @@ class AArch64Operand : public MCParsedAsmOperand {
 
   template <unsigned ImmEnum> DiagnosticPredicate isExactFPImm() const {
     if (Kind != k_FPImm)
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
 
     if (getFPImmIsExact()) {
       // Lookup the immediate from table of supported immediates.
@@ -1546,19 +1548,19 @@ class AArch64Operand : public MCParsedAsmOperand {
         llvm_unreachable("FP immediate is not exact");
 
       if (getFPImm().bitwiseIsEqual(RealVal))
-        return DiagnosticPredicateTy::Match;
+        return DiagnosticPredicate::Match;
     }
 
-    return DiagnosticPredicateTy::NearMatch;
+    return DiagnosticPredicate::NearMatch;
   }
 
   template <unsigned ImmA, unsigned ImmB>
   DiagnosticPredicate isExactFPImm() const {
-    DiagnosticPredicate Res = DiagnosticPredicateTy::NoMatch;
+    DiagnosticPredicate Res = DiagnosticPredicate::NoMatch;
     if ((Res = isExactFPImm<ImmA>()))
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
     if ((Res = isExactFPImm<ImmB>()))
-      return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::Match;
     return Res;
   }
 
@@ -1741,12 +1743,12 @@ class AArch64Operand : public MCParsedAsmOperand {
   template <MatrixKind Kind, unsigned EltSize, unsigned RegClass>
   DiagnosticPredicate isMatrixRegOperand() const {
     if (!isMatrix())
-      return DiagnosticPredicateTy::NoMatch;
+      return DiagnosticPredicate::NoMatch;
     if (getMatrixKind() != Kind ||
         !AArch64MCRegisterClasses[RegClass].contains(getMatrixReg()) ||
         EltSize != getMatrixElementWidth())
-      return DiagnosticPredicateTy::NearMatch;
-    return DiagnosticPredicateTy::Match;
+      return DiagnosticPredicate::NearMatch;
+    return DiagnosticPredicate::Match;
   }
 
   bool isPAuthPCRelLabel16Operand() const {



More information about the llvm-commits mailing list