[llvm] [TableGen] Enhance !range bang operator (PR #66489)

Wang Pengcheng via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 15 03:47:07 PDT 2023


https://github.com/wangpc-pp created https://github.com/llvm/llvm-project/pull/66489

We add a third argument `step` to `!range` bang operator to make it
with the same semantics as `range` in Python.

`step` can be negative. `step` is 1 by default and `step` can't be
0. If `start` < `end` and `step` is negative, or `start` > `end`
and `step` is positive, the result is an empty list.


>From e4bb68ef924ef93fe67a28ed0b7bf5ca814394ce Mon Sep 17 00:00:00 2001
From: wangpc <wangpengcheng.pp at bytedance.com>
Date: Fri, 15 Sep 2023 18:34:55 +0800
Subject: [PATCH] [TableGen] Enhance !range bang operator

We add a third argument `step` to `!range` bang operator to make it
with the same semantics as `range` in Python.

`step` can be negative. `step` is 1 by default and `step` can't be
0. If `start` < `end` and `step` is negative, or `start` > `end`
and `step` is positive, the result is an empty list.
---
 llvm/docs/TableGen/ProgRef.rst      |  12 ++-
 llvm/include/llvm/TableGen/Record.h |   2 +-
 llvm/lib/TableGen/Record.cpp        |  31 +++++-
 llvm/lib/TableGen/TGParser.cpp      | 158 ++++++++++++++++++----------
 llvm/test/TableGen/range-op-fail.td |   4 +-
 llvm/test/TableGen/range-op.td      |  16 +++
 6 files changed, 155 insertions(+), 68 deletions(-)

diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index aa29d830ec90772..299143a8982be88 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -1831,11 +1831,13 @@ and non-0 as true.
     result. A logical OR can be performed if all the arguments are either
     0 or 1.
 
-``!range([``\ *a*\ ``,``] *b*\ ``)``
-    This operator produces half-open range sequence ``[a : b)`` as ``list<int>``.
-    *a* is ``0`` by default. ``!range(4)`` is equivalent to ``!range(0, 4)``.
-    The result is `[0, 1, 2, 3]`.
-    If *a* ``>=`` *b*, then the result is `[]<list<int>>`.
+``!range([``\ *start*\ ``,]`` *end*\ ``[, ``\ *step*\ ``])``
+    This operator produces half-open range sequence ``[start : end : step)`` as
+    ``list<int>``. *start* is ``0`` and *step* is ``1`` by default. *step* can
+    be negative and cannot be 0. If *start* ``<`` *end* and *step* is negative,
+    or *start* ``>`` *end* and *step* is positive, the result is an empty list
+    ``[]<list<int>>``. For example, ``!range(4)`` is equivalent to
+    ``!range(0, 4, 1)`` and the result is `[0, 1, 2, 3]`.
 
 ``!range(``\ *list*\ ``)``
     Equivalent to ``!range(0, !size(list))``.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 06e4abb27e59983..7d743d24f4801ba 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -911,7 +911,6 @@ class BinOpInit : public OpInit, public FoldingSetNode {
     LISTREMOVE,
     LISTELEM,
     LISTSLICE,
-    RANGE,
     RANGEC,
     STRCONCAT,
     INTERLEAVE,
@@ -988,6 +987,7 @@ class TernOpInit : public OpInit, public FoldingSetNode {
     FILTER,
     IF,
     DAG,
+    RANGE,
     SUBSTR,
     FIND,
     SETDAGARG,
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 20db470855a102b..ac8c12ace98bb60 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1287,7 +1287,6 @@ Init *BinOpInit::Fold(Record *CurRec) const {
     }
     return ListInit::get(Args, TheList->getElementType());
   }
-  case RANGE:
   case RANGEC: {
     auto *LHSi = dyn_cast<IntInit>(LHS);
     auto *RHSi = dyn_cast<IntInit>(RHS);
@@ -1488,7 +1487,6 @@ std::string BinOpInit::getAsString() const {
   case LISTCONCAT: Result = "!listconcat"; break;
   case LISTSPLAT: Result = "!listsplat"; break;
   case LISTREMOVE: Result = "!listremove"; break;
-  case RANGE: Result = "!range"; break;
   case STRCONCAT: Result = "!strconcat"; break;
   case INTERLEAVE: Result = "!interleave"; break;
   case SETDAGOP: Result = "!setdagop"; break;
@@ -1704,6 +1702,34 @@ Init *TernOpInit::Fold(Record *CurRec) const {
     break;
   }
 
+  case RANGE: {
+    auto *LHSi = dyn_cast<IntInit>(LHS);
+    auto *MHSi = dyn_cast<IntInit>(MHS);
+    auto *RHSi = dyn_cast<IntInit>(RHS);
+    if (!LHSi || !MHSi || !RHSi)
+      break;
+
+    auto Start = LHSi->getValue();
+    auto End = MHSi->getValue();
+    auto Step = RHSi->getValue();
+    if (Step == 0)
+      PrintError(CurRec->getLoc(), "Step of !range can't be 0");
+
+    SmallVector<Init *, 8> Args;
+    if (Start < End && Step > 0) {
+      Args.reserve((End - Start) / Step);
+      for (auto I = Start; I < End; I += Step)
+        Args.push_back(IntInit::get(getRecordKeeper(), I));
+    } else if (Start > End && Step < 0) {
+      Args.reserve((Start - End) / -Step);
+      for (auto I = Start; I > End; I += Step)
+        Args.push_back(IntInit::get(getRecordKeeper(), I));
+    } else {
+      // Empty set
+    }
+    return ListInit::get(Args, LHSi->getType());
+  }
+
   case SUBSTR: {
     StringInit *LHSs = dyn_cast<StringInit>(LHS);
     IntInit *MHSi = dyn_cast<IntInit>(MHS);
@@ -1823,6 +1849,7 @@ std::string TernOpInit::getAsString() const {
   case FILTER: Result = "!filter"; UnquotedLHS = true; break;
   case FOREACH: Result = "!foreach"; UnquotedLHS = true; break;
   case IF: Result = "!if"; break;
+  case RANGE: Result = "!range"; break;
   case SUBST: Result = "!subst"; break;
   case SUBSTR: Result = "!substr"; break;
   case FIND: Result = "!find"; break;
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index b4bc9272fa324af..ae2f07cd94530da 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1408,7 +1408,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
   case tgtok::XListConcat:
   case tgtok::XListSplat:
   case tgtok::XListRemove:
-  case tgtok::XRange:
   case tgtok::XStrConcat:
   case tgtok::XInterleave:
   case tgtok::XGetDagArg:
@@ -1441,7 +1440,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
     case tgtok::XListConcat: Code = BinOpInit::LISTCONCAT; break;
     case tgtok::XListSplat:  Code = BinOpInit::LISTSPLAT; break;
     case tgtok::XListRemove: Code = BinOpInit::LISTREMOVE; break;
-    case tgtok::XRange:      Code = BinOpInit::RANGE; break;
     case tgtok::XStrConcat:  Code = BinOpInit::STRCONCAT; break;
     case tgtok::XInterleave: Code = BinOpInit::INTERLEAVE; break;
     case tgtok::XSetDagOp:   Code = BinOpInit::SETDAGOP; break;
@@ -1508,10 +1506,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       // We don't know the list type until we parse the first argument.
       ArgType = ItemType;
       break;
-    case tgtok::XRange:
-      Type = IntRecTy::get(Records)->getListTy();
-      // ArgType may be either Int or List.
-      break;
     case tgtok::XStrConcat:
       Type = StringRecTy::get(Records);
       ArgType = StringRecTy::get(Records);
@@ -1596,27 +1590,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
             return nullptr;
           }
           break;
-        case BinOpInit::RANGE:
-          if (InitList.size() == 1) {
-            if (isa<ListRecTy>(ArgType)) {
-              ArgType = nullptr; // Detect error if 2nd arg were present.
-            } else if (isa<IntRecTy>(ArgType)) {
-              // Assume 2nd arg should be IntRecTy
-            } else {
-              Error(InitLoc,
-                    Twine("expected list or int, got value of type '") +
-                        ArgType->getAsString() + "'");
-              return nullptr;
-            }
-          } else {
-            // Don't come here unless 1st arg is ListRecTy.
-            assert(isa<ListRecTy>(cast<TypedInit>(InitList[0])->getType()));
-            Error(InitLoc,
-                  Twine("expected one list, got extra value of type '") +
-                      ArgType->getAsString() + "'");
-            return nullptr;
-          }
-          break;
         case BinOpInit::EQ:
         case BinOpInit::NE:
           if (!ArgType->typeIsConvertibleTo(IntRecTy::get(Records)) &&
@@ -1726,37 +1699,6 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
     if (Code == BinOpInit::LISTREMOVE)
       Type = ArgType;
 
-    if (Code == BinOpInit::RANGE) {
-      Init *LHS, *RHS;
-      auto ArgCount = InitList.size();
-      assert(ArgCount >= 1);
-      auto *Arg0 = cast<TypedInit>(InitList[0]);
-      auto *Arg0Ty = Arg0->getType();
-      if (ArgCount == 1) {
-        if (isa<ListRecTy>(Arg0Ty)) {
-          // (0, !size(arg))
-          LHS = IntInit::get(Records, 0);
-          RHS = UnOpInit::get(UnOpInit::SIZE, Arg0, IntRecTy::get(Records))
-                    ->Fold(CurRec);
-        } else {
-          assert(isa<IntRecTy>(Arg0Ty));
-          // (0, arg)
-          LHS = IntInit::get(Records, 0);
-          RHS = Arg0;
-        }
-      } else if (ArgCount == 2) {
-        assert(isa<IntRecTy>(Arg0Ty));
-        auto *Arg1 = cast<TypedInit>(InitList[1]);
-        assert(isa<IntRecTy>(Arg1->getType()));
-        LHS = Arg0;
-        RHS = Arg1;
-      } else {
-        Error(OpLoc, "expected at most two values of integer");
-        return nullptr;
-      }
-      return BinOpInit::get(Code, LHS, RHS, Type)->Fold(CurRec);
-    }
-
     // We allow multiple operands to associative operators like !strconcat as
     // shorthand for nesting them.
     if (Code == BinOpInit::STRCONCAT || Code == BinOpInit::LISTCONCAT ||
@@ -1783,6 +1725,105 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
     return ParseOperationForEachFilter(CurRec, ItemType);
   }
 
+  case tgtok::XRange: {
+    SMLoc OpLoc = Lex.getLoc();
+    Lex.Lex(); // eat the operation
+
+    if (!consume(tgtok::l_paren)) {
+      TokError("expected '(' after !range operator");
+      return nullptr;
+    }
+
+    SmallVector<Init *, 2> Args;
+    bool FirstArgIsList = false;
+    for (;;) {
+      if (Args.size() >= 3) {
+        TokError("expected at most three values of integer");
+        return nullptr;
+      }
+
+      SMLoc InitLoc = Lex.getLoc();
+      Args.push_back(ParseValue(CurRec));
+      if (!Args.back())
+        return nullptr;
+
+      TypedInit *ArgBack = dyn_cast<TypedInit>(Args.back());
+      if (!ArgBack) {
+        Error(OpLoc, Twine("expected value to be a typed value, got '" +
+                           Args.back()->getAsString() + "'"));
+        return nullptr;
+      }
+
+      RecTy *ArgBackType = ArgBack->getType();
+      if (!FirstArgIsList || Args.size() == 1) {
+        if (Args.size() == 1 && isa<ListRecTy>(ArgBackType)) {
+          FirstArgIsList = true; // Detect error if 2nd arg were present.
+        } else if (isa<IntRecTy>(ArgBackType)) {
+          // Assume 2nd arg should be IntRecTy
+        } else {
+          if (Args.size() != 1)
+            Error(InitLoc, Twine("expected value of type 'int', got '" +
+                                 ArgBackType->getAsString() + "'"));
+          else
+            Error(InitLoc, Twine("expected list or int, got value of type '") +
+                               ArgBackType->getAsString() + "'");
+          return nullptr;
+        }
+      } else {
+        // Don't come here unless 1st arg is ListRecTy.
+        assert(isa<ListRecTy>(cast<TypedInit>(Args[0])->getType()));
+        Error(InitLoc, Twine("expected one list, got extra value of type '") +
+                           ArgBackType->getAsString() + "'");
+        return nullptr;
+      }
+      if (!consume(tgtok::comma))
+        break;
+    }
+
+    if (!consume(tgtok::r_paren)) {
+      TokError("expected ')' in operator");
+      return nullptr;
+    }
+
+    Init *LHS, *MHS, *RHS;
+    auto ArgCount = Args.size();
+    assert(ArgCount >= 1);
+    auto *Arg0 = cast<TypedInit>(Args[0]);
+    auto *Arg0Ty = Arg0->getType();
+    if (ArgCount == 1) {
+      if (isa<ListRecTy>(Arg0Ty)) {
+        // (0, !size(arg), 1)
+        LHS = IntInit::get(Records, 0);
+        MHS = UnOpInit::get(UnOpInit::SIZE, Arg0, IntRecTy::get(Records))
+                  ->Fold(CurRec);
+        RHS = IntInit::get(Records, 1);
+      } else {
+        assert(isa<IntRecTy>(Arg0Ty));
+        // (0, arg, 1)
+        LHS = IntInit::get(Records, 0);
+        MHS = Arg0;
+        RHS = IntInit::get(Records, 1);
+      }
+    } else {
+      assert(isa<IntRecTy>(Arg0Ty));
+      auto *Arg1 = cast<TypedInit>(Args[1]);
+      assert(isa<IntRecTy>(Arg1->getType()));
+      LHS = Arg0;
+      MHS = Arg1;
+      if (ArgCount == 3) {
+        // (start, end, step)
+        auto *Arg2 = cast<TypedInit>(Args[2]);
+        assert(isa<IntRecTy>(Arg2->getType()));
+        RHS = Arg2;
+      } else
+        // (start, end, 1)
+        RHS = IntInit::get(Records, 1);
+    }
+    return TernOpInit::get(TernOpInit::RANGE, LHS, MHS, RHS,
+                           IntRecTy::get(Records)->getListTy())
+        ->Fold(CurRec);
+  }
+
   case tgtok::XSetDagArg:
   case tgtok::XSetDagName:
   case tgtok::XDag:
@@ -2534,6 +2575,7 @@ Init *TGParser::ParseOperationCond(Record *CurRec, RecTy *ItemType) {
 ///   SimpleValue ::= LISTREMOVETOK '(' Value ',' Value ')'
 ///   SimpleValue ::= RANGE '(' Value ')'
 ///   SimpleValue ::= RANGE '(' Value ',' Value ')'
+///   SimpleValue ::= RANGE '(' Value ',' Value ',' Value ')'
 ///   SimpleValue ::= STRCONCATTOK '(' Value ',' Value ')'
 ///   SimpleValue ::= COND '(' [Value ':' Value,]+ ')'
 ///
diff --git a/llvm/test/TableGen/range-op-fail.td b/llvm/test/TableGen/range-op-fail.td
index 3238d60abe7f89c..e20bc2eefa692b4 100644
--- a/llvm/test/TableGen/range-op-fail.td
+++ b/llvm/test/TableGen/range-op-fail.td
@@ -16,8 +16,8 @@ defvar errs = !range(0, list_int);
 
 #ifdef ERR_TOO_MANY_ARGS
 // RUN: not llvm-tblgen %s -DERR_TOO_MANY_ARGS 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR_TOO_MANY_ARGS
-// ERR_TOO_MANY_ARGS: [[FILE]]:[[@LINE+1]]:15: error: expected at most two values of integer
-defvar errs = !range(0, 42, 255);
+// ERR_TOO_MANY_ARGS: [[FILE]]:[[@LINE+1]]:34: error: expected at most three values of integer
+defvar errs = !range(0, 42, 255, 233);
 #endif
 
 #ifdef ERR_UNEXPECTED_TYPE_0
diff --git a/llvm/test/TableGen/range-op.td b/llvm/test/TableGen/range-op.td
index b7e587d47b8cba5..55af93a3aedb5d6 100644
--- a/llvm/test/TableGen/range-op.td
+++ b/llvm/test/TableGen/range-op.td
@@ -11,10 +11,26 @@ def range_op_ranges {
   // CHECK: list<int> idxs4 = [4, 5, 6, 7];
   list<int> idxs4 = !range(4, 8);
 
+  // !range(m, n, s) generates half-open interval "[m,n)" with step s
+  // CHECK: list<int> step = [0, 2, 4, 6];
+  list<int> step = !range(0, 8, 2);
+
+  // Step can be negative.
+  // CHECK: list<int> negative_step = [8, 6, 4, 2];
+  list<int> negative_step = !range(8, 0, -2);
+
   // !range(m, n) generates empty set if m >= n
   // CHECK: list<int> idxs84 = [];
   list<int> idxs84 = !range(8, 4);
 
+  // !range(m, n, s) generates empty set if m < n and s < 0
+  // CHECK: list<int> emptycase0 = [];
+  list<int> emptycase0 = !range(4, 8, -1);
+
+  // !range(m, n, s) generates empty set if m > n and s > 0
+  // CHECK: list<int> emptycase1 = [];
+  list<int> emptycase1 = !range(8, 4, 1);
+
   // !range(list) generates index values for the list. (Not a copy of the list)
   // CHECK: list<int> idxsl = [0, 1, 2, 3];
   list<int> idxsl = !range(idxs4);



More information about the llvm-commits mailing list