[llvm] [TableGen] Add a !listflatten operator to TableGen (PR #109346)
Rahul Joshi via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 19 15:51:14 PDT 2024
https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/109346
Add a !listflatten operator that will transform an input list of type `list<list<X>>` to `list<X>` by concatenating elements of the constituent lists of the input argument.
>From 130b8bad1ff0711b7a353b644db5cfd5da1094d4 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 19 Sep 2024 15:28:40 -0700
Subject: [PATCH] [TableGen] Add a !listflatten operator to TableGen
Add a !listflatten operator that will transform an input list of
type `list<list<X>>` to `list<X>` by concatenating elements of the
constituent lists of the input argument.
---
llvm/docs/TableGen/ProgRef.rst | 18 +++++++++-----
llvm/include/llvm/TableGen/Record.h | 3 ++-
llvm/lib/TableGen/Record.cpp | 25 ++++++++++++++++++++
llvm/lib/TableGen/TGLexer.cpp | 1 +
llvm/lib/TableGen/TGLexer.h | 1 +
llvm/lib/TableGen/TGParser.cpp | 31 +++++++++++++++++++++----
llvm/test/TableGen/listflatten-error.td | 6 +++++
llvm/test/TableGen/listflatten.td | 29 +++++++++++++++++++++++
8 files changed, 102 insertions(+), 12 deletions(-)
create mode 100644 llvm/test/TableGen/listflatten-error.td
create mode 100644 llvm/test/TableGen/listflatten.td
diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index dcea3b721dae27..8a823d1e1cb362 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
: !div !empty !eq !exists !filter
: !find !foldl !foreach !ge !getdagarg
: !getdagname !getdagop !gt !head !if
- : !interleave !isa !le !listconcat !listremove
- : !listsplat !logtwo !lt !mul !ne
- : !not !or !range !repr !setdagarg
- : !setdagname !setdagop !shl !size !sra
- : !srl !strconcat !sub !subst !substr
- : !tail !tolower !toupper !xor
+ : !interleave !isa !le !listconcat !listflatten
+ : !listremove !listsplat !logtwo !lt !mul
+ : !ne !not !or !range !repr
+ : !setdagarg !setdagname !setdagop !shl !size
+ : !sra !srl !strconcat !sub !subst
+ : !substr !tail !tolower !toupper !xor
The ``!cond`` operator has a slightly different
syntax compared to other bang operators, so it is defined separately:
@@ -1832,6 +1832,12 @@ and non-0 as true.
This operator concatenates the list arguments *list1*, *list2*, etc., and
produces the resulting list. The lists must have the same element type.
+``!listflatten(``\ *list*\ ``)``
+ This operator flattens a list of lists *list* and produces a list with all
+ elements of the constituent lists concatenated. The constituent lists must
+ have the same element type. If *list* is of type ``list<list<X>>`` the
+ resulting list is of type ``list<X>``.
+
``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
This operator returns a copy of *list1* removing all elements that also occur in
*list2*. The lists must have the same element type.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 5348c1177f63ed..4cd73c3f675527 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
EMPTY,
GETDAGOP,
LOG2,
- REPR
+ REPR,
+ LISTFLATTEN,
};
private:
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index ff2da3badb3628..331629ec4677fe 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -987,6 +987,28 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
}
}
break;
+
+ case LISTFLATTEN:
+ ListInit *LHSList = dyn_cast<ListInit>(LHS);
+ if (!LHSList)
+ break;
+ // Iterate over inner lists.
+ ListRecTy *InnerListTy = cast<ListRecTy>(LHSList->getElementType());
+ if (!InnerListTy)
+ break;
+ std::vector<Init *> Flattened;
+ bool Failed = false;
+ for (Init *InnerInit : LHSList->getValues()) {
+ ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
+ if (!InnerList) {
+ Failed = true;
+ break;
+ }
+ for (Init *InnerElem : InnerList->getValues())
+ Flattened.push_back(InnerElem);
+ }
+ if (!Failed)
+ return ListInit::get(Flattened, InnerListTy->getElementType());
}
return const_cast<UnOpInit *>(this);
}
@@ -1011,6 +1033,9 @@ std::string UnOpInit::getAsString() const {
case EMPTY: Result = "!empty"; break;
case GETDAGOP: Result = "!getdagop"; break;
case LOG2 : Result = "!logtwo"; break;
+ case LISTFLATTEN:
+ Result = "!listflatten";
+ break;
case REPR:
Result = "!repr";
break;
diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp
index 62a884e01a5306..8fe7f69ecf8e59 100644
--- a/llvm/lib/TableGen/TGLexer.cpp
+++ b/llvm/lib/TableGen/TGLexer.cpp
@@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
.Case("foreach", tgtok::XForEach)
.Case("filter", tgtok::XFilter)
.Case("listconcat", tgtok::XListConcat)
+ .Case("listflatten", tgtok::XListFlatten)
.Case("listsplat", tgtok::XListSplat)
.Case("listremove", tgtok::XListRemove)
.Case("range", tgtok::XRange)
diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h
index 9adc03ccc72b85..4fa4d84d0535d3 100644
--- a/llvm/lib/TableGen/TGLexer.h
+++ b/llvm/lib/TableGen/TGLexer.h
@@ -122,6 +122,7 @@ enum TokKind {
XSRL,
XSHL,
XListConcat,
+ XListFlatten,
XListSplat,
XStrConcat,
XInterleave,
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 1a60c2a567a297..4aff39c53ff9fc 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XNOT:
case tgtok::XToLower:
case tgtok::XToUpper:
+ case tgtok::XListFlatten:
case tgtok::XLOG2:
case tgtok::XHead:
case tgtok::XTail:
@@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
Code = UnOpInit::NOT;
Type = IntRecTy::get(Records);
break;
+ case tgtok::XListFlatten:
+ Lex.Lex(); // eat the operation.
+ Code = UnOpInit::LISTFLATTEN;
+ Type = IntRecTy::get(Records); // Bogus type used here.
+ break;
case tgtok::XLOG2:
Lex.Lex(); // eat the operation
Code = UnOpInit::LOG2;
@@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
}
}
- if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
+ if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
+ Code == UnOpInit::LISTFLATTEN) {
ListInit *LHSl = dyn_cast<ListInit>(LHS);
TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
if (!LHSl && !LHSt) {
@@ -1328,6 +1335,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
TokError("empty list argument in unary operator");
return nullptr;
}
+ bool UseElementType =
+ Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
if (LHSl) {
Init *Item = LHSl->getElement(0);
TypedInit *Itemt = dyn_cast<TypedInit>(Item);
@@ -1335,12 +1344,24 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
TokError("untyped list element in unary operator");
return nullptr;
}
- Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
- : ListRecTy::get(Itemt->getType());
+ Type = UseElementType ? Itemt->getType()
+ : ListRecTy::get(Itemt->getType());
} else {
assert(LHSt && "expected list type argument in unary operator");
ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
- Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
+ Type = UseElementType ? LType->getElementType() : LType;
+ }
+
+ // for listflatten, we expect a list of lists.
+ if (Code == UnOpInit::LISTFLATTEN) {
+ ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
+ if (!InnerListTy) {
+ TokError(
+ "expected list of list type argument in !listflatten operator");
+ return nullptr;
+ }
+ // listflatten will convert list<list<X>> to list<X>.
+ Type = ListRecTy::get(InnerListTy->getElementType());
}
}
@@ -1378,7 +1399,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XExists: {
// Value ::= !exists '<' Type '>' '(' Value ')'
- Lex.Lex(); // eat the operation
+ Lex.Lex(); // eat the operation.
RecTy *Type = ParseOperatorType();
if (!Type)
diff --git a/llvm/test/TableGen/listflatten-error.td b/llvm/test/TableGen/listflatten-error.td
new file mode 100644
index 00000000000000..560b3506154889
--- /dev/null
+++ b/llvm/test/TableGen/listflatten-error.td
@@ -0,0 +1,6 @@
+// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s -DFILE=%s
+
+// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected list of list type argument in !listflatten operator
+class Flatten<list<int> A> {
+ list<int> F = !listflatten(A);
+}
diff --git a/llvm/test/TableGen/listflatten.td b/llvm/test/TableGen/listflatten.td
new file mode 100644
index 00000000000000..20119b24cce0a3
--- /dev/null
+++ b/llvm/test/TableGen/listflatten.td
@@ -0,0 +1,29 @@
+
+// RUN: llvm-tblgen %s | FileCheck %s
+
+class Flatten<list<int> A, list<int> B> {
+ list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);
+
+ list<list<int>> X = [A, B];
+ list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));
+
+ // Generate a nested list of integers.
+ list<int> Y0 = [1, 2, 3, 4];
+ list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
+ list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
+ list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);
+
+ // Flatten it completely.
+ list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));
+
+ // Flatten it partially.
+ list<list<list<int>>> Flat4 = !listflatten(Y3);
+ list<list<int>> Flat5 = !listflatten(!listflatten(Y3));
+}
+
+// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
+// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
+// CHECK: list<int> Flat3 = [1, 2, 3, 4];
+// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
+def F : Flatten<[1,2], [3,4,5]>;
+
More information about the llvm-commits
mailing list