[llvm] [GlobalISel] Add `GITypeOf` special type (PR #66079)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 18 23:00:25 PDT 2023


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/66079

>From 27f5e90fd433ed419dd3fbc22017d42f28854522 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 12 Sep 2023 13:14:37 +0200
Subject: [PATCH 1/2] [GlobalISel] Add `GITypeOf` special type

Allows creating a register/immediate that uses the same type as a matched operand.
---
 llvm/docs/GlobalISel/MIRPatterns.rst          |  42 +++
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |   3 -
 .../CodeGen/GlobalISel/GIMatchTableExecutor.h |  16 +-
 .../GlobalISel/GIMatchTableExecutorImpl.h     |  50 ++--
 .../include/llvm/Target/GlobalISel/Combine.td |  25 +-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  12 -
 .../match-table-typeof.td                     |  49 +++
 .../operand-types.td                          |  28 +-
 .../pattern-parsing.td                        |  25 +-
 .../typeof-errors.td                          |  72 +++++
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 278 +++++++++++++++---
 llvm/utils/TableGen/GlobalISelMatchTable.cpp  |  36 ++-
 llvm/utils/TableGen/GlobalISelMatchTable.h    |  89 +++++-
 13 files changed, 622 insertions(+), 103 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td

diff --git a/llvm/docs/GlobalISel/MIRPatterns.rst b/llvm/docs/GlobalISel/MIRPatterns.rst
index fa70311f48572de..a3883b14b3e0bd6 100644
--- a/llvm/docs/GlobalISel/MIRPatterns.rst
+++ b/llvm/docs/GlobalISel/MIRPatterns.rst
@@ -101,6 +101,48 @@ pattern, you can try naming your patterns to see exactly where the issue is.
   // using $x again here copies operand 1 from G_AND into the new inst.
   (apply (COPY $root, $x))
 
+Types
+-----
+
+ValueType
+~~~~~~~~~
+
+Subclasses of ``ValueType`` are valid types, e.g. ``i32``.
+
+GITypeOf
+~~~~~~~~
+
+``GITypeOf<"$x">`` is a ``GISpecialType`` that allows for the creation of a
+register or immediate with the same type as another (register) operand.
+
+Operand:
+
+* An operand name as a string, prefixed by ``$``.
+
+Semantics:
+
+* Can only appear in an 'apply' pattern.
+* The operand name used must appear in the 'match' pattern of the
+  same ``GICombineRule``.
+
+.. code-block:: text
+  :caption: Example: Immediate
+
+  def mul_by_neg_one: GICombineRule <
+    (defs root:$root),
+    (match (G_MUL $dst, $x, -1)),
+    (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  >;
+
+.. code-block:: text
+  :caption: Example: Temp Reg
+
+  def Test0 : GICombineRule<
+    (defs root:$dst),
+    (match (G_FMUL $dst, $src, -1)),
+    (apply (G_FSUB $dst, $src, $tmp),
+           (G_FNEG GITypeOf<"$dst">:$tmp, $src))>;
+
 Builtin Operations
 ------------------
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index d64b414f2747621..c6270fbff1b5496 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -403,9 +403,6 @@ class CombinerHelper {
   void applyCombineTruncOfShift(MachineInstr &MI,
                                 std::pair<MachineInstr *, LLT> &MatchInfo);
 
-  /// Transform G_MUL(x, -1) to G_SUB(0, x)
-  void applyCombineMulByNegativeOne(MachineInstr &MI);
-
   /// Return true if any explicit use operand on \p MI is defined by a
   /// G_IMPLICIT_DEF.
   bool matchAnyExplicitUseIsUndef(MachineInstr &MI);
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
index 209f80c6d6d2877..eefa0724222a9fa 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
@@ -275,6 +275,12 @@ enum {
   /// - StoreIdx - Store location in RecordedOperands.
   GIM_RecordNamedOperand,
 
+  /// Records an operand's register type into the set of temporary types.
+  /// - InsnID - Instruction ID
+  /// - OpIdx - Operand index
+  /// - TempTypeIdx - Temp Type Index, always negative.
+  GIM_RecordRegType,
+
   /// Fail the current try-block, or completely fail to match if there is no
   /// current try-block.
   GIM_Reject,
@@ -356,12 +362,6 @@ enum {
   /// - Imm - The immediate to add
   GIR_AddImm,
 
-  /// Add an CImm to the specified instruction
-  /// - InsnID - Instruction ID to modify
-  /// - Ty - Type of the constant immediate.
-  /// - Imm - The immediate to add
-  GIR_AddCImm,
-
   /// Render complex operands to the specified instruction
   /// - InsnID - Instruction ID to modify
   /// - RendererID - The renderer to call
@@ -522,6 +522,10 @@ class GIMatchTableExecutor {
     /// list. Currently such predicates don't have more then 3 arguments.
     std::array<const MachineOperand *, 3> RecordedOperands;
 
+    /// Types extracted from an instruction's operand.
+    /// Whenever a type index is negative, we look here instead.
+    SmallVector<LLT, 4> RecordedTypes;
+
     MatcherState(unsigned MaxRenderers);
   };
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index fb03d5ec0bc89a9..9fd56a2c670e0ac 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -92,6 +92,14 @@ bool GIMatchTableExecutor::executeMatchTable(
     return true;
   };
 
+  // If the index is >= 0, it's an index in the type objects generated by
+  // TableGen. If the index is <0, it's an index in the recorded types object.
+  auto getTypeFromIdx = [&](int64_t Idx) -> const LLT & {
+    if (Idx >= 0)
+      return ExecInfo.TypeObjects[Idx];
+    return State.RecordedTypes[1 - Idx];
+  };
+
   while (true) {
     assert(CurrentIdx != ~0u && "Invalid MatchTable index");
     int64_t MatcherOpcode = MatchTable[CurrentIdx++];
@@ -627,8 +635,7 @@ bool GIMatchTableExecutor::executeMatchTable(
                              << "), TypeID=" << TypeID << ")\n");
       assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
       MachineOperand &MO = State.MIs[InsnID]->getOperand(OpIdx);
-      if (!MO.isReg() ||
-          MRI.getType(MO.getReg()) != ExecInfo.TypeObjects[TypeID]) {
+      if (!MO.isReg() || MRI.getType(MO.getReg()) != getTypeFromIdx(TypeID)) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       }
@@ -679,6 +686,25 @@ bool GIMatchTableExecutor::executeMatchTable(
       State.RecordedOperands[StoreIdx] = &State.MIs[InsnID]->getOperand(OpIdx);
       break;
     }
+    case GIM_RecordRegType: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      int64_t OpIdx = MatchTable[CurrentIdx++];
+      int64_t TypeIdx = MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIM_RecordRegType(MIs["
+                             << InsnID << "]->getOperand(" << OpIdx
+                             << "), TypeIdx=" << TypeIdx << ")\n");
+      assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
+      assert(TypeIdx <= 0 && "Temp types always have negative indexes!");
+      // Indexes start at -1.
+      TypeIdx = 1 - TypeIdx;
+      const auto &Op = State.MIs[InsnID]->getOperand(OpIdx);
+      if (State.RecordedTypes.size() <= (uint64_t)TypeIdx)
+        State.RecordedTypes.resize(TypeIdx + 1, LLT());
+      State.RecordedTypes[TypeIdx] = MRI.getType(Op.getReg());
+      break;
+    }
     case GIM_CheckRegBankForClass: {
       int64_t InsnID = MatchTable[CurrentIdx++];
       int64_t OpIdx = MatchTable[CurrentIdx++];
@@ -1068,24 +1094,6 @@ bool GIMatchTableExecutor::executeMatchTable(
                              << "], " << Imm << ")\n");
       break;
     }
-
-    case GIR_AddCImm: {
-      int64_t InsnID = MatchTable[CurrentIdx++];
-      int64_t TypeID = MatchTable[CurrentIdx++];
-      int64_t Imm = MatchTable[CurrentIdx++];
-      assert(OutMIs[InsnID] && "Attempted to add to undefined instruction");
-
-      unsigned Width = ExecInfo.TypeObjects[TypeID].getScalarSizeInBits();
-      LLVMContext &Ctx = MF->getFunction().getContext();
-      OutMIs[InsnID].addCImm(
-          ConstantInt::get(IntegerType::get(Ctx, Width), Imm, /*signed*/ true));
-      DEBUG_WITH_TYPE(TgtExecutor::getName(),
-                      dbgs() << CurrentIdx << ": GIR_AddCImm(OutMIs[" << InsnID
-                             << "], TypeID=" << TypeID << ", Imm=" << Imm
-                             << ")\n");
-      break;
-    }
-
     case GIR_ComplexRenderer: {
       int64_t InsnID = MatchTable[CurrentIdx++];
       int64_t RendererID = MatchTable[CurrentIdx++];
@@ -1275,7 +1283,7 @@ bool GIMatchTableExecutor::executeMatchTable(
       int64_t TypeID = MatchTable[CurrentIdx++];
 
       State.TempRegisters[TempRegID] =
-          MRI.createGenericVirtualRegister(ExecInfo.TypeObjects[TypeID]);
+          MRI.createGenericVirtualRegister(getTypeFromIdx(TypeID));
       DEBUG_WITH_TYPE(TgtExecutor::getName(),
                       dbgs() << CurrentIdx << ": TempRegs[" << TempRegID
                              << "] = GIR_MakeTempReg(" << TypeID << ")\n");
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 7e0691e1ee95048..31a337feed847ec 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -110,6 +110,24 @@ class GICombinePatFrag<dag outs, dag ins, list<dag> alts> {
   list<dag> Alternatives = alts;
 }
 
+//===----------------------------------------------------------------------===//
+// Pattern Special Types
+//===----------------------------------------------------------------------===//
+
+class GISpecialType;
+
+// In an apply pattern, GITypeOf can be used to set the type of a new temporary
+// register to match the type of a matched register.
+//
+// This can only be used on temporary registers defined by the apply pattern.
+//
+// TODO: Make this work in matchers as well?
+//
+// FIXME: Syntax is very ugly.
+class GITypeOf<string opName> : GISpecialType {
+  string OpName = opName;
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern Builtins
 //===----------------------------------------------------------------------===//
@@ -776,10 +794,9 @@ def trunc_shift: GICombineRule <
 
 // Transform (mul x, -1) -> (sub 0, x)
 def mul_by_neg_one: GICombineRule <
-  (defs root:$root),
-  (match (wip_match_opcode G_MUL):$root,
-         [{ return Helper.matchConstantOp(${root}->getOperand(2), -1); }]),
-  (apply [{ Helper.applyCombineMulByNegativeOne(*${root}); }])
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 9efb70f28fee3ee..8a817d211957a5e 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -2224,18 +2224,6 @@ void CombinerHelper::applyCombineExtOfExt(
   }
 }
 
-void CombinerHelper::applyCombineMulByNegativeOne(MachineInstr &MI) {
-  assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
-  Register DstReg = MI.getOperand(0).getReg();
-  Register SrcReg = MI.getOperand(1).getReg();
-  LLT DstTy = MRI.getType(DstReg);
-
-  Builder.setInstrAndDebugLoc(MI);
-  Builder.buildSub(DstReg, Builder.buildConstant(DstTy, 0), SrcReg,
-                   MI.getFlags());
-  MI.eraseFromParent();
-}
-
 bool CombinerHelper::matchCombineTruncOfExt(
     MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
   assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
new file mode 100644
index 000000000000000..496d86aeef2d10a
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
@@ -0,0 +1,49 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -combiners=MyCombiner %s | \
+// RUN: FileCheck %s
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+def Test0 : GICombineRule<
+  (defs root:$dst),
+  (match (G_MUL $dst, $src, -1)),
+  (apply (G_SUB $dst, (GITypeOf<"$src"> 0), $tmp),
+         (G_CONSTANT GITypeOf<"$dst">:$tmp, (GITypeOf<"$src"> 42)))>;
+
+// CHECK:      const int64_t *GenMyCombiner::getMatchTable() const {
+// CHECK-NEXT:   constexpr static int64_t MatchTable0[] = {
+// CHECK-NEXT:     GIM_Try, /*On fail goto*//*Label 0*/ 57, // Rule ID 0 //
+// CHECK-NEXT:       GIM_CheckSimplePredicate, GICXXPred_Simple_IsRule0Enabled,
+// CHECK-NEXT:       GIM_CheckOpcode, /*MI*/0, TargetOpcode::G_MUL,
+// CHECK-NEXT:       // MIs[0] dst
+// CHECK-NEXT:       GIM_RecordRegType, /*MI*/0, /*Op*/0, /*TempTypeIdx*/-1,
+// CHECK-NEXT:       // MIs[0] src
+// CHECK-NEXT:       GIM_RecordRegType, /*MI*/0, /*Op*/1, /*TempTypeIdx*/-2,
+// CHECK-NEXT:       // MIs[0] Operand 2
+// CHECK-NEXT:       GIM_CheckConstantInt, /*MI*/0, /*Op*/2, -1,
+// CHECK-NEXT:       GIR_MakeTempReg, /*TempRegID*/1, /*TypeID*/-2,
+// CHECK-NEXT:       GIR_BuildConstant, /*TempRegID*/1, /*Val*/0,
+// CHECK-NEXT:       GIR_MakeTempReg, /*TempRegID*/0, /*TypeID*/-1,
+// CHECK-NEXT:       // Combiner Rule #0: Test0
+// CHECK-NEXT:       GIR_BuildMI, /*InsnID*/0, /*Opcode*/TargetOpcode::G_CONSTANT,
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/0, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_AddCImm, /*InsnID*/0, /*Type*/-2, /*Imm*/42,
+// CHECK-NEXT:       GIR_EraseFromParent, /*InsnID*/0,
+// CHECK-NEXT:       GIR_BuildMI, /*InsnID*/1, /*Opcode*/TargetOpcode::G_SUB,
+// CHECK-NEXT:       GIR_Copy, /*NewInsnID*/1, /*OldInsnID*/0, /*OpIdx*/0, // dst
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/1, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_Done,
+// CHECK-NEXT:     // Label 0: @57
+// CHECK-NEXT:     GIM_Reject,
+// CHECK-NEXT:     };
+// CHECK-NEXT:   return MatchTable0;
+// CHECK-NEXT: }
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  Test0
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
index c871e603e4e05aa..4769bed97240125 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
@@ -79,7 +79,33 @@ def PatFragTest0 : GICombineRule<
   (match (FooPF $dst)),
   (apply (COPY $dst, (i32 0)))>;
 
+
+// CHECK:      (CombineRule name:TypeOfProp id:2 root:x
+// CHECK-NEXT:   (MatchPats
+// CHECK-NEXT:     <match_root>__TypeOfProp_match_0:(CodeGenInstructionPattern G_ZEXT operands:[<def>$x, $y])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (ApplyPats
+// CHECK-NEXT:     <apply_root>__TypeOfProp_apply_0:(CodeGenInstructionPattern G_ANYEXT operands:[<def>$x, GITypeOf<$y>:$tmp])
+// CHECK-NEXT:     __TypeOfProp_apply_1:(CodeGenInstructionPattern G_ANYEXT operands:[<def>GITypeOf<$y>:$tmp, $y])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable MatchPats
+// CHECK-NEXT:     x -> __TypeOfProp_match_0
+// CHECK-NEXT:     y -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable ApplyPats
+// CHECK-NEXT:     tmp -> __TypeOfProp_apply_1
+// CHECK-NEXT:     x -> __TypeOfProp_apply_0
+// CHECK-NEXT:     y -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT: )
+def TypeOfProp : GICombineRule<
+  (defs root:$x),
+  (match (G_ZEXT $x, $y)),
+  (apply (G_ANYEXT $x, GITypeOf<"$y">:$tmp),
+         (G_ANYEXT $tmp, $y))>;
+
 def MyCombiner: GICombiner<"GenMyCombiner", [
   InstTest0,
-  PatFragTest0
+  PatFragTest0,
+  TypeOfProp
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
index bc75b15233b5519..fd41a7d1d72417e 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
@@ -297,6 +297,28 @@ def VariadicsOutTest : GICombineRule<
   (apply (COPY $a, (i32 0)),
          (COPY $b, (i32 0)))>;
 
+// CHECK:      (CombineRule name:TypeOfTest id:10 root:dst
+// CHECK-NEXT:   (MatchPats
+// CHECK-NEXT:     <match_root>__TypeOfTest_match_0:(CodeGenInstructionPattern COPY operands:[<def>$dst, $tmp])
+// CHECK-NEXT:     __TypeOfTest_match_1:(CodeGenInstructionPattern G_ZEXT operands:[<def>$tmp, $src])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (ApplyPats
+// CHECK-NEXT:     <apply_root>__TypeOfTest_apply_0:(CodeGenInstructionPattern G_MUL operands:[<def>$dst, (GITypeOf<$src> 0), (GITypeOf<$dst> -1)])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable MatchPats
+// CHECK-NEXT:     dst -> __TypeOfTest_match_0
+// CHECK-NEXT:     src -> <live-in>
+// CHECK-NEXT:     tmp -> __TypeOfTest_match_1
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable ApplyPats
+// CHECK-NEXT:     dst -> __TypeOfTest_apply_0
+// CHECK-NEXT:   )
+// CHECK-NEXT: )
+def TypeOfTest : GICombineRule<
+  (defs root:$dst),
+  (match (COPY $dst, $tmp),
+         (G_ZEXT $tmp, $src)),
+  (apply (G_MUL $dst, (GITypeOf<"$src"> 0), (GITypeOf<"$dst"> -1)))>;
 
 def MyCombiner: GICombiner<"GenMyCombiner", [
   WipOpcodeTest0,
@@ -308,5 +330,6 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
   PatFragTest0,
   PatFragTest1,
   VariadicsInTest,
-  VariadicsOutTest
+  VariadicsOutTest,
+  TypeOfTest
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
new file mode 100644
index 000000000000000..6040d6def449766
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -0,0 +1,72 @@
+// RUN: not llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -combiners=MyCombiner %s 2>&1| \
+// RUN: FileCheck %s -implicit-check-not=error:
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+def NoDollarSign : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, (GITypeOf<"unknown"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'unknown' ('GITypeOf<$unknown>') does not refer to a matched operand!
+def UnknownOperand : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, (GITypeOf<"$unknown"> 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__UseInMatch_match_0' has type 'GITypeOf<$dst>'
+def UseInMatch : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, (GITypeOf<"$dst"> 0))),
+  (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: GISpecialType is not supported in GICombinePatFrag
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: operand 1 of '__PFWithTypeOF_alt0_pattern_0' has type 'GITypeOf<$dst>
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Could not parse GICombinePatFrag 'PFWithTypeOF'
+def PFWithTypeOF: GICombinePatFrag<
+    (outs $dst), (ins),
+    [(pattern (G_ANYEXT $dst, (GITypeOf<"$dst"> 0)))]>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(PFWithTypeOF ?:$dst)'
+def UseInPF: GICombineRule<
+  (defs root:$dst),
+  (match (PFWithTypeOF $dst)),
+  (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__InferredUseInMatch_match_0' has type 'GITypeOf<$dst>'
+def InferredUseInMatch : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+def InferenceConflict : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, i32:$src)),
+  (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'tmp' ('GITypeOf<$tmp>') does not refer to a matched operand!
+def TypeOfApplyTmp : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, i32:$tmp),
+         (G_ANYEXT $tmp, (GITypeOf<"$tmp"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse one or more rules
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  NoDollarSign,
+  UnknownOperand,
+  UseInMatch,
+  UseInPF,
+  InferredUseInMatch,
+  InferenceConflict,
+  TypeOfApplyTmp
+]>;
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 7992cb4362a1718..844d8b1051bd471 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -73,6 +73,8 @@ constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
 constexpr StringLiteral PatFragClassName = "GICombinePatFrag";
 constexpr StringLiteral BuiltinInstClassName = "GIBuiltinInst";
+constexpr StringLiteral SpecialTyClassName = "GISpecialType";
+constexpr StringLiteral TypeOfClassName = "GITypeOf";
 
 std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
   return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
@@ -123,11 +125,6 @@ template <typename Container> auto values(Container &&C) {
   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
 }
 
-LLTCodeGen getLLTCodeGenFromRecord(const Record *Ty) {
-  assert(Ty->isSubClassOf("ValueType"));
-  return LLTCodeGen(*MVTToLLT(getValueType(Ty)));
-}
-
 //===- MatchData Handling -------------------------------------------------===//
 
 /// Represents MatchData defined by the match stage and required by the apply
@@ -292,6 +289,117 @@ class CXXPredicateCode {
 CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXMatchCode;
 CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 
+//===- PatternType --------------------------------------------------------===//
+
+/// Simple wrapper around a Record* of a type.
+///
+/// Types have two form:
+///   - LLTs, which are straightforward.
+///   - Special types, e.g. GITypeOf
+class PatternType {
+public:
+  PatternType() = default;
+  PatternType(const Record *R) : R(R) {}
+
+  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+
+  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
+  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
+  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+
+  StringRef getTypeOfOpName() const;
+  LLTCodeGen getLLTCodeGen() const;
+
+  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
+
+  LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
+
+  explicit operator bool() const { return R != nullptr; }
+
+  bool operator==(const PatternType &Other) const;
+  bool operator!=(const PatternType &Other) const { return !operator==(Other); }
+
+  std::string str() const;
+
+private:
+  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+
+  const Record *R = nullptr;
+};
+
+StringRef PatternType::getTypeOfOpName() const {
+  assert(isTypeOf());
+  StringRef Name = getRawOpName();
+  Name.consume_front("$");
+  return Name;
+}
+
+LLTCodeGen PatternType::getLLTCodeGen() const {
+  assert(isLLT());
+  return *MVTToLLT(getValueType(R));
+}
+
+LLTCodeGenOrTempType
+PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
+  assert(isValidType());
+
+  if (isLLT())
+    return getLLTCodeGen();
+
+  assert(isTypeOf());
+  auto &OM = RM.getOperandMatcher(getTypeOfOpName());
+  return OM.getTempTypeIdx(RM);
+}
+
+bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
+  if (!isTypeOf())
+    return true;
+
+  auto RawOpName = getRawOpName();
+  if (RawOpName.starts_with("$"))
+    return true;
+
+  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
+                          TypeOfClassName +
+                          ": expected '$' followed by an operand name");
+  return false;
+}
+
+bool PatternType::operator==(const PatternType &Other) const {
+  if (R == Other.R) {
+    if (R && R->getName() != Other.R->getName()) {
+      dbgs() << "Same ptr but: " << R->getName() << " and "
+             << Other.R->getName() << "?\n";
+      assert(false);
+    }
+    return true;
+  }
+
+  if (isTypeOf() && Other.isTypeOf())
+    return getTypeOfOpName() == Other.getTypeOfOpName();
+
+  return false;
+}
+
+std::string PatternType::str() const {
+  if (!R)
+    return "";
+
+  if (!isValidType())
+    return "<invalid>";
+
+  if (isLLT())
+    return R->getName().str();
+
+  assert(isSpecial());
+
+  // TODO: Use an enum for these?
+  if (isTypeOf())
+    return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+
+  llvm_unreachable("Unknown type!");
+}
+
 //===- Pattern Base Class -------------------------------------------------===//
 
 /// Base class for all patterns that can be written in an `apply`, `match` or
@@ -499,13 +607,15 @@ class InstructionOperand {
 public:
   using IntImmTy = int64_t;
 
-  InstructionOperand(IntImmTy Imm, StringRef Name, const Record *Type)
+  InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
       : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(!Type || Type->isSubClassOf("ValueType"));
+    assert(Type.isValidType());
   }
 
-  InstructionOperand(StringRef Name, const Record *Type)
-      : Name(insertStrRef(Name)), Type(Type) {}
+  InstructionOperand(StringRef Name, PatternType Type)
+      : Name(insertStrRef(Name)), Type(Type) {
+    assert(Type.isValidType());
+  }
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -527,11 +637,12 @@ class InstructionOperand {
   void setIsDef(bool Value = true) { Def = Value; }
   bool isDef() const { return Def; }
 
-  void setType(const Record *R) {
-    assert((!Type || (Type == R)) && "Overwriting type!");
-    Type = R;
+  void setType(PatternType NewType) {
+    assert((!Type || (Type == NewType)) && "Overwriting type!");
+    assert(NewType.isValidType());
+    Type = NewType;
   }
-  const Record *getType() const { return Type; }
+  PatternType getType() const { return Type; }
 
   std::string describe() const {
     if (!hasImmValue())
@@ -547,11 +658,11 @@ class InstructionOperand {
       OS << "<def>";
 
     bool NeedsColon = true;
-    if (const Record *Ty = getType()) {
+    if (Type) {
       if (hasImmValue())
-        OS << "(" << Ty->getName() << " " << getImmValue() << ")";
+        OS << "(" << Type.str() << " " << getImmValue() << ")";
       else
-        OS << Ty->getName();
+        OS << Type.str();
     } else if (hasImmValue())
       OS << getImmValue();
     else
@@ -566,7 +677,7 @@ class InstructionOperand {
 private:
   std::optional<int64_t> Value;
   StringRef Name;
-  const Record *Type = nullptr;
+  PatternType Type;
   bool Def = false;
 };
 
@@ -622,6 +733,10 @@ class InstructionPattern : public Pattern {
 
   virtual StringRef getInstName() const = 0;
 
+  /// Diagnoses all uses of special types in this Pattern and returns true if at
+  /// least one diagnostic was emitted.
+  bool diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc, Twine Msg) const;
+
   void reportUnreachable(ArrayRef<SMLoc> Locs) const;
   virtual bool checkSemantics(ArrayRef<SMLoc> Loc);
 
@@ -633,6 +748,20 @@ class InstructionPattern : public Pattern {
   SmallVector<InstructionOperand, 4> Operands;
 };
 
+bool InstructionPattern::diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc,
+                                                 Twine Msg) const {
+  bool HasDiag = false;
+  for (const auto &[Idx, Op] : enumerate(operands())) {
+    if (Op.getType().isSpecial()) {
+      PrintError(Loc, Msg);
+      PrintNote(Loc, "operand " + Twine(Idx) + " of '" + getName() +
+                         "' has type '" + Op.getType().str() + "'");
+      HasDiag = true;
+    }
+  }
+  return HasDiag;
+}
+
 void InstructionPattern::reportUnreachable(ArrayRef<SMLoc> Locs) const {
   PrintError(Locs, "pattern '" + getName() + "' ('" + getInstName() +
                        "') is unreachable from the pattern root!");
@@ -829,17 +958,20 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
 /// setAllOperandTypes.
+///
+/// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P);
+  bool check(InstructionPattern *P,
+             std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
   void setAllOperandTypes();
 
 private:
   struct OpTypeInfo {
-    const Record *Type = nullptr;
+    PatternType Type;
     InstructionPattern *TypeSrc = nullptr;
   };
 
@@ -849,16 +981,26 @@ class OperandTypeChecker {
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
-bool OperandTypeChecker::check(InstructionPattern *P) {
+bool OperandTypeChecker::check(
+    InstructionPattern *P,
+    std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
   Pats.push_back(P);
 
-  for (auto &Op : P->named_operands()) {
-    const Record *Ty = Op.getType();
+  for (auto &Op : P->operands()) {
+    const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    if (!Ty.checkSemantics(DiagLoc))
+      return false;
+
+    if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
+      return false;
 
+    if (!Op.isNamedOperand())
+      continue;
+
+    auto &Info = Types[Op.getOperandName()];
     if (!Info.Type) {
       Info.Type = Ty;
       Info.TypeSrc = P;
@@ -868,9 +1010,9 @@ bool OperandTypeChecker::check(InstructionPattern *P) {
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
                               Op.getOperandName() + "': first seen with '" +
-                              Info.Type->getName() + "' in '" +
+                              Info.Type.str() + "' in '" +
                               Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty->getName() + "' in '" + P->getName() + "'");
+                              Ty.str() + "' in '" + P->getName() + "'");
       return false;
     }
   }
@@ -1058,7 +1200,12 @@ bool PatFrag::checkSemantics() {
                    PatFragClassName);
         return false;
       case Pattern::K_CXX:
+        continue;
       case Pattern::K_CodeGenInstruction:
+        if (cast<CodeGenInstructionPattern>(Pat.get())->diagnoseAllSpecialTypes(
+                Def.getLoc(), SpecialTyClassName + " is not supported in " +
+                                  PatFragClassName))
+          return false;
         continue;
       case Pattern::K_PatFrag:
         // TODO: It's just that the emitter doesn't handle it but technically
@@ -1142,12 +1289,16 @@ bool PatFrag::checkSemantics() {
 
   // TODO: find unused params
 
+  const auto CheckTypeOf = [&](const PatternType &) -> bool {
+    llvm_unreachable("GITypeOf should have been rejected earlier!");
+  };
+
   // Now, typecheck all alternatives.
   for (auto &Alt : Alts) {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP))
+        if (!OTC.check(IP, CheckTypeOf))
           return false;
       }
     }
@@ -1954,21 +2105,49 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 bool CombineRuleBuilder::typecheckPatterns() {
   OperandTypeChecker OTC(RuleDef.getLoc());
 
+  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
+    // We'll reject those after we're done inferring
+    return true;
+  };
+
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP))
+      if (!OTC.check(IP, CheckMatchTypeOf))
         return false;
     }
   }
 
+  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError("'" + OpName + "' ('" + Ty.str() +
+               "') does not refer to a matched operand!");
+    return false;
+  };
+
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP))
+      if (!OTC.check(IP, CheckApplyTypeOf))
         return false;
     }
   }
 
   OTC.setAllOperandTypes();
+
+  // Always check this after in case inference adds some special types to the
+  // match patterns.
+  for (auto &Pat : values(MatchPats)) {
+    if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
+      if (IP->diagnoseAllSpecialTypes(
+              RuleDef.getLoc(),
+              SpecialTyClassName + " is not supported in 'match' patterns")) {
+        return false;
+      }
+    }
+  }
   return true;
 }
 
@@ -2461,10 +2640,12 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
     if (DagOp->getNumArgs() != 1)
       return ParseErr();
 
-    Record *ImmTy = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    if (!ImmTy->isSubClassOf("ValueType")) {
+    const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
+    PatternType ImmTy(TyDef);
+    if (!ImmTy.isValidType()) {
       PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + ImmTy->getName() + "' is not a ValueType!");
+                 "', '" + TyDef->getName() + "' is not a ValueType or " +
+                 SpecialTyClassName);
       return false;
     }
 
@@ -2491,12 +2672,13 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    if (!Def->isSubClassOf("ValueType")) {
+    PatternType Ty(Def);
+    if (!Ty.isValidType()) {
       PrintError("invalid operand type: '" + Def->getName() +
                  "' is not a ValueType");
       return false;
     }
-    IP.addOperand(OpName->getAsUnquotedString(), Def);
+    IP.addOperand(OpName->getAsUnquotedString(), Ty);
     return true;
   }
 
@@ -2823,8 +3005,8 @@ bool CombineRuleBuilder::emitPatFragMatchPattern(
       StringRef PFName = PF.getName();
       PrintWarning("impossible type constraints: operand " + Twine(PIdx) +
                    " of '" + PFP.getName() + "' has type '" +
-                   ArgOp.getType()->getName() + "', but '" + PFName +
-                   "' constrains it to '" + O.getType()->getName() + "'");
+                   ArgOp.getType().str() + "', but '" + PFName +
+                   "' constrains it to '" + O.getType().str() + "'");
       if (ArgOp.isNamedOperand())
         PrintNote("operand " + Twine(PIdx) + " of '" + PFP.getName() +
                   "' is '" + ArgOp.getOperandName() + "'");
@@ -3055,17 +3237,18 @@ bool CombineRuleBuilder::emitInstructionApplyPattern(
       // This is a brand new register.
       TempRegID = M.allocateTempRegID();
       OperandToTempRegID[OpName] = TempRegID;
-      const Record *Ty = Op.getType();
+      const auto Ty = Op.getType();
       if (!Ty) {
         PrintError("def of a new register '" + OpName +
                    "' in the apply patterns must have a type");
         return false;
       }
+
       declareTempRegExpansion(CE, TempRegID, OpName);
       // Always insert the action at the beginning, otherwise we may end up
       // using the temp reg before it's available.
       M.insertAction<MakeTempRegisterAction>(
-          M.actions_begin(), getLLTCodeGenFromRecord(Ty), TempRegID);
+          M.actions_begin(), Ty.getLLTCodeGenOrTempType(M), TempRegID);
     }
 
     DstMI.addRenderer<TempRegRenderer>(TempRegID);
@@ -3088,7 +3271,7 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
   // G_CONSTANT is a special case and needs a CImm though so this is likely a
   // mistake.
   const bool isGConstant = P.is("G_CONSTANT");
-  const Record *Ty = O.getType();
+  const auto Ty = O.getType();
   if (!Ty) {
     if (isGConstant) {
       PrintError("'G_CONSTANT' immediate must be typed!");
@@ -3101,16 +3284,17 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
     return true;
   }
 
-  LLTCodeGen LLT = getLLTCodeGenFromRecord(Ty);
+  auto ImmTy = Ty.getLLTCodeGenOrTempType(M);
+
   if (isGConstant) {
-    DstMI.addRenderer<ImmRenderer>(O.getImmValue(), LLT);
+    DstMI.addRenderer<ImmRenderer>(O.getImmValue(), ImmTy);
     return true;
   }
 
   unsigned TempRegID = M.allocateTempRegID();
   // Ensure MakeTempReg & the BuildConstantAction occur at the beginning.
-  auto InsertIt =
-      M.insertAction<MakeTempRegisterAction>(M.actions_begin(), LLT, TempRegID);
+  auto InsertIt = M.insertAction<MakeTempRegisterAction>(M.actions_begin(),
+                                                         ImmTy, TempRegID);
   M.insertAction<BuildConstantAction>(++InsertIt, TempRegID, O.getImmValue());
   DstMI.addRenderer<TempRegRenderer>(TempRegID);
   return true;
@@ -3227,8 +3411,14 @@ bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
     // Always emit a check for unnamed operands.
     if (OpName.empty() ||
         !M.getOperandMatcher(OpName).contains<LLTOperandMatcher>()) {
-      if (const Record *Ty = RemappedO.getType())
-        OM.addPredicate<LLTOperandMatcher>(getLLTCodeGenFromRecord(Ty));
+      if (const auto Ty = RemappedO.getType()) {
+        // TODO: We could support GITypeOf here on the condition that the
+        // OperandMatcher exists already. Though it's clunky to make this work
+        // and isn't all that useful so it's just rejected in typecheckPatterns
+        // at this time.
+        assert(Ty.isLLT() && "Only LLTs are supported in match patterns!");
+        OM.addPredicate<LLTOperandMatcher>(Ty.getLLTCodeGen());
+      }
     }
 
     // Stop here if the operand is a def, or if it had no name.
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
index 9a4a375f34bdb91..6ec85269e6e20d0 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.cpp
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
@@ -822,6 +822,15 @@ const OperandMatcher &RuleMatcher::getPhysRegOperandMatcher(Record *Reg) const {
   return *I->second;
 }
 
+OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) {
+  const auto &I = DefinedOperands.find(Name);
+
+  if (I == DefinedOperands.end())
+    PrintFatalError(SrcLoc, "Operand " + Name + " was not declared in matcher");
+
+  return *I->second;
+}
+
 const OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) const {
   const auto &I = DefinedOperands.find(Name);
 
@@ -1081,6 +1090,17 @@ void RecordNamedOperandMatcher::emitPredicateOpcodes(MatchTable &Table,
         << MatchTable::Comment("Name : " + Name) << MatchTable::LineBreak;
 }
 
+//===- RecordRegisterType ------------------------------------------===//
+
+void RecordRegisterType::emitPredicateOpcodes(MatchTable &Table,
+                                              RuleMatcher &Rule) const {
+  assert(Idx < 0 && "Temp types always have negative indexes!");
+  Table << MatchTable::Opcode("GIM_RecordRegType") << MatchTable::Comment("MI")
+        << MatchTable::IntValue(InsnVarID) << MatchTable::Comment("Op")
+        << MatchTable::IntValue(OpIdx) << MatchTable::Comment("TempTypeIdx")
+        << MatchTable::IntValue(Idx) << MatchTable::LineBreak;
+}
+
 //===- ComplexPatternOperandMatcher ---------------------------------------===//
 
 void ComplexPatternOperandMatcher::emitPredicateOpcodes(
@@ -1196,6 +1216,18 @@ std::string OperandMatcher::getOperandExpr(unsigned InsnVarID) const {
 
 unsigned OperandMatcher::getInsnVarID() const { return Insn.getInsnVarID(); }
 
+TempTypeIdx OperandMatcher::getTempTypeIdx(RuleMatcher &Rule) {
+  if (TTIdx >= 0) {
+    // Temp type index not assigned yet, so assign one and add the necessary
+    // predicate.
+    TTIdx = Rule.getNextTempTypeIdx();
+    assert(TTIdx < 0);
+    addPredicate<RecordRegisterType>(TTIdx);
+    return TTIdx;
+  }
+  return TTIdx;
+}
+
 void OperandMatcher::emitPredicateOpcodes(MatchTable &Table,
                                           RuleMatcher &Rule) {
   if (!Optimized) {
@@ -2092,9 +2124,7 @@ void MakeTempRegisterAction::emitActionOpcodes(MatchTable &Table,
                                                RuleMatcher &Rule) const {
   Table << MatchTable::Opcode("GIR_MakeTempReg")
         << MatchTable::Comment("TempRegID") << MatchTable::IntValue(TempRegID)
-        << MatchTable::Comment("TypeID")
-        << MatchTable::NamedValue(Ty.getCxxEnumValue())
-        << MatchTable::LineBreak;
+        << MatchTable::Comment("TypeID") << Ty << MatchTable::LineBreak;
 }
 
 } // namespace gi
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.h b/llvm/utils/TableGen/GlobalISelMatchTable.h
index 5608bab482bfd34..364f2a1ec725d53 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.h
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.h
@@ -273,6 +273,40 @@ extern std::set<LLTCodeGen> KnownTypes;
 /// MVTs that don't map cleanly to an LLT (e.g., iPTR, *any, ...).
 std::optional<LLTCodeGen> MVTToLLT(MVT::SimpleValueType SVT);
 
+using TempTypeIdx = int64_t;
+class LLTCodeGenOrTempType {
+public:
+  LLTCodeGenOrTempType(const LLTCodeGen &LLT) : Data(LLT) {}
+  LLTCodeGenOrTempType(TempTypeIdx TempTy) : Data(TempTy) {}
+
+  bool isLLTCodeGen() const { return std::holds_alternative<LLTCodeGen>(Data); }
+  bool isTempTypeIdx() const {
+    return std::holds_alternative<TempTypeIdx>(Data);
+  }
+
+  const LLTCodeGen &getLLTCodeGen() const {
+    assert(isLLTCodeGen());
+    return std::get<LLTCodeGen>(Data);
+  }
+
+  TempTypeIdx getTempTypeIdx() const {
+    assert(isTempTypeIdx());
+    return std::get<TempTypeIdx>(Data);
+  }
+
+private:
+  std::variant<LLTCodeGen, TempTypeIdx> Data;
+};
+
+inline MatchTable &operator<<(MatchTable &Table,
+                              const LLTCodeGenOrTempType &Ty) {
+  if (Ty.isLLTCodeGen())
+    Table << MatchTable::NamedValue(Ty.getLLTCodeGen().getCxxEnumValue());
+  else
+    Table << MatchTable::IntValue(Ty.getTempTypeIdx());
+  return Table;
+}
+
 //===- Matchers -----------------------------------------------------------===//
 class Matcher {
 public:
@@ -459,6 +493,9 @@ class RuleMatcher : public Matcher {
   /// ID for the next temporary register ID allocated with allocateTempRegID()
   unsigned NextTempRegID;
 
+  /// ID for the next recorded type. Starts at -1 and counts down.
+  TempTypeIdx NextTempTypeIdx = -1;
+
   // HwMode predicate index for this rule. -1 if no HwMode.
   int HwModeIdx = -1;
 
@@ -498,6 +535,8 @@ class RuleMatcher : public Matcher {
   RuleMatcher(RuleMatcher &&Other) = default;
   RuleMatcher &operator=(RuleMatcher &&Other) = default;
 
+  TempTypeIdx getNextTempTypeIdx() { return NextTempTypeIdx--; }
+
   uint64_t getRuleID() const { return RuleID; }
 
   InstructionMatcher &addInstructionMatcher(StringRef SymbolicName);
@@ -602,6 +641,7 @@ class RuleMatcher : public Matcher {
   }
 
   InstructionMatcher &getInstructionMatcher(StringRef SymbolicName) const;
+  OperandMatcher &getOperandMatcher(StringRef Name);
   const OperandMatcher &getOperandMatcher(StringRef Name) const;
   const OperandMatcher &getPhysRegOperandMatcher(Record *) const;
 
@@ -762,6 +802,7 @@ class PredicateMatcher {
     OPM_RegBank,
     OPM_MBB,
     OPM_RecordNamedOperand,
+    OPM_RecordRegType,
   };
 
 protected:
@@ -963,6 +1004,30 @@ class RecordNamedOperandMatcher : public OperandPredicateMatcher {
                             RuleMatcher &Rule) const override;
 };
 
+/// Generates code to store a register operand's type into the set of temporary
+/// LLTs.
+class RecordRegisterType : public OperandPredicateMatcher {
+protected:
+  TempTypeIdx Idx;
+
+public:
+  RecordRegisterType(unsigned InsnVarID, unsigned OpIdx, TempTypeIdx Idx)
+      : OperandPredicateMatcher(OPM_RecordRegType, InsnVarID, OpIdx), Idx(Idx) {
+  }
+
+  static bool classof(const PredicateMatcher *P) {
+    return P->getKind() == OPM_RecordRegType;
+  }
+
+  bool isIdentical(const PredicateMatcher &B) const override {
+    return OperandPredicateMatcher::isIdentical(B) &&
+           Idx == cast<RecordRegisterType>(&B)->Idx;
+  }
+
+  void emitPredicateOpcodes(MatchTable &Table,
+                            RuleMatcher &Rule) const override;
+};
+
 /// Generates code to check that an operand is a particular target constant.
 class ComplexPatternOperandMatcher : public OperandPredicateMatcher {
 protected:
@@ -1169,6 +1234,8 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
   /// countRendererFns().
   unsigned AllocatedTemporariesBaseID;
 
+  TempTypeIdx TTIdx = 0;
+
 public:
   OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx,
                  const std::string &SymbolicName,
@@ -1196,6 +1263,11 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
   unsigned getOpIdx() const { return OpIdx; }
   unsigned getInsnVarID() const;
 
+  /// If this OperandMatcher has not been assigned a TempTypeIdx yet, assigns it
+  /// one and adds a `RecordRegisterType` predicate to this matcher. If one has
+  /// already been assigned, simply returns it.
+  TempTypeIdx getTempTypeIdx(RuleMatcher &Rule);
+
   std::string getOperandExpr(unsigned InsnVarID) const;
 
   InstructionMatcher &getInstructionMatcher() const { return Insn; }
@@ -1955,15 +2027,16 @@ class ImmRenderer : public OperandRenderer {
 protected:
   unsigned InsnID;
   int64_t Imm;
-  std::optional<LLTCodeGen> CImmLLT;
+  std::optional<LLTCodeGenOrTempType> CImmLLT;
 
 public:
   ImmRenderer(unsigned InsnID, int64_t Imm)
       : OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm) {}
 
-  ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGen &CImmLLT)
+  ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGenOrTempType &CImmLLT)
       : OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm), CImmLLT(CImmLLT) {
-    KnownTypes.insert(CImmLLT);
+    if (CImmLLT.isLLTCodeGen())
+      KnownTypes.insert(CImmLLT.getLLTCodeGen());
   }
 
   static bool classof(const OperandRenderer *R) {
@@ -1976,8 +2049,7 @@ class ImmRenderer : public OperandRenderer {
              "ConstantInt immediate are only for combiners!");
       Table << MatchTable::Opcode("GIR_AddCImm")
             << MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
-            << MatchTable::Comment("Type")
-            << MatchTable::NamedValue(CImmLLT->getCxxEnumValue())
+            << MatchTable::Comment("Type") << *CImmLLT
             << MatchTable::Comment("Imm") << MatchTable::IntValue(Imm)
             << MatchTable::LineBreak;
     } else {
@@ -2290,13 +2362,14 @@ class ConstrainOperandToRegClassAction : public MatchAction {
 /// instructions together.
 class MakeTempRegisterAction : public MatchAction {
 private:
-  LLTCodeGen Ty;
+  LLTCodeGenOrTempType Ty;
   unsigned TempRegID;
 
 public:
-  MakeTempRegisterAction(const LLTCodeGen &Ty, unsigned TempRegID)
+  MakeTempRegisterAction(const LLTCodeGenOrTempType &Ty, unsigned TempRegID)
       : MatchAction(AK_MakeTempReg), Ty(Ty), TempRegID(TempRegID) {
-    KnownTypes.insert(Ty);
+    if (Ty.isLLTCodeGen())
+      KnownTypes.insert(Ty.getLLTCodeGen());
   }
 
   static bool classof(const MatchAction *A) {

>From e803b89e25ae43782f7a36b331665c6e72715997 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Thu, 19 Oct 2023 08:00:12 +0200
Subject: [PATCH 2/2] Restore AddCImm

---
 .../CodeGen/GlobalISel/GIMatchTableExecutor.h  |  6 ++++++
 .../GlobalISel/GIMatchTableExecutorImpl.h      | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
index eefa0724222a9fa..6fcd9d09e1863cc 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
@@ -362,6 +362,12 @@ enum {
   /// - Imm - The immediate to add
   GIR_AddImm,
 
+  /// Add an CImm to the specified instruction
+  /// - InsnID - Instruction ID to modify
+  /// - Ty - Type of the constant immediate.
+  /// - Imm - The immediate to add
+  GIR_AddCImm,
+
   /// Render complex operands to the specified instruction
   /// - InsnID - Instruction ID to modify
   /// - RendererID - The renderer to call
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 9fd56a2c670e0ac..48e1d388b93996a 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -1094,6 +1094,24 @@ bool GIMatchTableExecutor::executeMatchTable(
                              << "], " << Imm << ")\n");
       break;
     }
+
+    case GIR_AddCImm: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      int64_t TypeID = MatchTable[CurrentIdx++];
+      int64_t Imm = MatchTable[CurrentIdx++];
+      assert(OutMIs[InsnID] && "Attempted to add to undefined instruction");
+
+      unsigned Width = ExecInfo.TypeObjects[TypeID].getScalarSizeInBits();
+      LLVMContext &Ctx = MF->getFunction().getContext();
+      OutMIs[InsnID].addCImm(
+          ConstantInt::get(IntegerType::get(Ctx, Width), Imm, /*signed*/ true));
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIR_AddCImm(OutMIs[" << InsnID
+                             << "], TypeID=" << TypeID << ", Imm=" << Imm
+                             << ")\n");
+      break;
+    }
+
     case GIR_ComplexRenderer: {
       int64_t InsnID = MatchTable[CurrentIdx++];
       int64_t RendererID = MatchTable[CurrentIdx++];



More information about the llvm-commits mailing list