[llvm] 641428f - [TableGen] Enhance the six comparison bang operators.

Paul C. Anagnostopoulos via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 13 06:57:58 PST 2020


Author: Paul C. Anagnostopoulos
Date: 2020-11-13T09:57:27-05:00
New Revision: 641428f9288b9ae2b574219ebf773d3bfbf6e8a0

URL: https://github.com/llvm/llvm-project/commit/641428f9288b9ae2b574219ebf773d3bfbf6e8a0
DIFF: https://github.com/llvm/llvm-project/commit/641428f9288b9ae2b574219ebf773d3bfbf6e8a0.diff

LOG: [TableGen] Enhance the six comparison bang operators.

Update the Programmer's Reference.

Differential Revision: https://reviews.llvm.org/D91036

Added: 
    

Modified: 
    llvm/docs/TableGen/ProgRef.rst
    llvm/lib/TableGen/Record.cpp
    llvm/lib/TableGen/TGParser.cpp
    llvm/test/TableGen/compare.td

Removed: 
    


################################################################################
diff  --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index 57abe142babc..c805afc8ccc1 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -1560,8 +1560,8 @@ and non-0 as true.
 
 ``!eq(`` *a*\ `,` *b*\ ``)``
     This operator produces 1 if *a* is equal to *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
-    Use ``!cast<string>`` to compare other types of objects.
+    The arguments must be ``bit``, ``bits``, ``int``, ``string``, or 
+    record values. Use ``!cast<string>`` to compare other types of objects.
 
 ``!filter(``\ *var*\ ``,`` *list*\ ``,`` *predicate*\ ``)``
 
@@ -1603,7 +1603,7 @@ and non-0 as true.
 
 ``!ge(``\ *a*\ `,` *b*\ ``)``
     This operator produces 1 if *a* is greater than or equal to *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, or ``int`` values.
+    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
 
 ``!getdagop(``\ *dag*\ ``)`` --or-- ``!getdagop<``\ *type*\ ``>(``\ *dag*\ ``)``
     This operator produces the operator of the given *dag* node.
@@ -1629,7 +1629,7 @@ and non-0 as true.
 
 ``!gt(``\ *a*\ `,` *b*\ ``)``
     This operator produces 1 if *a* is greater than *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, or ``int`` values.
+    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
 
 ``!head(``\ *a*\ ``)``
     This operator produces the zeroth element of the list *a*.
@@ -1652,7 +1652,7 @@ and non-0 as true.
 
 ``!le(``\ *a*\ ``,`` *b*\ ``)``
     This operator produces 1 if *a* is less than or equal to *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, or ``int`` values.
+    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
 
 ``!listconcat(``\ *list1*\ ``,`` *list2*\ ``, ...)``
     This operator concatenates the list arguments *list1*, *list2*, etc., and
@@ -1665,15 +1665,15 @@ and non-0 as true.
 
 ``!lt(``\ *a*\ `,` *b*\ ``)``
     This operator produces 1 if *a* is less than *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, or ``int`` values.
+    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
 
 ``!mul(``\ *a*\ ``,`` *b*\ ``, ...)``
     This operator multiplies *a*, *b*, etc., and produces the product.
 
 ``!ne(``\ *a*\ `,` *b*\ ``)``
     This operator produces 1 if *a* is not equal to *b*; 0 otherwise.
-    The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values.
-    Use ``!cast<string>`` to compare other types of objects.
+    The arguments must be ``bit``, ``bits``, ``int``, ``string``,
+    or record values. Use ``!cast<string>`` to compare other types of objects.
 
 ``!not(``\ *a*\ ``)``
     This operator performs a logical NOT on *a*, which must be

diff  --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index e17a29cba009..81d700263bb1 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1010,36 +1010,53 @@ Init *BinOpInit::Fold(Record *CurRec) const {
   case LT:
   case GE:
   case GT: {
-    // try to fold eq comparison for 'bit' and 'int', otherwise fallback
-    // to string objects.
-    IntInit *L =
+    // First see if we have two bit, bits, or int.
+    IntInit *LHSi =
         dyn_cast_or_null<IntInit>(LHS->convertInitializerTo(IntRecTy::get()));
-    IntInit *R =
+    IntInit *RHSi =
         dyn_cast_or_null<IntInit>(RHS->convertInitializerTo(IntRecTy::get()));
 
-    if (L && R) {
+    if (LHSi && RHSi) {
       bool Result;
       switch (getOpcode()) {
-      case EQ: Result = L->getValue() == R->getValue(); break;
-      case NE: Result = L->getValue() != R->getValue(); break;
-      case LE: Result = L->getValue() <= R->getValue(); break;
-      case LT: Result = L->getValue() < R->getValue(); break;
-      case GE: Result = L->getValue() >= R->getValue(); break;
-      case GT: Result = L->getValue() > R->getValue(); break;
+      case EQ: Result = LHSi->getValue() == RHSi->getValue(); break;
+      case NE: Result = LHSi->getValue() != RHSi->getValue(); break;
+      case LE: Result = LHSi->getValue() <= RHSi->getValue(); break;
+      case LT: Result = LHSi->getValue() <  RHSi->getValue(); break;
+      case GE: Result = LHSi->getValue() >= RHSi->getValue(); break;
+      case GT: Result = LHSi->getValue() >  RHSi->getValue(); break;
       default: llvm_unreachable("unhandled comparison");
       }
       return BitInit::get(Result);
     }
 
-    if (getOpcode() == EQ || getOpcode() == NE) {
-      StringInit *LHSs = dyn_cast<StringInit>(LHS);
-      StringInit *RHSs = dyn_cast<StringInit>(RHS);
+    // Next try strings.
+    StringInit *LHSs = dyn_cast<StringInit>(LHS);
+    StringInit *RHSs = dyn_cast<StringInit>(RHS);
 
-      // Make sure we've resolved
-      if (LHSs && RHSs) {
-        bool Equal = LHSs->getValue() == RHSs->getValue();
-        return BitInit::get(getOpcode() == EQ ? Equal : !Equal);
+    if (LHSs && RHSs) {
+      bool Result;
+      switch (getOpcode()) {
+      case EQ: Result = LHSs->getValue() == RHSs->getValue(); break;
+      case NE: Result = LHSs->getValue() != RHSs->getValue(); break;
+      case LE: Result = LHSs->getValue() <= RHSs->getValue(); break;
+      case LT: Result = LHSs->getValue() <  RHSs->getValue(); break;
+      case GE: Result = LHSs->getValue() >= RHSs->getValue(); break;
+      case GT: Result = LHSs->getValue() >  RHSs->getValue(); break;
+      default: llvm_unreachable("unhandled comparison");
       }
+      return BitInit::get(Result);
+////      bool Equal = LHSs->getValue() == RHSs->getValue();
+////      return BitInit::get(getOpcode() == EQ ? Equal : !Equal);
+    }
+
+    // Finally, !eq and !ne can be used with records.
+    if (getOpcode() == EQ || getOpcode() == NE) {
+      DefInit *LHSd = dyn_cast<DefInit>(LHS);
+      DefInit *RHSd = dyn_cast<DefInit>(RHS);
+      if (LHSd && RHSd)
+        return BitInit::get((getOpcode() == EQ) ? LHSd == RHSd
+                                                : LHSd != RHSd);
     }
 
     break;

diff  --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index ccfde95247a3..90b5afa0e82e 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1148,15 +1148,12 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       break;
     case tgtok::XEq:
     case tgtok::XNe:
-      Type = BitRecTy::get();
-      // ArgType for Eq / Ne is not known at this point
-      break;
     case tgtok::XLe:
     case tgtok::XLt:
     case tgtok::XGe:
     case tgtok::XGt:
       Type = BitRecTy::get();
-      ArgType = IntRecTy::get();
+      // ArgType for the comparison operators is not yet known.
       break;
     case tgtok::XListConcat:
       // We don't know the list type until we parse the first argument
@@ -1244,10 +1241,24 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
           break;
         case BinOpInit::EQ:
         case BinOpInit::NE:
+          if (!ArgType->typeIsConvertibleTo(IntRecTy::get()) &&
+              !ArgType->typeIsConvertibleTo(StringRecTy::get()) &&
+              !ArgType->typeIsConvertibleTo(RecordRecTy::get({}))) {
+            Error(InitLoc, Twine("expected bit, bits, int, string, or record; "
+                                 "got value of type '") + ArgType->getAsString() + 
+                                 "'");
+            return nullptr;
+          }
+          break;
+        case BinOpInit::LE:
+        case BinOpInit::LT:
+        case BinOpInit::GE:
+        case BinOpInit::GT:
           if (!ArgType->typeIsConvertibleTo(IntRecTy::get()) &&
               !ArgType->typeIsConvertibleTo(StringRecTy::get())) {
-            Error(InitLoc, Twine("expected int, bits, or string; got value of "
-                                 "type '") + ArgType->getAsString() + "'");
+            Error(InitLoc, Twine("expected bit, bits, int, or string; "
+                                 "got value of type '") + ArgType->getAsString() + 
+                                 "'");
             return nullptr;
           }
           break;

diff  --git a/llvm/test/TableGen/compare.td b/llvm/test/TableGen/compare.td
index e54d853b9a60..1ecd95f4d35c 100644
--- a/llvm/test/TableGen/compare.td
+++ b/llvm/test/TableGen/compare.td
@@ -1,54 +1,117 @@
 // RUN: llvm-tblgen %s | FileCheck %s
-// XFAIL: vg_leak
-
-// CHECK: --- Defs ---
-
-// CHECK: def A0 {
-// CHECK:   bit eq = 1;
-// CHECK:   bit ne = 0;
-// CHECK:   bit le = 1;
-// CHECK:   bit lt = 0;
-// CHECK:   bit ge = 1;
-// CHECK:   bit gt = 0;
-// CHECK: }
-
-// CHECK: def A1 {
-// CHECK:   bit eq = 0;
-// CHECK:   bit ne = 1;
-// CHECK:   bit le = 1;
-// CHECK:   bit lt = 1;
-// CHECK:   bit ge = 0;
-// CHECK:   bit gt = 0;
-// CHECK: }
-
-// CHECK: def A2 {
-// CHECK:   bit eq = 0;
-// CHECK:   bit ne = 1;
-// CHECK:   bit le = 0;
-// CHECK:   bit lt = 0;
-// CHECK:   bit ge = 1;
-// CHECK:   bit gt = 1;
-// CHECK: }
-
-// CHECK: def A3 {
-// CHECK:   bit eq = 0;
-// CHECK:   bit ne = 1;
-// CHECK:   bit le = 0;
-// CHECK:   bit lt = 0;
-// CHECK:   bit ge = 1;
-// CHECK:   bit gt = 1;
-// CHECK: }
-
-class A<int x, int y> {
-  bit eq = !eq(x, y);
-  bit ne = !ne(x, y);
-  bit le = !le(x, y);
-  bit lt = !lt(x, y);
-  bit ge = !ge(x, y);
-  bit gt = !gt(x, y);
-}
-
-def A0 : A<-3, -3>;
-def A1 : A<-1, 4>;
-def A2 : A<3, -2>;
-def A3 : A<4, 2>;
+// RUN: not llvm-tblgen -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
+// RUN: not llvm-tblgen -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
+
+// This file tests the comparison bang operators.
+
+class BitCompare<bit a, bit b> {
+  list<bit> compare = [!eq(a, b), !ne(a, b),
+                       !lt(a, b), !le(a, b),
+                       !gt(a, b), !ge(a, b)];
+}
+
+class BitsCompare<bits<3> a, bits<3> b> {
+  list<bit> compare = [!eq(a, b), !ne(a, b),
+                       !lt(a, b), !le(a, b),
+                       !gt(a, b), !ge(a, b)];
+}
+
+class IntCompare<int a, int b> {
+  list<bit> compare = [!eq(a, b), !ne(a, b),
+                       !lt(a, b), !le(a, b),
+                       !gt(a, b), !ge(a, b)];
+}
+
+class StringCompare<string a, string b> {
+  list<bit> compare = [!eq(a, b), !ne(a, b),
+                       !lt(a, b), !le(a, b),
+                       !gt(a, b), !ge(a, b)];
+}
+
+multiclass MC {
+  def _MC;
+}
+
+// CHECK: def Bit00
+// CHECK:   compare = [1, 0, 0, 1, 0, 1];
+// CHECK: def Bit01
+// CHECK:   compare = [0, 1, 1, 1, 0, 0];
+// CHECK: def Bit10
+// CHECK:   compare = [0, 1, 0, 0, 1, 1];
+// CHECK: def Bit11
+// CHECK:   compare = [1, 0, 0, 1, 0, 1];
+
+def Bit00 : BitCompare<0, 0>;
+def Bit01 : BitCompare<0, 1>;
+def Bit10 : BitCompare<1, 0>;
+def Bit11 : BitCompare<1, 1>;
+
+// CHECK: def Bits1
+// CHECK:   compare = [0, 1, 1, 1, 0, 0];
+// CHECK: def Bits2
+// CHECK:   compare = [1, 0, 0, 1, 0, 1];
+// CHECK: def Bits3
+// CHECK:   compare = [0, 1, 0, 0, 1, 1];
+
+def Bits1 : BitsCompare<{0, 1, 0}, {1, 0, 1}>;
+def Bits2 : BitsCompare<{0, 1, 1}, {0, 1, 1}>;
+def Bits3 : BitsCompare<{1, 1, 1}, {0, 1, 1}>;
+
+// CHECK: def Int1
+// CHECK:   compare = [0, 1, 1, 1, 0, 0];
+// CHECK: def Int2
+// CHECK:   compare = [1, 0, 0, 1, 0, 1];
+// CHECK: def Int3
+// CHECK:   compare = [0, 1, 0, 0, 1, 1];
+
+def Int1 : IntCompare<-7, 13>;
+def Int2 : IntCompare<42, 42>;
+def Int3 : IntCompare<108, 42>;
+
+// CHECK: def Record1
+// CHECK:   compare1 = [1, 0];
+// CHECK:   compare2 = [0, 1];
+// CHECK:   compare3 = [1, 1];
+
+defm foo : MC;
+defm bar : MC;
+
+def Record1 {
+  list<bit> compare1 = [!eq(Bit00, Bit00), !eq(Bit00, Bit01)];
+  list<bit> compare2 = [!ne(Bit00, Bit00), !ne(Bit00, Int1)];
+  list<bit> compare3 = [!eq(bar_MC, bar_MC), !ne(bar_MC, foo_MC)];
+}
+
+// CHECK: def String1
+// CHECK:   compare = [0, 1, 1, 1, 0, 0];
+// CHECK: def String2
+// CHECK:   compare = [1, 0, 0, 1, 0, 1];
+// CHECK: def String3
+// CHECK:   compare = [0, 1, 0, 0, 1, 1];
+// CHECK: def String4
+// CHECK:   compare = [0, 1, 0, 0, 1, 1];
+def String1 : StringCompare<"bar", "foo">;
+def String2 : StringCompare<"foo", "foo">;
+def String3 : StringCompare<"foo", "bar">;
+def String4 : StringCompare<"foo", "Foo">;
+
+#ifdef ERROR1
+
+// ERROR1: expected bit, bits, int, string, or record; got value
+
+def Zerror1 {
+  bit compare1 = !eq([0, 1, 2], [0, 1, 2]);
+}
+
+#endif
+
+#ifdef ERROR2
+
+// ERROR2: expected bit, bits, int, or string; got value
+
+def Zerror2 {
+  bit compare1 = !lt(Bit00, Bit00);
+}
+
+#endif
+


        


More information about the llvm-commits mailing list