[llvm] [TableGen] Add a !listflatten operator to TableGen (PR #109346)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 19 22:39:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-tablegen

Author: Rahul Joshi (jurahul)

<details>
<summary>Changes</summary>

Add a !listflatten operator that will transform an input list of type `list<list<X>>` to `list<X>` by concatenating elements of the constituent lists of the input argument.

---
Full diff: https://github.com/llvm/llvm-project/pull/109346.diff


8 Files Affected:

- (modified) llvm/docs/TableGen/ProgRef.rst (+11-6) 
- (modified) llvm/include/llvm/TableGen/Record.h (+2-1) 
- (modified) llvm/lib/TableGen/Record.cpp (+25) 
- (modified) llvm/lib/TableGen/TGLexer.cpp (+1) 
- (modified) llvm/lib/TableGen/TGLexer.h (+1) 
- (modified) llvm/lib/TableGen/TGParser.cpp (+26-5) 
- (added) llvm/test/TableGen/listflatten-error.td (+6) 
- (added) llvm/test/TableGen/listflatten.td (+29) 


``````````diff
diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index dcea3b721dae27..69cfaeb5f8442e 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
                : !div         !empty       !eq          !exists      !filter
                : !find        !foldl       !foreach     !ge          !getdagarg
                : !getdagname  !getdagop    !gt          !head        !if
-               : !interleave  !isa         !le          !listconcat  !listremove
-               : !listsplat   !logtwo      !lt          !mul         !ne
-               : !not         !or          !range       !repr        !setdagarg
-               : !setdagname  !setdagop    !shl         !size        !sra
-               : !srl         !strconcat   !sub         !subst       !substr
-               : !tail        !tolower     !toupper     !xor
+               : !interleave  !isa         !le          !listconcat  !listflatten
+               : !listremove  !listsplat   !logtwo      !lt          !mul
+               : !ne          !not         !or          !range       !repr
+               : !setdagarg   !setdagname  !setdagop    !shl         !size
+               : !sra         !srl         !strconcat   !sub         !subst
+               : !substr      !tail        !tolower     !toupper     !xor
 
 The ``!cond`` operator has a slightly different
 syntax compared to other bang operators, so it is defined separately:
@@ -1832,6 +1832,11 @@ and non-0 as true.
     This operator concatenates the list arguments *list1*, *list2*, etc., and
     produces the resulting list. The lists must have the same element type.
 
+``!listflatten(``\ *list*\ ``)``
+    This operator flattens a list of lists *list* and produces a list with all
+    elements of the constituent lists concatenated. If *list* is of type
+    ``list<list<X>>`` the resulting list is of type ``list<X>``.
+
 ``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
     This operator returns a copy of *list1* removing all elements that also occur in
     *list2*. The lists must have the same element type.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 5348c1177f63ed..4cd73c3f675527 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
     EMPTY,
     GETDAGOP,
     LOG2,
-    REPR
+    REPR,
+    LISTFLATTEN,
   };
 
 private:
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index ff2da3badb3628..1f403e19339a2a 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -987,6 +987,28 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
       }
     }
     break;
+
+  case LISTFLATTEN:
+    ListInit *LHSList = dyn_cast<ListInit>(LHS);
+    if (!LHSList)
+      break;
+    ListRecTy *InnerListTy = cast<ListRecTy>(LHSList->getElementType());
+    if (!InnerListTy)
+      break;
+    std::vector<Init *> Flattened;
+    bool Failed = false;
+    // Concatenate elements of all the inner lists.
+    for (Init *InnerInit : LHSList->getValues()) {
+      ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
+      if (!InnerList) {
+        Failed = true;
+        break;
+      }
+      for (Init *InnerElem : InnerList->getValues())
+        Flattened.push_back(InnerElem);
+    }
+    if (!Failed)
+      return ListInit::get(Flattened, InnerListTy->getElementType());
   }
   return const_cast<UnOpInit *>(this);
 }
@@ -1011,6 +1033,9 @@ std::string UnOpInit::getAsString() const {
   case EMPTY: Result = "!empty"; break;
   case GETDAGOP: Result = "!getdagop"; break;
   case LOG2 : Result = "!logtwo"; break;
+  case LISTFLATTEN:
+    Result = "!listflatten";
+    break;
   case REPR:
     Result = "!repr";
     break;
diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp
index 62a884e01a5306..8fe7f69ecf8e59 100644
--- a/llvm/lib/TableGen/TGLexer.cpp
+++ b/llvm/lib/TableGen/TGLexer.cpp
@@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
           .Case("foreach", tgtok::XForEach)
           .Case("filter", tgtok::XFilter)
           .Case("listconcat", tgtok::XListConcat)
+          .Case("listflatten", tgtok::XListFlatten)
           .Case("listsplat", tgtok::XListSplat)
           .Case("listremove", tgtok::XListRemove)
           .Case("range", tgtok::XRange)
diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h
index 9adc03ccc72b85..4fa4d84d0535d3 100644
--- a/llvm/lib/TableGen/TGLexer.h
+++ b/llvm/lib/TableGen/TGLexer.h
@@ -122,6 +122,7 @@ enum TokKind {
   XSRL,
   XSHL,
   XListConcat,
+  XListFlatten,
   XListSplat,
   XStrConcat,
   XInterleave,
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 1a60c2a567a297..20de3cc4dad9e9 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
   case tgtok::XNOT:
   case tgtok::XToLower:
   case tgtok::XToUpper:
+  case tgtok::XListFlatten:
   case tgtok::XLOG2:
   case tgtok::XHead:
   case tgtok::XTail:
@@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       Code = UnOpInit::NOT;
       Type = IntRecTy::get(Records);
       break;
+    case tgtok::XListFlatten:
+      Lex.Lex(); // eat the operation.
+      Code = UnOpInit::LISTFLATTEN;
+      Type = IntRecTy::get(Records); // Bogus type used here.
+      break;
     case tgtok::XLOG2:
       Lex.Lex();  // eat the operation
       Code = UnOpInit::LOG2;
@@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       }
     }
 
-    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
+    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
+        Code == UnOpInit::LISTFLATTEN) {
       ListInit *LHSl = dyn_cast<ListInit>(LHS);
       TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
       if (!LHSl && !LHSt) {
@@ -1328,6 +1335,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
         TokError("empty list argument in unary operator");
         return nullptr;
       }
+      bool UseElementType =
+          Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
       if (LHSl) {
         Init *Item = LHSl->getElement(0);
         TypedInit *Itemt = dyn_cast<TypedInit>(Item);
@@ -1335,12 +1344,24 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
           TokError("untyped list element in unary operator");
           return nullptr;
         }
-        Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
-                                        : ListRecTy::get(Itemt->getType());
+        Type = UseElementType ? Itemt->getType()
+                              : ListRecTy::get(Itemt->getType());
       } else {
         assert(LHSt && "expected list type argument in unary operator");
         ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
-        Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
+        Type = UseElementType ? LType->getElementType() : LType;
+      }
+
+      // for listflatten, we expect a list of lists.
+      if (Code == UnOpInit::LISTFLATTEN) {
+        ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
+        if (!InnerListTy) {
+          TokError("expected argument of type list of list in !listflatten "
+                   "operator");
+          return nullptr;
+        }
+        // listflatten will convert list<list<X>> to list<X>.
+        Type = ListRecTy::get(InnerListTy->getElementType());
       }
     }
 
@@ -1378,7 +1399,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
 
   case tgtok::XExists: {
     // Value ::= !exists '<' Type '>' '(' Value ')'
-    Lex.Lex(); // eat the operation
+    Lex.Lex(); // eat the operation.
 
     RecTy *Type = ParseOperatorType();
     if (!Type)
diff --git a/llvm/test/TableGen/listflatten-error.td b/llvm/test/TableGen/listflatten-error.td
new file mode 100644
index 00000000000000..e18528a08e6bf6
--- /dev/null
+++ b/llvm/test/TableGen/listflatten-error.td
@@ -0,0 +1,6 @@
+// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s  -DFILE=%s
+
+// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected argument of type list of list in !listflatten operator
+class Flatten<list<int> A> {
+    list<int> F = !listflatten(A);
+}
diff --git a/llvm/test/TableGen/listflatten.td b/llvm/test/TableGen/listflatten.td
new file mode 100644
index 00000000000000..20119b24cce0a3
--- /dev/null
+++ b/llvm/test/TableGen/listflatten.td
@@ -0,0 +1,29 @@
+
+// RUN: llvm-tblgen %s | FileCheck %s
+
+class Flatten<list<int> A, list<int> B> {
+    list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);
+
+    list<list<int>> X = [A, B];
+    list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));
+
+    // Generate a nested list of integers.
+    list<int> Y0 = [1, 2, 3, 4];
+    list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
+    list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
+    list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);
+
+    // Flatten it completely.
+    list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));
+
+    // Flatten it partially.
+    list<list<list<int>>> Flat4 = !listflatten(Y3);
+    list<list<int>> Flat5 = !listflatten(!listflatten(Y3));
+}
+
+// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
+// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
+// CHECK: list<int> Flat3 = [1, 2, 3, 4];
+// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
+def F : Flatten<[1,2], [3,4,5]>;
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/109346


More information about the llvm-commits mailing list