[llvm] [TableGen][NFCI] Simplify TypeSetByHwMode::intersect and make extensible (PR #81688)

Jessica Clarke via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 14 09:20:02 PST 2024


https://github.com/jrtc27 updated https://github.com/llvm/llvm-project/pull/81688

>From 3393dc69564373bf9223ea080d3448fc1ded13e8 Mon Sep 17 00:00:00 2001
From: Jessica Clarke <jrtc27 at jrtc27.com>
Date: Tue, 13 Feb 2024 23:25:57 +0000
Subject: [PATCH 1/2] [TableGen][NFCI] Simplify TypeSetByHwMode::intersect and
 make extensible

The current implementation considers both iPTR+iN and everything else
all in one go, which leads to more special casing when iPTR is present
in only one set than is described in the comment block. Moreover this
makes it very difficult to add any new iPTR-like wildcards due to the
exponential combinatorial explosion that occurs.

Logically, iPTR+iN handling is entirely independent from everything
else, so rewrite the code to do them separately. This removes special
cases, making the core of the implementation more succinct, whilst more
clearly implementing exactly what is described in the comment block, and
allows for any number of (non-overlapping) wildcards to be added to the
list, as needed by CHERI LLVM downstream (due to having a new capability
type which, much like a normal integer pointer in LLVM, varies in size
between targets and modes).

In testing, this change results in identical TableGen output for all
in-tree backends (including those in LLVM_ALL_EXPERIMENTAL_TARGETS), and
it is intended that this implementation is entirely equivalent to the
old one.
---
 llvm/utils/TableGen/CodeGenDAGPatterns.cpp | 156 +++++++++++----------
 1 file changed, 80 insertions(+), 76 deletions(-)

diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
index a9046e09a62976..9addea8fab98d0 100644
--- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
@@ -41,7 +41,6 @@ static inline bool isIntegerOrPtr(MVT VT) {
 static inline bool isFloatingPoint(MVT VT) { return VT.isFloatingPoint(); }
 static inline bool isVector(MVT VT) { return VT.isVector(); }
 static inline bool isScalar(MVT VT) { return !VT.isVector(); }
-static inline bool isScalarInteger(MVT VT) { return VT.isScalarInteger(); }
 
 template <typename Predicate>
 static bool berase_if(MachineValueTypeSet &S, Predicate P) {
@@ -262,85 +261,90 @@ LLVM_DUMP_METHOD
 void TypeSetByHwMode::dump() const { dbgs() << *this << '\n'; }
 
 bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) {
-  bool OutP = Out.count(MVT::iPTR), InP = In.count(MVT::iPTR);
-  // Complement of In.
-  auto CompIn = [&In](MVT T) -> bool { return !In.count(T); };
-
-  if (OutP == InP)
-    return berase_if(Out, CompIn);
-
-  // Compute the intersection of scalars separately to account for only
-  // one set containing iPTR.
-  // The intersection of iPTR with a set of integer scalar types that does not
-  // include iPTR will result in the most specific scalar type:
-  // - iPTR is more specific than any set with two elements or more
-  // - iPTR is less specific than any single integer scalar type.
-  // For example
-  // { iPTR } * { i32 }     -> { i32 }
-  // { iPTR } * { i32 i64 } -> { iPTR }
-  // and
-  // { iPTR i32 } * { i32 }          -> { i32 }
-  // { iPTR i32 } * { i32 i64 }      -> { i32 i64 }
-  // { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }
-
-  // Let In' = elements only in In, Out' = elements only in Out, and
-  // IO = elements common to both. Normally IO would be returned as the result
-  // of the intersection, but we need to account for iPTR being a "wildcard" of
-  // sorts. Since elements in IO are those that match both sets exactly, they
-  // will all belong to the output. If any of the "leftovers" (i.e. In' or
-  // Out') contain iPTR, it means that the other set doesn't have it, but it
-  // could have (1) a more specific type, or (2) a set of types that is less
-  // specific. The "leftovers" from the other set is what we want to examine
-  // more closely.
-
-  auto subtract = [](const SetType &A, const SetType &B) {
-    SetType Diff = A;
-    berase_if(Diff, [&B](MVT T) { return B.count(T); });
-    return Diff;
-  };
-
-  if (InP) {
-    SetType OutOnly = subtract(Out, In);
-    if (OutOnly.empty()) {
-      // This means that Out \subset In, so no change to Out.
-      return false;
-    }
-    unsigned NumI = llvm::count_if(OutOnly, isScalarInteger);
-    if (NumI == 1 && OutOnly.size() == 1) {
-      // There is only one element in Out', and it happens to be a scalar
-      // integer that should be kept as a match for iPTR in In.
-      return false;
+  auto IntersectP = [&](std::optional<MVT> WildVT, function_ref<bool(MVT)> P) {
+    // Complement of In within this partition.
+    auto CompIn = [&](MVT T) -> bool { return !In.count(T) && P(T); };
+
+    if (!WildVT)
+      return berase_if(Out, CompIn);
+
+    bool OutW = Out.count(*WildVT), InW = In.count(*WildVT);
+    if (OutW == InW)
+      return berase_if(Out, CompIn);
+
+    // Compute the intersection of scalars separately to account for only one
+    // set containing WildVT.
+    // The intersection of WildVT with a set of corresponding types that does
+    // not include WildVT will result in the most specific type:
+    // - WildVT is more specific than any set with two elements or more
+    // - WildVT is less specific than any single type.
+    // For example, for iPTR and scalar integer types
+    // { iPTR } * { i32 }     -> { i32 }
+    // { iPTR } * { i32 i64 } -> { iPTR }
+    // and
+    // { iPTR i32 } * { i32 }          -> { i32 }
+    // { iPTR i32 } * { i32 i64 }      -> { i32 i64 }
+    // { iPTR i32 } * { i32 i64 i128 } -> { iPTR i32 }
+
+    // Looking at just this partition, let In' = elements only in In,
+    // Out' = elements only in Out, and IO = elements common to both. Normally
+    // IO would be returned as the result of the intersection, but we need to
+    // account for WildVT being a "wildcard" of sorts. Since elements in IO are
+    // those that match both sets exactly, they will all belong to the output.
+    // If any of the "leftovers" (i.e. In' or Out') contain WildVT, it means
+    // that the other set doesn't have it, but it could have (1) a more
+    // specific type, or (2) a set of types that is less specific. The
+    // "leftovers" from the other set is what we want to examine more closely.
+
+    auto Leftovers = [&](const SetType &A, const SetType &B) {
+      SetType Diff = A;
+      berase_if(Diff, [&](MVT T) { return B.count(T) || !P(T); });
+      return Diff;
+    };
+
+    if (InW) {
+      SetType OutLeftovers = Leftovers(Out, In);
+      if (OutLeftovers.size() < 2) {
+        // WildVT not added to Out. Keep the possible single leftover.
+        return false;
+      }
+      // WildVT replaces the leftovers.
+      berase_if(Out, CompIn);
+      Out.insert(*WildVT);
+      return true;
     }
-    berase_if(Out, CompIn);
-    if (NumI == 1) {
-      // Replace the iPTR with the leftover scalar integer.
-      Out.insert(*llvm::find_if(OutOnly, isScalarInteger));
-    } else if (NumI > 1) {
-      Out.insert(MVT::iPTR);
+
+    // OutW == true
+    SetType InLeftovers = Leftovers(In, Out);
+    unsigned SizeOut = Out.size();
+    berase_if(Out, CompIn); // This will remove at least the WildVT.
+    if (InLeftovers.size() < 2) {
+      // WildVT deleted from Out. Add back the possible single leftover.
+      Out.insert(InLeftovers);
+      return true;
     }
-    return true;
-  }
 
-  // OutP == true
-  SetType InOnly = subtract(In, Out);
-  unsigned SizeOut = Out.size();
-  berase_if(Out, CompIn); // This will remove at least the iPTR.
-  unsigned NumI = llvm::count_if(InOnly, isScalarInteger);
-  if (NumI == 0) {
-    // iPTR deleted from Out.
-    return true;
-  }
-  if (NumI == 1) {
-    // Replace the iPTR with the leftover scalar integer.
-    Out.insert(*llvm::find_if(InOnly, isScalarInteger));
-    return true;
-  }
+    // Keep the WildVT in Out.
+    Out.insert(*WildVT);
+    // If WildVT was the only element initially removed from Out, then Out
+    // has not changed.
+    return SizeOut != Out.size();
+  };
 
-  // NumI > 1: Keep the iPTR in Out.
-  Out.insert(MVT::iPTR);
-  // If iPTR was the only element initially removed from Out, then Out
-  // has not changed.
-  return SizeOut != Out.size();
+  typedef std::pair<MVT, std::function<bool(MVT)>> WildPartT;
+  static const WildPartT WildParts[] = {
+      {MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }},
+  };
+
+  bool Changed = false;
+  for (const auto &I : WildParts)
+    Changed |= IntersectP(I.first, I.second);
+
+  Changed |= IntersectP(std::nullopt, [&](MVT T) {
+    return !any_of(WildParts, [=](const WildPartT &I) { return I.second(T); });
+  });
+
+  return Changed;
 }
 
 bool TypeSetByHwMode::validate() const {

>From fa8ac858c2d6024d354b6d6dc332dcdfe29c9b4f Mon Sep 17 00:00:00 2001
From: Jessica Clarke <jrtc27 at jrtc27.com>
Date: Wed, 14 Feb 2024 17:19:24 +0000
Subject: [PATCH 2/2] Add non-overlapping note, prefer using over typedef

---
 llvm/utils/TableGen/CodeGenDAGPatterns.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
index 9addea8fab98d0..f63b7f836ff9af 100644
--- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp
@@ -331,7 +331,8 @@ bool TypeSetByHwMode::intersect(SetType &Out, const SetType &In) {
     return SizeOut != Out.size();
   };
 
-  typedef std::pair<MVT, std::function<bool(MVT)>> WildPartT;
+  // Note: must be non-overlapping
+  using WildPartT = std::pair<MVT, std::function<bool(MVT)>>;
   static const WildPartT WildParts[] = {
       {MVT::iPTR, [](MVT T) { return T.isScalarInteger() || T == MVT::iPTR; }},
   };



More information about the llvm-commits mailing list