[llvm] [TableGen] Add a !listflatten operator to TableGen (PR #109346)
    Rahul Joshi via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Mon Sep 23 13:47:10 PDT 2024
    
    
  
https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/109346
>From 56f251159b66022483221f633d15bab5c00a7c31 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 19 Sep 2024 15:28:40 -0700
Subject: [PATCH] [TableGen] Add a !listflatten operator to TableGen
Add a !listflatten operator that will transform an input list of
type `list<list<X>>` to `list<X>` by concatenating elements of the
constituent lists of the input argument.
---
 llvm/docs/TableGen/ProgRef.rst          | 18 +++++++++-----
 llvm/include/llvm/TableGen/Record.h     |  3 ++-
 llvm/lib/TableGen/Record.cpp            | 29 ++++++++++++++++++++++
 llvm/lib/TableGen/TGLexer.cpp           |  1 +
 llvm/lib/TableGen/TGLexer.h             |  1 +
 llvm/lib/TableGen/TGParser.cpp          | 32 +++++++++++++++++++++----
 llvm/test/TableGen/listflatten-error.td |  6 +++++
 llvm/test/TableGen/listflatten.td       | 32 +++++++++++++++++++++++++
 8 files changed, 110 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/TableGen/listflatten-error.td
 create mode 100644 llvm/test/TableGen/listflatten.td
diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index dcea3b721dae27..5cf48d6ed29786 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
                : !div         !empty       !eq          !exists      !filter
                : !find        !foldl       !foreach     !ge          !getdagarg
                : !getdagname  !getdagop    !gt          !head        !if
-               : !interleave  !isa         !le          !listconcat  !listremove
-               : !listsplat   !logtwo      !lt          !mul         !ne
-               : !not         !or          !range       !repr        !setdagarg
-               : !setdagname  !setdagop    !shl         !size        !sra
-               : !srl         !strconcat   !sub         !subst       !substr
-               : !tail        !tolower     !toupper     !xor
+               : !interleave  !isa         !le          !listconcat  !listflatten
+               : !listremove  !listsplat   !logtwo      !lt          !mul
+               : !ne          !not         !or          !range       !repr
+               : !setdagarg   !setdagname  !setdagop    !shl         !size
+               : !sra         !srl         !strconcat   !sub         !subst
+               : !substr      !tail        !tolower     !toupper     !xor
 
 The ``!cond`` operator has a slightly different
 syntax compared to other bang operators, so it is defined separately:
@@ -1832,6 +1832,12 @@ and non-0 as true.
     This operator concatenates the list arguments *list1*, *list2*, etc., and
     produces the resulting list. The lists must have the same element type.
 
+``!listflatten(``\ *list*\ ``)``
+    This operator flattens a list of lists *list* and produces a list with all
+    elements of the constituent lists concatenated. If *list* is of type
+    ``list<list<X>>`` the resulting list is of type ``list<X>``. If *list*'s
+    element type is not a list, the result is *list* itself.
+
 ``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
     This operator returns a copy of *list1* removing all elements that also occur in
     *list2*. The lists must have the same element type.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 5348c1177f63ed..4cd73c3f675527 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
     EMPTY,
     GETDAGOP,
     LOG2,
-    REPR
+    REPR,
+    LISTFLATTEN,
   };
 
 private:
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index a72a0799a08025..0f99b4a13d2f95 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -986,6 +986,32 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
       }
     }
     break;
+
+  case LISTFLATTEN:
+    if (ListInit *LHSList = dyn_cast<ListInit>(LHS)) {
+      ListRecTy *InnerListTy = dyn_cast<ListRecTy>(LHSList->getElementType());
+      // list of non-lists, !listflatten() is a NOP.
+      if (!InnerListTy)
+        return LHS;
+
+      auto Flatten = [](ListInit *List) -> std::optional<std::vector<Init *>> {
+        std::vector<Init *> Flattened;
+        // Concatenate elements of all the inner lists.
+        for (Init *InnerInit : List->getValues()) {
+          ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
+          if (!InnerList)
+            return std::nullopt;
+          for (Init *InnerElem : InnerList->getValues())
+            Flattened.push_back(InnerElem);
+        };
+        return Flattened;
+      };
+
+      auto Flattened = Flatten(LHSList);
+      if (Flattened)
+        return ListInit::get(*Flattened, InnerListTy->getElementType());
+    }
+    break;
   }
   return const_cast<UnOpInit *>(this);
 }
@@ -1010,6 +1036,9 @@ std::string UnOpInit::getAsString() const {
   case EMPTY: Result = "!empty"; break;
   case GETDAGOP: Result = "!getdagop"; break;
   case LOG2 : Result = "!logtwo"; break;
+  case LISTFLATTEN:
+    Result = "!listflatten";
+    break;
   case REPR:
     Result = "!repr";
     break;
diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp
index 62a884e01a5306..8fe7f69ecf8e59 100644
--- a/llvm/lib/TableGen/TGLexer.cpp
+++ b/llvm/lib/TableGen/TGLexer.cpp
@@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
           .Case("foreach", tgtok::XForEach)
           .Case("filter", tgtok::XFilter)
           .Case("listconcat", tgtok::XListConcat)
+          .Case("listflatten", tgtok::XListFlatten)
           .Case("listsplat", tgtok::XListSplat)
           .Case("listremove", tgtok::XListRemove)
           .Case("range", tgtok::XRange)
diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h
index 9adc03ccc72b85..4fa4d84d0535d3 100644
--- a/llvm/lib/TableGen/TGLexer.h
+++ b/llvm/lib/TableGen/TGLexer.h
@@ -122,6 +122,7 @@ enum TokKind {
   XSRL,
   XSHL,
   XListConcat,
+  XListFlatten,
   XListSplat,
   XStrConcat,
   XInterleave,
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 1a60c2a567a297..54c9a902ec27a1 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
   case tgtok::XNOT:
   case tgtok::XToLower:
   case tgtok::XToUpper:
+  case tgtok::XListFlatten:
   case tgtok::XLOG2:
   case tgtok::XHead:
   case tgtok::XTail:
@@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       Code = UnOpInit::NOT;
       Type = IntRecTy::get(Records);
       break;
+    case tgtok::XListFlatten:
+      Lex.Lex(); // eat the operation.
+      Code = UnOpInit::LISTFLATTEN;
+      Type = IntRecTy::get(Records); // Bogus type used here.
+      break;
     case tgtok::XLOG2:
       Lex.Lex();  // eat the operation
       Code = UnOpInit::LOG2;
@@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       }
     }
 
-    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
+    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
+        Code == UnOpInit::LISTFLATTEN) {
       ListInit *LHSl = dyn_cast<ListInit>(LHS);
       TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
       if (!LHSl && !LHSt) {
@@ -1328,6 +1335,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
         TokError("empty list argument in unary operator");
         return nullptr;
       }
+      bool UseElementType =
+          Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
       if (LHSl) {
         Init *Item = LHSl->getElement(0);
         TypedInit *Itemt = dyn_cast<TypedInit>(Item);
@@ -1335,12 +1344,25 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
           TokError("untyped list element in unary operator");
           return nullptr;
         }
-        Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
-                                        : ListRecTy::get(Itemt->getType());
+        Type = UseElementType ? Itemt->getType()
+                              : ListRecTy::get(Itemt->getType());
       } else {
         assert(LHSt && "expected list type argument in unary operator");
         ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
-        Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
+        Type = UseElementType ? LType->getElementType() : LType;
+      }
+
+      // for !listflatten, we expect a list of lists, but also support a list of
+      // non-lists, where !listflatten will be a NOP.
+      if (Code == UnOpInit::LISTFLATTEN) {
+        ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
+        if (InnerListTy) {
+          // listflatten will convert list<list<X>> to list<X>.
+          Type = ListRecTy::get(InnerListTy->getElementType());
+        } else {
+          // If its a list of non-lists, !listflatten will be a NOP.
+          Type = ListRecTy::get(Type);
+        }
       }
     }
 
@@ -1378,7 +1400,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
 
   case tgtok::XExists: {
     // Value ::= !exists '<' Type '>' '(' Value ')'
-    Lex.Lex(); // eat the operation
+    Lex.Lex(); // eat the operation.
 
     RecTy *Type = ParseOperatorType();
     if (!Type)
diff --git a/llvm/test/TableGen/listflatten-error.td b/llvm/test/TableGen/listflatten-error.td
new file mode 100644
index 00000000000000..56062420982a11
--- /dev/null
+++ b/llvm/test/TableGen/listflatten-error.td
@@ -0,0 +1,6 @@
+// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s  -DFILE=%s
+
+// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected list type argument in unary operator
+class Flatten<int A> {
+    list<int> F = !listflatten(A);
+}
diff --git a/llvm/test/TableGen/listflatten.td b/llvm/test/TableGen/listflatten.td
new file mode 100644
index 00000000000000..bc9b1c71ea88d7
--- /dev/null
+++ b/llvm/test/TableGen/listflatten.td
@@ -0,0 +1,32 @@
+// RUN: llvm-tblgen %s | FileCheck %s
+
+class Flatten<list<int> A, list<int> B> {
+    list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);
+
+    list<list<int>> X = [A, B];
+    list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));
+
+    // Generate a nested list of integers.
+    list<int> Y0 = [1, 2, 3, 4];
+    list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
+    list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
+    list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);
+
+    // Flatten it completely.
+    list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));
+
+    // Flatten it partially.
+    list<list<list<int>>> Flat4 = !listflatten(Y3);
+    list<list<int>> Flat5 = !listflatten(!listflatten(Y3));
+
+    // Test NOP flattening.
+    list<string> Flat6 = !listflatten(["a", "b"]);
+}
+
+// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
+// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
+// CHECK: list<int> Flat3 = [1, 2, 3, 4];
+// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
+// CHECK: list<string> Flat6 = ["a", "b"];
+def F : Flatten<[1,2], [3,4,5]>;
+
    
    
More information about the llvm-commits
mailing list