[llvm] [LLVM][TableGen] Support type casts of nodes with multiple results (PR #109728)

Stephen Chou via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 25 22:45:47 PDT 2024


https://github.com/stephenchouca updated https://github.com/llvm/llvm-project/pull/109728

>From a1fae648ae9cc6ce08e8d965d4ade39951a9c0bb Mon Sep 17 00:00:00 2001
From: Stephen Chou <stephenchou at google.com>
Date: Thu, 26 Sep 2024 05:44:26 +0000
Subject: [PATCH] [LLVM][TableGen] Support type casts of nodes with multiple
 results

Currently, type casts can only be used to pattern match for intrinsics with a single overloaded return value. For instance:
```
def int_foo : Intrinsic<[llvm_anyint_ty], []>;
def : Pat<(i32 (int_foo)), ...>;
```

This patch extends type casts to support matching intrinsics with multiple overloaded return values. As an example, the following defines a pattern that matches only if the overloaded intrinsic call returns an `i16` for the first result and an `i32` for the second result:
```
def int_bar : Intrinsic<[llvm_anyint_ty, llvm_anyint_ty], []>;
def : Pat<([i16, i32] (int_bar)), ...>;
```
---
 llvm/lib/TableGen/TGParser.cpp                |  8 ++--
 .../TableGen/invalid-type-cast-patfrags.td    | 43 ++++++++++++++++++
 .../TableGen/multiple-type-casts-patfrags.td  | 37 ++++++++++++++++
 .../TableGen/Common/CodeGenDAGPatterns.cpp    | 44 +++++++++++++++----
 4 files changed, 120 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/TableGen/invalid-type-cast-patfrags.td
 create mode 100644 llvm/test/TableGen/multiple-type-casts-patfrags.td

diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 54c9a902ec27a1..b83d8b304f5991 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -2866,11 +2866,13 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
 
     return ListInit::get(Vals, DeducedEltTy);
   }
-  case tgtok::l_paren: {         // Value ::= '(' IDValue DagArgList ')'
+  case tgtok::l_paren: { // Value ::= '(' IDValue DagArgList ')'
+                         // Value ::= '(' '[' ValueList ']' DagArgList ')'
     Lex.Lex();   // eat the '('
     if (Lex.getCode() != tgtok::Id && Lex.getCode() != tgtok::XCast &&
-        Lex.getCode() != tgtok::question && Lex.getCode() != tgtok::XGetDagOp) {
-      TokError("expected identifier in dag init");
+        Lex.getCode() != tgtok::question && Lex.getCode() != tgtok::XGetDagOp &&
+        Lex.getCode() != tgtok::l_square) {
+      TokError("expected identifier or list of value types in dag init");
       return nullptr;
     }
 
diff --git a/llvm/test/TableGen/invalid-type-cast-patfrags.td b/llvm/test/TableGen/invalid-type-cast-patfrags.td
new file mode 100644
index 00000000000000..c25cf0d8d027b0
--- /dev/null
+++ b/llvm/test/TableGen/invalid-type-cast-patfrags.td
@@ -0,0 +1,43 @@
+// RUN: not llvm-tblgen -gen-dag-isel -I %p/../../include -I %p/Common -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s
+// RUN: not llvm-tblgen -gen-dag-isel -I %p/../../include -I %p/Common -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s
+// RUN: not llvm-tblgen -gen-dag-isel -I %p/../../include -I %p/Common -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s
+// RUN: not llvm-tblgen -gen-dag-isel -I %p/../../include -I %p/Common -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s
+
+include "llvm/Target/Target.td"
+include "GlobalISelEmitterCommon.td"
+
+def int_foo : Intrinsic<[llvm_anyint_ty, llvm_anyint_ty], [llvm_i32_ty]>;
+def int_bar : Intrinsic<[], []>;
+
+def INSTR_FOO : Instruction {
+  let OutOperandList = (outs GPR32:$a, GPR32:$b);
+  let InOperandList = (ins GPR32:$c);
+}
+def INSTR_BAR : Instruction {
+  let OutOperandList = (outs);
+  let InOperandList = (ins);
+}
+
+def Defs {
+  list<ValueType> empty = [];
+}
+
+#ifdef ERROR1
+// ERROR1: [[@LINE+1]]:1: error: {{.*}} Invalid number of type casts!
+def : Pat<([i32, i32, i32] (int_foo (i32 GPR32:$a))), ([i32, i32, i32] (INSTR_FOO $a))>;
+#endif
+
+#ifdef ERROR2
+// ERROR2: [[@LINE+1]]:1: error: {{.*}} Invalid number of type casts!
+def : Pat<(Defs.empty (int_bar)), (Defs.empty (INSTR_BAR))>;
+#endif
+
+#ifdef ERROR3
+// ERROR3: [[@LINE+1]]:1: error: {{.*}} Type cast only takes one operand!
+def : Pat<([i32, i32] (int_foo), (int_foo)), ([i32, i32] (INSTR_FOO))>;
+#endif
+
+#ifdef ERROR4
+// ERROR4: [[@LINE+1]]:1: error: {{.*}} Type cast should not have a name!
+def : Pat<([i32, i32] ([i32, i32] (int_foo)):$name), ([i32, i32] (INSTR_FOO))>;
+#endif
diff --git a/llvm/test/TableGen/multiple-type-casts-patfrags.td b/llvm/test/TableGen/multiple-type-casts-patfrags.td
new file mode 100644
index 00000000000000..c4b4b62995fdb3
--- /dev/null
+++ b/llvm/test/TableGen/multiple-type-casts-patfrags.td
@@ -0,0 +1,37 @@
+// RUN: llvm-tblgen -gen-dag-isel -I %p/../../include -I %p/Common %s | FileCheck -check-prefix=SDAG %s
+// RUN: llvm-tblgen -gen-global-isel -optimize-match-table=false -warn-on-skipped-patterns -I %p/../../include -I %p/Common %s -o - < %s | FileCheck -check-prefix=GISEL %s
+
+include "llvm/Target/Target.td"
+include "GlobalISelEmitterCommon.td"
+
+def REG : Register<"REG">;
+def GPR : RegisterClass<"MyTarget", [i16, i32], 32, (add REG)>;
+
+def int_foo : Intrinsic<[llvm_anyint_ty, llvm_anyint_ty], []>;
+
+def INSTR_FOO_I16_I32 : Instruction {
+  let OutOperandList = (outs GPR:$a, GPR:$b);
+  let InOperandList = (ins);
+}
+def INSTR_FOO_I32_I16 : Instruction {
+  let OutOperandList = (outs GPR:$a, GPR:$b);
+  let InOperandList = (ins);
+}
+
+// SDAG: 7*/ OPC_SwitchType {{.*}}, 10, /*MVT::i16*/6
+// SDAG: OPC_CheckTypeRes, 1, /*MVT::i32*/7
+// SDAG: OPC_MorphNodeTo2Chain, TARGET_VAL(::INSTR_FOO_I16_I32)
+
+// GISEL: GIM_RootCheckType, /*Op*/0, /*Type*/GILLT_s16
+// GISEL: GIM_RootCheckType, /*Op*/1, /*Type*/GILLT_s32
+// GISEL: GIR_BuildRootMI, /*Opcode*/GIMT_Encode2(::INSTR_FOO_I16_I32)
+def : Pat<([i16, i32] (int_foo)), ([i16, i32] (INSTR_FOO_I16_I32))>;
+
+// SDAG: 20*/ /*SwitchType*/ {{.*}} /*MVT::i32*/7
+// SDAG: OPC_CheckTypeRes, 1, /*MVT::i16*/6
+// SDAG: OPC_MorphNodeTo2Chain, TARGET_VAL(::INSTR_FOO_I32_I16)
+
+// GISEL: GIM_RootCheckType, /*Op*/0, /*Type*/GILLT_s32
+// GISEL: GIM_RootCheckType, /*Op*/1, /*Type*/GILLT_s16
+// GISEL: GIR_BuildRootMI, /*Opcode*/GIMT_Encode2(::INSTR_FOO_I32_I16)
+def : Pat<([i32, i16] (int_foo)), ([i32, i16] (INSTR_FOO_I32_I16))>;
diff --git a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
index e8cf7e3998e125..cac1d9b4e0bbcf 100644
--- a/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
+++ b/llvm/utils/TableGen/Common/CodeGenDAGPatterns.cpp
@@ -2886,6 +2886,35 @@ TreePatternNodePtr TreePattern::ParseTreePattern(Init *TheInit,
     error("Pattern has unexpected init kind!");
     return nullptr;
   }
+
+  auto ParseCastOperand = [this](DagInit *Dag, StringRef OpName) {
+    if (Dag->getNumArgs() != 1)
+      error("Type cast only takes one operand!");
+
+    if (!OpName.empty())
+      error("Type cast should not have a name!");
+
+    return ParseTreePattern(Dag->getArg(0), Dag->getArgNameStr(0));
+  };
+
+  if (ListInit *LI = dyn_cast<ListInit>(Dag->getOperator())) {
+    // If the operator is a list (of value types), then this must be "type cast"
+    // of a leaf node with multiple results.
+    TreePatternNodePtr New = ParseCastOperand(Dag, OpName);
+
+    size_t NumTypes = New->getNumTypes();
+    if (LI->size() == 0 || LI->size() != NumTypes)
+      error("Invalid number of type casts!");
+
+    // Apply the type casts.
+    const CodeGenHwModes &CGH = getDAGPatterns().getTargetInfo().getHwModes();
+    for (unsigned i = 0; i < std::min(NumTypes, LI->size()); ++i)
+      New->UpdateNodeType(
+          i, getValueTypeByHwMode(LI->getElementAsRecord(i), CGH), *this);
+
+    return New;
+  }
+
   DefInit *OpDef = dyn_cast<DefInit>(Dag->getOperator());
   if (!OpDef) {
     error("Pattern has unexpected operator type!");
@@ -2896,20 +2925,15 @@ TreePatternNodePtr TreePattern::ParseTreePattern(Init *TheInit,
   if (Operator->isSubClassOf("ValueType")) {
     // If the operator is a ValueType, then this must be "type cast" of a leaf
     // node.
-    if (Dag->getNumArgs() != 1)
-      error("Type cast only takes one operand!");
+    TreePatternNodePtr New = ParseCastOperand(Dag, OpName);
 
-    TreePatternNodePtr New =
-        ParseTreePattern(Dag->getArg(0), Dag->getArgNameStr(0));
+    if (New->getNumTypes() != 1)
+      error("ValueType cast can only have one type!");
 
     // Apply the type cast.
-    if (New->getNumTypes() != 1)
-      error("Type cast can only have one type!");
     const CodeGenHwModes &CGH = getDAGPatterns().getTargetInfo().getHwModes();
     New->UpdateNodeType(0, getValueTypeByHwMode(Operator, CGH), *this);
 
-    if (!OpName.empty())
-      error("ValueType cast should not have a name!");
     return New;
   }
 
@@ -4223,8 +4247,10 @@ void CodeGenDAGPatterns::ParseOnePattern(
   Pattern.InlinePatternFragments();
   Result.InlinePatternFragments();
 
-  if (Result.getNumTrees() != 1)
+  if (Result.getNumTrees() != 1) {
     Result.error("Cannot use multi-alternative fragments in result pattern!");
+    return;
+  }
 
   // Infer types.
   bool IterateInference;



More information about the llvm-commits mailing list