[llvm] [TableGen] Make `!and` and `!or` short-circuit (PR #113963)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 11:26:16 PST 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/113963

>From 425ddc6c5d96fdda6aa23d5c95216a5ef2fcfc1c Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 28 Oct 2024 09:46:20 -0700
Subject: [PATCH 1/3] [TableGen] Make `!and` short-circuit when either of the
 operand is zero

By preemptively simplifying the result of `!and`, we can fold some of
the conditional operators, like `!if` or `!cond`, as early as possible.
---
 llvm/lib/TableGen/Record.cpp     | 17 +++++++++++++++++
 llvm/test/TableGen/true-false.td | 14 ++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 1d71482b020b22..1f01b82685718a 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1543,6 +1543,23 @@ const Init *BinOpInit::resolveReferences(Resolver &R) const {
   const Init *lhs = LHS->resolveReferences(R);
   const Init *rhs = RHS->resolveReferences(R);
 
+  if (getOpcode() == AND) {
+    // Short-circuit. Regardless whether this is a logical or bitwise
+    // AND.
+    if (lhs != LHS)
+      if (const auto *LHSi = dyn_cast_or_null<IntInit>(
+              lhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
+        if (!LHSi->getValue())
+          return LHSi;
+      }
+    if (rhs != RHS)
+      if (const auto *RHSi = dyn_cast_or_null<IntInit>(
+              rhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
+        if (!RHSi->getValue())
+          return RHSi;
+      }
+  }
+
   if (LHS != lhs || RHS != rhs)
     return (BinOpInit::get(getOpcode(), lhs, rhs, getType()))
         ->Fold(R.getCurrentRecord());
diff --git a/llvm/test/TableGen/true-false.td b/llvm/test/TableGen/true-false.td
index 597ad9f5ecc8e7..a9094884dbdb9f 100644
--- a/llvm/test/TableGen/true-false.td
+++ b/llvm/test/TableGen/true-false.td
@@ -67,6 +67,20 @@ def rec7 {
   bits<3> flags = { true, false, true };
 }
 
+// The `!and` should be short-circuit such that `!tail` on empty list will never
+// be evaluated.
+// CHECK: def rec8
+// CHECK:   list<int> newSeq = [];
+// CHECK:   list<int> newSeq2 = [];
+
+class Foo <list<int> seq = []> {
+  bit containsStr = !ne(!find(NAME, "BAR"), -1);
+  list<int> newSeq  = !if(!and(!not(!empty(seq)), containsStr), !tail(seq), seq);
+  list<int> newSeq2 = !if(!and(containsStr, !not(!empty(seq))), !tail(seq), seq);
+}
+
+def rec8 : Foo<>;
+
 #ifdef ERROR1
 // ERROR1: Record name '1' is not a string
 

>From 8cce6cecbf060c11e710ee2c6f1fcb8e8ba5d2c1 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 30 Oct 2024 11:58:37 -0700
Subject: [PATCH 2/3] Add short-circuit for `!or` and add documentations

---
 llvm/docs/TableGen/ProgRef.rst   |  6 ++++--
 llvm/lib/TableGen/Record.cpp     | 34 +++++++++++++++++++-------------
 llvm/test/TableGen/true-false.td |  6 +++++-
 3 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index 5cf48d6ed29786..08614beb937bde 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -1646,7 +1646,8 @@ and non-0 as true.
 ``!and(``\ *a*\ ``,`` *b*\ ``, ...)``
     This operator does a bitwise AND on *a*, *b*, etc., and produces the
     result. A logical AND can be performed if all the arguments are either
-    0 or 1.
+    0 or 1. This operator is short-circuit to 0 when one of the operands
+    is 0.
 
 ``!cast<``\ *type*\ ``>(``\ *a*\ ``)``
     This operator performs a cast on *a* and produces the result.
@@ -1872,7 +1873,8 @@ and non-0 as true.
 ``!or(``\ *a*\ ``,`` *b*\ ``, ...)``
     This operator does a bitwise OR on *a*, *b*, etc., and produces the
     result. A logical OR can be performed if all the arguments are either
-    0 or 1.
+    0 or 1. This operator is short-circuit to -1 (all ones) if one of the
+    operands is -1.
 
 ``!range([``\ *start*\ ``,]`` *end*\ ``[,``\ *step*\ ``])``
     This operator produces half-open range sequence ``[start : end : step)`` as
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 1f01b82685718a..311b1e3adedf86 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1543,21 +1543,27 @@ const Init *BinOpInit::resolveReferences(Resolver &R) const {
   const Init *lhs = LHS->resolveReferences(R);
   const Init *rhs = RHS->resolveReferences(R);
 
-  if (getOpcode() == AND) {
+  unsigned Opc = getOpcode();
+  if (Opc == AND || Opc == OR) {
     // Short-circuit. Regardless whether this is a logical or bitwise
-    // AND.
-    if (lhs != LHS)
-      if (const auto *LHSi = dyn_cast_or_null<IntInit>(
-              lhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
-        if (!LHSi->getValue())
-          return LHSi;
-      }
-    if (rhs != RHS)
-      if (const auto *RHSi = dyn_cast_or_null<IntInit>(
-              rhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
-        if (!RHSi->getValue())
-          return RHSi;
-      }
+    // AND/OR.
+    // Ideally we could also short-circuit `!or(true, ...)`, but it's
+    // difficult to do it right without knowing if rest of the operands
+    // are all `bit` or not. Therefore, we're only implementing a relatively
+    // limited version of short-circuit against all ones (`true` is casted
+    // to 1 rather than all ones before we evaluate `!or`).
+    if (const auto *LHSi = dyn_cast_or_null<IntInit>(
+            lhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
+      if ((Opc == AND && !LHSi->getValue()) ||
+          (Opc == OR && LHSi->getValue() == -1))
+        return LHSi;
+    }
+    if (const auto *RHSi = dyn_cast_or_null<IntInit>(
+            rhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
+      if ((Opc == AND && !RHSi->getValue()) ||
+          (Opc == OR && RHSi->getValue() == -1))
+        return RHSi;
+    }
   }
 
   if (LHS != lhs || RHS != rhs)
diff --git a/llvm/test/TableGen/true-false.td b/llvm/test/TableGen/true-false.td
index a9094884dbdb9f..b86569fceb7c94 100644
--- a/llvm/test/TableGen/true-false.td
+++ b/llvm/test/TableGen/true-false.td
@@ -67,16 +67,20 @@ def rec7 {
   bits<3> flags = { true, false, true };
 }
 
-// The `!and` should be short-circuit such that `!tail` on empty list will never
+// `!and` and `!or` should be short-circuit such that `!tail` on empty list will never
 // be evaluated.
 // CHECK: def rec8
 // CHECK:   list<int> newSeq = [];
 // CHECK:   list<int> newSeq2 = [];
+// CHECK:   list<int> newSeq3 = [];
+// CHECK:   list<int> newSeq4 = [];
 
 class Foo <list<int> seq = []> {
   bit containsStr = !ne(!find(NAME, "BAR"), -1);
   list<int> newSeq  = !if(!and(!not(!empty(seq)), containsStr), !tail(seq), seq);
   list<int> newSeq2 = !if(!and(containsStr, !not(!empty(seq))), !tail(seq), seq);
+  list<int> newSeq3 = !if(!or(containsStr, -1), seq, !tail(seq));
+  list<int> newSeq4 = !if(!or(-1, containsStr), seq, !tail(seq));
 }
 
 def rec8 : Foo<>;

>From b20cf297df2f64c109a2912661e6ba946147edc5 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Mon, 4 Nov 2024 11:25:36 -0800
Subject: [PATCH 3/3] Only short circuit against the left-most operand

And cleanup the test
---
 llvm/docs/TableGen/ProgRef.rst   |  6 +++---
 llvm/lib/TableGen/Record.cpp     |  6 ------
 llvm/test/TableGen/true-false.td | 10 +++-------
 3 files changed, 6 insertions(+), 16 deletions(-)

diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index 08614beb937bde..03fe1157b4042e 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -1646,7 +1646,7 @@ and non-0 as true.
 ``!and(``\ *a*\ ``,`` *b*\ ``, ...)``
     This operator does a bitwise AND on *a*, *b*, etc., and produces the
     result. A logical AND can be performed if all the arguments are either
-    0 or 1. This operator is short-circuit to 0 when one of the operands
+    0 or 1. This operator is short-circuit to 0 when the left-most operand
     is 0.
 
 ``!cast<``\ *type*\ ``>(``\ *a*\ ``)``
@@ -1873,8 +1873,8 @@ and non-0 as true.
 ``!or(``\ *a*\ ``,`` *b*\ ``, ...)``
     This operator does a bitwise OR on *a*, *b*, etc., and produces the
     result. A logical OR can be performed if all the arguments are either
-    0 or 1. This operator is short-circuit to -1 (all ones) if one of the
-    operands is -1.
+    0 or 1. This operator is short-circuit to -1 (all ones) the left-most
+    operand is -1.
 
 ``!range([``\ *start*\ ``,]`` *end*\ ``[,``\ *step*\ ``])``
     This operator produces half-open range sequence ``[start : end : step)`` as
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 311b1e3adedf86..feef51f3d203cd 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1558,12 +1558,6 @@ const Init *BinOpInit::resolveReferences(Resolver &R) const {
           (Opc == OR && LHSi->getValue() == -1))
         return LHSi;
     }
-    if (const auto *RHSi = dyn_cast_or_null<IntInit>(
-            rhs->convertInitializerTo(IntRecTy::get(getRecordKeeper())))) {
-      if ((Opc == AND && !RHSi->getValue()) ||
-          (Opc == OR && RHSi->getValue() == -1))
-        return RHSi;
-    }
   }
 
   if (LHS != lhs || RHS != rhs)
diff --git a/llvm/test/TableGen/true-false.td b/llvm/test/TableGen/true-false.td
index b86569fceb7c94..5a59f20b21d252 100644
--- a/llvm/test/TableGen/true-false.td
+++ b/llvm/test/TableGen/true-false.td
@@ -72,15 +72,11 @@ def rec7 {
 // CHECK: def rec8
 // CHECK:   list<int> newSeq = [];
 // CHECK:   list<int> newSeq2 = [];
-// CHECK:   list<int> newSeq3 = [];
-// CHECK:   list<int> newSeq4 = [];
 
 class Foo <list<int> seq = []> {
-  bit containsStr = !ne(!find(NAME, "BAR"), -1);
-  list<int> newSeq  = !if(!and(!not(!empty(seq)), containsStr), !tail(seq), seq);
-  list<int> newSeq2 = !if(!and(containsStr, !not(!empty(seq))), !tail(seq), seq);
-  list<int> newSeq3 = !if(!or(containsStr, -1), seq, !tail(seq));
-  list<int> newSeq4 = !if(!or(-1, containsStr), seq, !tail(seq));
+  bit unresolved = !ne(!find(NAME, "BAR"), -1);
+  list<int> newSeq  = !if(!and(false, unresolved), !tail(seq), seq);
+  list<int> newSeq2 = !if(!or(-1, unresolved), seq, !tail(seq));
 }
 
 def rec8 : Foo<>;



More information about the llvm-commits mailing list