[llvm] [TableGen][GlobalISel] Add rule-wide type inference (PR #66377)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 17 22:50:40 PDT 2023


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/66377

>From cf3ba9df2b735920c0d435cc19c7ef306cbd3e78 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Tue, 12 Sep 2023 13:14:37 +0200
Subject: [PATCH 1/2] [GlobalISel] Add `GITypeOf` special type

Allows creating a register/immediate that uses the same type as a matched operand.
---
 llvm/docs/GlobalISel/MIRPatterns.rst          |  42 +++
 .../llvm/CodeGen/GlobalISel/CombinerHelper.h  |   3 -
 .../CodeGen/GlobalISel/GIMatchTableExecutor.h |  16 +-
 .../GlobalISel/GIMatchTableExecutorImpl.h     |  50 ++--
 .../include/llvm/Target/GlobalISel/Combine.td |  25 +-
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp |  12 -
 .../match-table-typeof.td                     |  49 +++
 .../operand-types.td                          |  28 +-
 .../pattern-parsing.td                        |  25 +-
 .../typeof-errors.td                          |  72 +++++
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 278 +++++++++++++++---
 llvm/utils/TableGen/GlobalISelMatchTable.cpp  |  36 ++-
 llvm/utils/TableGen/GlobalISelMatchTable.h    |  89 +++++-
 13 files changed, 622 insertions(+), 103 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td

diff --git a/llvm/docs/GlobalISel/MIRPatterns.rst b/llvm/docs/GlobalISel/MIRPatterns.rst
index fa70311f48572de..a3883b14b3e0bd6 100644
--- a/llvm/docs/GlobalISel/MIRPatterns.rst
+++ b/llvm/docs/GlobalISel/MIRPatterns.rst
@@ -101,6 +101,48 @@ pattern, you can try naming your patterns to see exactly where the issue is.
   // using $x again here copies operand 1 from G_AND into the new inst.
   (apply (COPY $root, $x))
 
+Types
+-----
+
+ValueType
+~~~~~~~~~
+
+Subclasses of ``ValueType`` are valid types, e.g. ``i32``.
+
+GITypeOf
+~~~~~~~~
+
+``GITypeOf<"$x">`` is a ``GISpecialType`` that allows for the creation of a
+register or immediate with the same type as another (register) operand.
+
+Operand:
+
+* An operand name as a string, prefixed by ``$``.
+
+Semantics:
+
+* Can only appear in an 'apply' pattern.
+* The operand name used must appear in the 'match' pattern of the
+  same ``GICombineRule``.
+
+.. code-block:: text
+  :caption: Example: Immediate
+
+  def mul_by_neg_one: GICombineRule <
+    (defs root:$root),
+    (match (G_MUL $dst, $x, -1)),
+    (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  >;
+
+.. code-block:: text
+  :caption: Example: Temp Reg
+
+  def Test0 : GICombineRule<
+    (defs root:$dst),
+    (match (G_FMUL $dst, $src, -1)),
+    (apply (G_FSUB $dst, $src, $tmp),
+           (G_FNEG GITypeOf<"$dst">:$tmp, $src))>;
+
 Builtin Operations
 ------------------
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
index d64b414f2747621..c6270fbff1b5496 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
@@ -403,9 +403,6 @@ class CombinerHelper {
   void applyCombineTruncOfShift(MachineInstr &MI,
                                 std::pair<MachineInstr *, LLT> &MatchInfo);
 
-  /// Transform G_MUL(x, -1) to G_SUB(0, x)
-  void applyCombineMulByNegativeOne(MachineInstr &MI);
-
   /// Return true if any explicit use operand on \p MI is defined by a
   /// G_IMPLICIT_DEF.
   bool matchAnyExplicitUseIsUndef(MachineInstr &MI);
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
index 209f80c6d6d2877..eefa0724222a9fa 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
@@ -275,6 +275,12 @@ enum {
   /// - StoreIdx - Store location in RecordedOperands.
   GIM_RecordNamedOperand,
 
+  /// Records an operand's register type into the set of temporary types.
+  /// - InsnID - Instruction ID
+  /// - OpIdx - Operand index
+  /// - TempTypeIdx - Temp Type Index, always negative.
+  GIM_RecordRegType,
+
   /// Fail the current try-block, or completely fail to match if there is no
   /// current try-block.
   GIM_Reject,
@@ -356,12 +362,6 @@ enum {
   /// - Imm - The immediate to add
   GIR_AddImm,
 
-  /// Add an CImm to the specified instruction
-  /// - InsnID - Instruction ID to modify
-  /// - Ty - Type of the constant immediate.
-  /// - Imm - The immediate to add
-  GIR_AddCImm,
-
   /// Render complex operands to the specified instruction
   /// - InsnID - Instruction ID to modify
   /// - RendererID - The renderer to call
@@ -522,6 +522,10 @@ class GIMatchTableExecutor {
     /// list. Currently such predicates don't have more then 3 arguments.
     std::array<const MachineOperand *, 3> RecordedOperands;
 
+    /// Types extracted from an instruction's operand.
+    /// Whenever a type index is negative, we look here instead.
+    SmallVector<LLT, 4> RecordedTypes;
+
     MatcherState(unsigned MaxRenderers);
   };
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index fb03d5ec0bc89a9..9fd56a2c670e0ac 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -92,6 +92,14 @@ bool GIMatchTableExecutor::executeMatchTable(
     return true;
   };
 
+  // If the index is >= 0, it's an index in the type objects generated by
+  // TableGen. If the index is <0, it's an index in the recorded types object.
+  auto getTypeFromIdx = [&](int64_t Idx) -> const LLT & {
+    if (Idx >= 0)
+      return ExecInfo.TypeObjects[Idx];
+    return State.RecordedTypes[1 - Idx];
+  };
+
   while (true) {
     assert(CurrentIdx != ~0u && "Invalid MatchTable index");
     int64_t MatcherOpcode = MatchTable[CurrentIdx++];
@@ -627,8 +635,7 @@ bool GIMatchTableExecutor::executeMatchTable(
                              << "), TypeID=" << TypeID << ")\n");
       assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
       MachineOperand &MO = State.MIs[InsnID]->getOperand(OpIdx);
-      if (!MO.isReg() ||
-          MRI.getType(MO.getReg()) != ExecInfo.TypeObjects[TypeID]) {
+      if (!MO.isReg() || MRI.getType(MO.getReg()) != getTypeFromIdx(TypeID)) {
         if (handleReject() == RejectAndGiveUp)
           return false;
       }
@@ -679,6 +686,25 @@ bool GIMatchTableExecutor::executeMatchTable(
       State.RecordedOperands[StoreIdx] = &State.MIs[InsnID]->getOperand(OpIdx);
       break;
     }
+    case GIM_RecordRegType: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      int64_t OpIdx = MatchTable[CurrentIdx++];
+      int64_t TypeIdx = MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIM_RecordRegType(MIs["
+                             << InsnID << "]->getOperand(" << OpIdx
+                             << "), TypeIdx=" << TypeIdx << ")\n");
+      assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
+      assert(TypeIdx <= 0 && "Temp types always have negative indexes!");
+      // Indexes start at -1.
+      TypeIdx = 1 - TypeIdx;
+      const auto &Op = State.MIs[InsnID]->getOperand(OpIdx);
+      if (State.RecordedTypes.size() <= (uint64_t)TypeIdx)
+        State.RecordedTypes.resize(TypeIdx + 1, LLT());
+      State.RecordedTypes[TypeIdx] = MRI.getType(Op.getReg());
+      break;
+    }
     case GIM_CheckRegBankForClass: {
       int64_t InsnID = MatchTable[CurrentIdx++];
       int64_t OpIdx = MatchTable[CurrentIdx++];
@@ -1068,24 +1094,6 @@ bool GIMatchTableExecutor::executeMatchTable(
                              << "], " << Imm << ")\n");
       break;
     }
-
-    case GIR_AddCImm: {
-      int64_t InsnID = MatchTable[CurrentIdx++];
-      int64_t TypeID = MatchTable[CurrentIdx++];
-      int64_t Imm = MatchTable[CurrentIdx++];
-      assert(OutMIs[InsnID] && "Attempted to add to undefined instruction");
-
-      unsigned Width = ExecInfo.TypeObjects[TypeID].getScalarSizeInBits();
-      LLVMContext &Ctx = MF->getFunction().getContext();
-      OutMIs[InsnID].addCImm(
-          ConstantInt::get(IntegerType::get(Ctx, Width), Imm, /*signed*/ true));
-      DEBUG_WITH_TYPE(TgtExecutor::getName(),
-                      dbgs() << CurrentIdx << ": GIR_AddCImm(OutMIs[" << InsnID
-                             << "], TypeID=" << TypeID << ", Imm=" << Imm
-                             << ")\n");
-      break;
-    }
-
     case GIR_ComplexRenderer: {
       int64_t InsnID = MatchTable[CurrentIdx++];
       int64_t RendererID = MatchTable[CurrentIdx++];
@@ -1275,7 +1283,7 @@ bool GIMatchTableExecutor::executeMatchTable(
       int64_t TypeID = MatchTable[CurrentIdx++];
 
       State.TempRegisters[TempRegID] =
-          MRI.createGenericVirtualRegister(ExecInfo.TypeObjects[TypeID]);
+          MRI.createGenericVirtualRegister(getTypeFromIdx(TypeID));
       DEBUG_WITH_TYPE(TgtExecutor::getName(),
                       dbgs() << CurrentIdx << ": TempRegs[" << TempRegID
                              << "] = GIR_MakeTempReg(" << TypeID << ")\n");
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 7e0691e1ee95048..31a337feed847ec 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -110,6 +110,24 @@ class GICombinePatFrag<dag outs, dag ins, list<dag> alts> {
   list<dag> Alternatives = alts;
 }
 
+//===----------------------------------------------------------------------===//
+// Pattern Special Types
+//===----------------------------------------------------------------------===//
+
+class GISpecialType;
+
+// In an apply pattern, GITypeOf can be used to set the type of a new temporary
+// register to match the type of a matched register.
+//
+// This can only be used on temporary registers defined by the apply pattern.
+//
+// TODO: Make this work in matchers as well?
+//
+// FIXME: Syntax is very ugly.
+class GITypeOf<string opName> : GISpecialType {
+  string OpName = opName;
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern Builtins
 //===----------------------------------------------------------------------===//
@@ -776,10 +794,9 @@ def trunc_shift: GICombineRule <
 
 // Transform (mul x, -1) -> (sub 0, x)
 def mul_by_neg_one: GICombineRule <
-  (defs root:$root),
-  (match (wip_match_opcode G_MUL):$root,
-         [{ return Helper.matchConstantOp(${root}->getOperand(2), -1); }]),
-  (apply [{ Helper.applyCombineMulByNegativeOne(*${root}); }])
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 9efb70f28fee3ee..8a817d211957a5e 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -2224,18 +2224,6 @@ void CombinerHelper::applyCombineExtOfExt(
   }
 }
 
-void CombinerHelper::applyCombineMulByNegativeOne(MachineInstr &MI) {
-  assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL");
-  Register DstReg = MI.getOperand(0).getReg();
-  Register SrcReg = MI.getOperand(1).getReg();
-  LLT DstTy = MRI.getType(DstReg);
-
-  Builder.setInstrAndDebugLoc(MI);
-  Builder.buildSub(DstReg, Builder.buildConstant(DstTy, 0), SrcReg,
-                   MI.getFlags());
-  MI.eraseFromParent();
-}
-
 bool CombinerHelper::matchCombineTruncOfExt(
     MachineInstr &MI, std::pair<Register, unsigned> &MatchInfo) {
   assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC");
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
new file mode 100644
index 000000000000000..496d86aeef2d10a
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-typeof.td
@@ -0,0 +1,49 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -combiners=MyCombiner %s | \
+// RUN: FileCheck %s
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+def Test0 : GICombineRule<
+  (defs root:$dst),
+  (match (G_MUL $dst, $src, -1)),
+  (apply (G_SUB $dst, (GITypeOf<"$src"> 0), $tmp),
+         (G_CONSTANT GITypeOf<"$dst">:$tmp, (GITypeOf<"$src"> 42)))>;
+
+// CHECK:      const int64_t *GenMyCombiner::getMatchTable() const {
+// CHECK-NEXT:   constexpr static int64_t MatchTable0[] = {
+// CHECK-NEXT:     GIM_Try, /*On fail goto*//*Label 0*/ 57, // Rule ID 0 //
+// CHECK-NEXT:       GIM_CheckSimplePredicate, GICXXPred_Simple_IsRule0Enabled,
+// CHECK-NEXT:       GIM_CheckOpcode, /*MI*/0, TargetOpcode::G_MUL,
+// CHECK-NEXT:       // MIs[0] dst
+// CHECK-NEXT:       GIM_RecordRegType, /*MI*/0, /*Op*/0, /*TempTypeIdx*/-1,
+// CHECK-NEXT:       // MIs[0] src
+// CHECK-NEXT:       GIM_RecordRegType, /*MI*/0, /*Op*/1, /*TempTypeIdx*/-2,
+// CHECK-NEXT:       // MIs[0] Operand 2
+// CHECK-NEXT:       GIM_CheckConstantInt, /*MI*/0, /*Op*/2, -1,
+// CHECK-NEXT:       GIR_MakeTempReg, /*TempRegID*/1, /*TypeID*/-2,
+// CHECK-NEXT:       GIR_BuildConstant, /*TempRegID*/1, /*Val*/0,
+// CHECK-NEXT:       GIR_MakeTempReg, /*TempRegID*/0, /*TypeID*/-1,
+// CHECK-NEXT:       // Combiner Rule #0: Test0
+// CHECK-NEXT:       GIR_BuildMI, /*InsnID*/0, /*Opcode*/TargetOpcode::G_CONSTANT,
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/0, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_AddCImm, /*InsnID*/0, /*Type*/-2, /*Imm*/42,
+// CHECK-NEXT:       GIR_EraseFromParent, /*InsnID*/0,
+// CHECK-NEXT:       GIR_BuildMI, /*InsnID*/1, /*Opcode*/TargetOpcode::G_SUB,
+// CHECK-NEXT:       GIR_Copy, /*NewInsnID*/1, /*OldInsnID*/0, /*OpIdx*/0, // dst
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/1, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_AddTempRegister, /*InsnID*/1, /*TempRegID*/0, /*TempRegFlags*/0,
+// CHECK-NEXT:       GIR_Done,
+// CHECK-NEXT:     // Label 0: @57
+// CHECK-NEXT:     GIM_Reject,
+// CHECK-NEXT:     };
+// CHECK-NEXT:   return MatchTable0;
+// CHECK-NEXT: }
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  Test0
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
index c871e603e4e05aa..4769bed97240125 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/operand-types.td
@@ -79,7 +79,33 @@ def PatFragTest0 : GICombineRule<
   (match (FooPF $dst)),
   (apply (COPY $dst, (i32 0)))>;
 
+
+// CHECK:      (CombineRule name:TypeOfProp id:2 root:x
+// CHECK-NEXT:   (MatchPats
+// CHECK-NEXT:     <match_root>__TypeOfProp_match_0:(CodeGenInstructionPattern G_ZEXT operands:[<def>$x, $y])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (ApplyPats
+// CHECK-NEXT:     <apply_root>__TypeOfProp_apply_0:(CodeGenInstructionPattern G_ANYEXT operands:[<def>$x, GITypeOf<$y>:$tmp])
+// CHECK-NEXT:     __TypeOfProp_apply_1:(CodeGenInstructionPattern G_ANYEXT operands:[<def>GITypeOf<$y>:$tmp, $y])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable MatchPats
+// CHECK-NEXT:     x -> __TypeOfProp_match_0
+// CHECK-NEXT:     y -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable ApplyPats
+// CHECK-NEXT:     tmp -> __TypeOfProp_apply_1
+// CHECK-NEXT:     x -> __TypeOfProp_apply_0
+// CHECK-NEXT:     y -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT: )
+def TypeOfProp : GICombineRule<
+  (defs root:$x),
+  (match (G_ZEXT $x, $y)),
+  (apply (G_ANYEXT $x, GITypeOf<"$y">:$tmp),
+         (G_ANYEXT $tmp, $y))>;
+
 def MyCombiner: GICombiner<"GenMyCombiner", [
   InstTest0,
-  PatFragTest0
+  PatFragTest0,
+  TypeOfProp
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
index bc75b15233b5519..fd41a7d1d72417e 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
@@ -297,6 +297,28 @@ def VariadicsOutTest : GICombineRule<
   (apply (COPY $a, (i32 0)),
          (COPY $b, (i32 0)))>;
 
+// CHECK:      (CombineRule name:TypeOfTest id:10 root:dst
+// CHECK-NEXT:   (MatchPats
+// CHECK-NEXT:     <match_root>__TypeOfTest_match_0:(CodeGenInstructionPattern COPY operands:[<def>$dst, $tmp])
+// CHECK-NEXT:     __TypeOfTest_match_1:(CodeGenInstructionPattern G_ZEXT operands:[<def>$tmp, $src])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (ApplyPats
+// CHECK-NEXT:     <apply_root>__TypeOfTest_apply_0:(CodeGenInstructionPattern G_MUL operands:[<def>$dst, (GITypeOf<$src> 0), (GITypeOf<$dst> -1)])
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable MatchPats
+// CHECK-NEXT:     dst -> __TypeOfTest_match_0
+// CHECK-NEXT:     src -> <live-in>
+// CHECK-NEXT:     tmp -> __TypeOfTest_match_1
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable ApplyPats
+// CHECK-NEXT:     dst -> __TypeOfTest_apply_0
+// CHECK-NEXT:   )
+// CHECK-NEXT: )
+def TypeOfTest : GICombineRule<
+  (defs root:$dst),
+  (match (COPY $dst, $tmp),
+         (G_ZEXT $tmp, $src)),
+  (apply (G_MUL $dst, (GITypeOf<"$src"> 0), (GITypeOf<"$dst"> -1)))>;
 
 def MyCombiner: GICombiner<"GenMyCombiner", [
   WipOpcodeTest0,
@@ -308,5 +330,6 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
   PatFragTest0,
   PatFragTest1,
   VariadicsInTest,
-  VariadicsOutTest
+  VariadicsOutTest,
+  TypeOfTest
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
new file mode 100644
index 000000000000000..6040d6def449766
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -0,0 +1,72 @@
+// RUN: not llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -combiners=MyCombiner %s 2>&1| \
+// RUN: FileCheck %s -implicit-check-not=error:
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+def NoDollarSign : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, (GITypeOf<"unknown"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'unknown' ('GITypeOf<$unknown>') does not refer to a matched operand!
+def UnknownOperand : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, (GITypeOf<"$unknown"> 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__UseInMatch_match_0' has type 'GITypeOf<$dst>'
+def UseInMatch : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, (GITypeOf<"$dst"> 0))),
+  (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: GISpecialType is not supported in GICombinePatFrag
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: operand 1 of '__PFWithTypeOF_alt0_pattern_0' has type 'GITypeOf<$dst>
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Could not parse GICombinePatFrag 'PFWithTypeOF'
+def PFWithTypeOF: GICombinePatFrag<
+    (outs $dst), (ins),
+    [(pattern (G_ANYEXT $dst, (GITypeOf<"$dst"> 0)))]>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(PFWithTypeOF ?:$dst)'
+def UseInPF: GICombineRule<
+  (defs root:$dst),
+  (match (PFWithTypeOF $dst)),
+  (apply (G_ANYEXT $dst, (i32 0)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: GISpecialType is not supported in 'match' patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: operand 1 of '__InferredUseInMatch_match_0' has type 'GITypeOf<$dst>'
+def InferredUseInMatch : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+def InferenceConflict : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, i32:$src)),
+  (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: 'tmp' ('GITypeOf<$tmp>') does not refer to a matched operand!
+def TypeOfApplyTmp : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src)),
+  (apply (G_ANYEXT $dst, i32:$tmp),
+         (G_ANYEXT $tmp, (GITypeOf<"$tmp"> 0)))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse one or more rules
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  NoDollarSign,
+  UnknownOperand,
+  UseInMatch,
+  UseInPF,
+  InferredUseInMatch,
+  InferenceConflict,
+  TypeOfApplyTmp
+]>;
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 7992cb4362a1718..844d8b1051bd471 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -73,6 +73,8 @@ constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
 constexpr StringLiteral PatFragClassName = "GICombinePatFrag";
 constexpr StringLiteral BuiltinInstClassName = "GIBuiltinInst";
+constexpr StringLiteral SpecialTyClassName = "GISpecialType";
+constexpr StringLiteral TypeOfClassName = "GITypeOf";
 
 std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
   return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
@@ -123,11 +125,6 @@ template <typename Container> auto values(Container &&C) {
   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
 }
 
-LLTCodeGen getLLTCodeGenFromRecord(const Record *Ty) {
-  assert(Ty->isSubClassOf("ValueType"));
-  return LLTCodeGen(*MVTToLLT(getValueType(Ty)));
-}
-
 //===- MatchData Handling -------------------------------------------------===//
 
 /// Represents MatchData defined by the match stage and required by the apply
@@ -292,6 +289,117 @@ class CXXPredicateCode {
 CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXMatchCode;
 CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 
+//===- PatternType --------------------------------------------------------===//
+
+/// Simple wrapper around a Record* of a type.
+///
+/// Types have two form:
+///   - LLTs, which are straightforward.
+///   - Special types, e.g. GITypeOf
+class PatternType {
+public:
+  PatternType() = default;
+  PatternType(const Record *R) : R(R) {}
+
+  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+
+  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
+  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
+  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+
+  StringRef getTypeOfOpName() const;
+  LLTCodeGen getLLTCodeGen() const;
+
+  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
+
+  LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
+
+  explicit operator bool() const { return R != nullptr; }
+
+  bool operator==(const PatternType &Other) const;
+  bool operator!=(const PatternType &Other) const { return !operator==(Other); }
+
+  std::string str() const;
+
+private:
+  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+
+  const Record *R = nullptr;
+};
+
+StringRef PatternType::getTypeOfOpName() const {
+  assert(isTypeOf());
+  StringRef Name = getRawOpName();
+  Name.consume_front("$");
+  return Name;
+}
+
+LLTCodeGen PatternType::getLLTCodeGen() const {
+  assert(isLLT());
+  return *MVTToLLT(getValueType(R));
+}
+
+LLTCodeGenOrTempType
+PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
+  assert(isValidType());
+
+  if (isLLT())
+    return getLLTCodeGen();
+
+  assert(isTypeOf());
+  auto &OM = RM.getOperandMatcher(getTypeOfOpName());
+  return OM.getTempTypeIdx(RM);
+}
+
+bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
+  if (!isTypeOf())
+    return true;
+
+  auto RawOpName = getRawOpName();
+  if (RawOpName.starts_with("$"))
+    return true;
+
+  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
+                          TypeOfClassName +
+                          ": expected '$' followed by an operand name");
+  return false;
+}
+
+bool PatternType::operator==(const PatternType &Other) const {
+  if (R == Other.R) {
+    if (R && R->getName() != Other.R->getName()) {
+      dbgs() << "Same ptr but: " << R->getName() << " and "
+             << Other.R->getName() << "?\n";
+      assert(false);
+    }
+    return true;
+  }
+
+  if (isTypeOf() && Other.isTypeOf())
+    return getTypeOfOpName() == Other.getTypeOfOpName();
+
+  return false;
+}
+
+std::string PatternType::str() const {
+  if (!R)
+    return "";
+
+  if (!isValidType())
+    return "<invalid>";
+
+  if (isLLT())
+    return R->getName().str();
+
+  assert(isSpecial());
+
+  // TODO: Use an enum for these?
+  if (isTypeOf())
+    return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+
+  llvm_unreachable("Unknown type!");
+}
+
 //===- Pattern Base Class -------------------------------------------------===//
 
 /// Base class for all patterns that can be written in an `apply`, `match` or
@@ -499,13 +607,15 @@ class InstructionOperand {
 public:
   using IntImmTy = int64_t;
 
-  InstructionOperand(IntImmTy Imm, StringRef Name, const Record *Type)
+  InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
       : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(!Type || Type->isSubClassOf("ValueType"));
+    assert(Type.isValidType());
   }
 
-  InstructionOperand(StringRef Name, const Record *Type)
-      : Name(insertStrRef(Name)), Type(Type) {}
+  InstructionOperand(StringRef Name, PatternType Type)
+      : Name(insertStrRef(Name)), Type(Type) {
+    assert(Type.isValidType());
+  }
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -527,11 +637,12 @@ class InstructionOperand {
   void setIsDef(bool Value = true) { Def = Value; }
   bool isDef() const { return Def; }
 
-  void setType(const Record *R) {
-    assert((!Type || (Type == R)) && "Overwriting type!");
-    Type = R;
+  void setType(PatternType NewType) {
+    assert((!Type || (Type == NewType)) && "Overwriting type!");
+    assert(NewType.isValidType());
+    Type = NewType;
   }
-  const Record *getType() const { return Type; }
+  PatternType getType() const { return Type; }
 
   std::string describe() const {
     if (!hasImmValue())
@@ -547,11 +658,11 @@ class InstructionOperand {
       OS << "<def>";
 
     bool NeedsColon = true;
-    if (const Record *Ty = getType()) {
+    if (Type) {
       if (hasImmValue())
-        OS << "(" << Ty->getName() << " " << getImmValue() << ")";
+        OS << "(" << Type.str() << " " << getImmValue() << ")";
       else
-        OS << Ty->getName();
+        OS << Type.str();
     } else if (hasImmValue())
       OS << getImmValue();
     else
@@ -566,7 +677,7 @@ class InstructionOperand {
 private:
   std::optional<int64_t> Value;
   StringRef Name;
-  const Record *Type = nullptr;
+  PatternType Type;
   bool Def = false;
 };
 
@@ -622,6 +733,10 @@ class InstructionPattern : public Pattern {
 
   virtual StringRef getInstName() const = 0;
 
+  /// Diagnoses all uses of special types in this Pattern and returns true if at
+  /// least one diagnostic was emitted.
+  bool diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc, Twine Msg) const;
+
   void reportUnreachable(ArrayRef<SMLoc> Locs) const;
   virtual bool checkSemantics(ArrayRef<SMLoc> Loc);
 
@@ -633,6 +748,20 @@ class InstructionPattern : public Pattern {
   SmallVector<InstructionOperand, 4> Operands;
 };
 
+bool InstructionPattern::diagnoseAllSpecialTypes(ArrayRef<SMLoc> Loc,
+                                                 Twine Msg) const {
+  bool HasDiag = false;
+  for (const auto &[Idx, Op] : enumerate(operands())) {
+    if (Op.getType().isSpecial()) {
+      PrintError(Loc, Msg);
+      PrintNote(Loc, "operand " + Twine(Idx) + " of '" + getName() +
+                         "' has type '" + Op.getType().str() + "'");
+      HasDiag = true;
+    }
+  }
+  return HasDiag;
+}
+
 void InstructionPattern::reportUnreachable(ArrayRef<SMLoc> Locs) const {
   PrintError(Locs, "pattern '" + getName() + "' ('" + getInstName() +
                        "') is unreachable from the pattern root!");
@@ -829,17 +958,20 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
 /// setAllOperandTypes.
+///
+/// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P);
+  bool check(InstructionPattern *P,
+             std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
   void setAllOperandTypes();
 
 private:
   struct OpTypeInfo {
-    const Record *Type = nullptr;
+    PatternType Type;
     InstructionPattern *TypeSrc = nullptr;
   };
 
@@ -849,16 +981,26 @@ class OperandTypeChecker {
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
-bool OperandTypeChecker::check(InstructionPattern *P) {
+bool OperandTypeChecker::check(
+    InstructionPattern *P,
+    std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
   Pats.push_back(P);
 
-  for (auto &Op : P->named_operands()) {
-    const Record *Ty = Op.getType();
+  for (auto &Op : P->operands()) {
+    const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    if (!Ty.checkSemantics(DiagLoc))
+      return false;
+
+    if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
+      return false;
 
+    if (!Op.isNamedOperand())
+      continue;
+
+    auto &Info = Types[Op.getOperandName()];
     if (!Info.Type) {
       Info.Type = Ty;
       Info.TypeSrc = P;
@@ -868,9 +1010,9 @@ bool OperandTypeChecker::check(InstructionPattern *P) {
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
                               Op.getOperandName() + "': first seen with '" +
-                              Info.Type->getName() + "' in '" +
+                              Info.Type.str() + "' in '" +
                               Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty->getName() + "' in '" + P->getName() + "'");
+                              Ty.str() + "' in '" + P->getName() + "'");
       return false;
     }
   }
@@ -1058,7 +1200,12 @@ bool PatFrag::checkSemantics() {
                    PatFragClassName);
         return false;
       case Pattern::K_CXX:
+        continue;
       case Pattern::K_CodeGenInstruction:
+        if (cast<CodeGenInstructionPattern>(Pat.get())->diagnoseAllSpecialTypes(
+                Def.getLoc(), SpecialTyClassName + " is not supported in " +
+                                  PatFragClassName))
+          return false;
         continue;
       case Pattern::K_PatFrag:
         // TODO: It's just that the emitter doesn't handle it but technically
@@ -1142,12 +1289,16 @@ bool PatFrag::checkSemantics() {
 
   // TODO: find unused params
 
+  const auto CheckTypeOf = [&](const PatternType &) -> bool {
+    llvm_unreachable("GITypeOf should have been rejected earlier!");
+  };
+
   // Now, typecheck all alternatives.
   for (auto &Alt : Alts) {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP))
+        if (!OTC.check(IP, CheckTypeOf))
           return false;
       }
     }
@@ -1954,21 +2105,49 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 bool CombineRuleBuilder::typecheckPatterns() {
   OperandTypeChecker OTC(RuleDef.getLoc());
 
+  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
+    // We'll reject those after we're done inferring
+    return true;
+  };
+
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP))
+      if (!OTC.check(IP, CheckMatchTypeOf))
         return false;
     }
   }
 
+  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError("'" + OpName + "' ('" + Ty.str() +
+               "') does not refer to a matched operand!");
+    return false;
+  };
+
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP))
+      if (!OTC.check(IP, CheckApplyTypeOf))
         return false;
     }
   }
 
   OTC.setAllOperandTypes();
+
+  // Always check this after in case inference adds some special types to the
+  // match patterns.
+  for (auto &Pat : values(MatchPats)) {
+    if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
+      if (IP->diagnoseAllSpecialTypes(
+              RuleDef.getLoc(),
+              SpecialTyClassName + " is not supported in 'match' patterns")) {
+        return false;
+      }
+    }
+  }
   return true;
 }
 
@@ -2461,10 +2640,12 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
     if (DagOp->getNumArgs() != 1)
       return ParseErr();
 
-    Record *ImmTy = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    if (!ImmTy->isSubClassOf("ValueType")) {
+    const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
+    PatternType ImmTy(TyDef);
+    if (!ImmTy.isValidType()) {
       PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + ImmTy->getName() + "' is not a ValueType!");
+                 "', '" + TyDef->getName() + "' is not a ValueType or " +
+                 SpecialTyClassName);
       return false;
     }
 
@@ -2491,12 +2672,13 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    if (!Def->isSubClassOf("ValueType")) {
+    PatternType Ty(Def);
+    if (!Ty.isValidType()) {
       PrintError("invalid operand type: '" + Def->getName() +
                  "' is not a ValueType");
       return false;
     }
-    IP.addOperand(OpName->getAsUnquotedString(), Def);
+    IP.addOperand(OpName->getAsUnquotedString(), Ty);
     return true;
   }
 
@@ -2823,8 +3005,8 @@ bool CombineRuleBuilder::emitPatFragMatchPattern(
       StringRef PFName = PF.getName();
       PrintWarning("impossible type constraints: operand " + Twine(PIdx) +
                    " of '" + PFP.getName() + "' has type '" +
-                   ArgOp.getType()->getName() + "', but '" + PFName +
-                   "' constrains it to '" + O.getType()->getName() + "'");
+                   ArgOp.getType().str() + "', but '" + PFName +
+                   "' constrains it to '" + O.getType().str() + "'");
       if (ArgOp.isNamedOperand())
         PrintNote("operand " + Twine(PIdx) + " of '" + PFP.getName() +
                   "' is '" + ArgOp.getOperandName() + "'");
@@ -3055,17 +3237,18 @@ bool CombineRuleBuilder::emitInstructionApplyPattern(
       // This is a brand new register.
       TempRegID = M.allocateTempRegID();
       OperandToTempRegID[OpName] = TempRegID;
-      const Record *Ty = Op.getType();
+      const auto Ty = Op.getType();
       if (!Ty) {
         PrintError("def of a new register '" + OpName +
                    "' in the apply patterns must have a type");
         return false;
       }
+
       declareTempRegExpansion(CE, TempRegID, OpName);
       // Always insert the action at the beginning, otherwise we may end up
       // using the temp reg before it's available.
       M.insertAction<MakeTempRegisterAction>(
-          M.actions_begin(), getLLTCodeGenFromRecord(Ty), TempRegID);
+          M.actions_begin(), Ty.getLLTCodeGenOrTempType(M), TempRegID);
     }
 
     DstMI.addRenderer<TempRegRenderer>(TempRegID);
@@ -3088,7 +3271,7 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
   // G_CONSTANT is a special case and needs a CImm though so this is likely a
   // mistake.
   const bool isGConstant = P.is("G_CONSTANT");
-  const Record *Ty = O.getType();
+  const auto Ty = O.getType();
   if (!Ty) {
     if (isGConstant) {
       PrintError("'G_CONSTANT' immediate must be typed!");
@@ -3101,16 +3284,17 @@ bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
     return true;
   }
 
-  LLTCodeGen LLT = getLLTCodeGenFromRecord(Ty);
+  auto ImmTy = Ty.getLLTCodeGenOrTempType(M);
+
   if (isGConstant) {
-    DstMI.addRenderer<ImmRenderer>(O.getImmValue(), LLT);
+    DstMI.addRenderer<ImmRenderer>(O.getImmValue(), ImmTy);
     return true;
   }
 
   unsigned TempRegID = M.allocateTempRegID();
   // Ensure MakeTempReg & the BuildConstantAction occur at the beginning.
-  auto InsertIt =
-      M.insertAction<MakeTempRegisterAction>(M.actions_begin(), LLT, TempRegID);
+  auto InsertIt = M.insertAction<MakeTempRegisterAction>(M.actions_begin(),
+                                                         ImmTy, TempRegID);
   M.insertAction<BuildConstantAction>(++InsertIt, TempRegID, O.getImmValue());
   DstMI.addRenderer<TempRegRenderer>(TempRegID);
   return true;
@@ -3227,8 +3411,14 @@ bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
     // Always emit a check for unnamed operands.
     if (OpName.empty() ||
         !M.getOperandMatcher(OpName).contains<LLTOperandMatcher>()) {
-      if (const Record *Ty = RemappedO.getType())
-        OM.addPredicate<LLTOperandMatcher>(getLLTCodeGenFromRecord(Ty));
+      if (const auto Ty = RemappedO.getType()) {
+        // TODO: We could support GITypeOf here on the condition that the
+        // OperandMatcher exists already. Though it's clunky to make this work
+        // and isn't all that useful so it's just rejected in typecheckPatterns
+        // at this time.
+        assert(Ty.isLLT() && "Only LLTs are supported in match patterns!");
+        OM.addPredicate<LLTOperandMatcher>(Ty.getLLTCodeGen());
+      }
     }
 
     // Stop here if the operand is a def, or if it had no name.
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
index 9a4a375f34bdb91..6ec85269e6e20d0 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.cpp
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
@@ -822,6 +822,15 @@ const OperandMatcher &RuleMatcher::getPhysRegOperandMatcher(Record *Reg) const {
   return *I->second;
 }
 
+OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) {
+  const auto &I = DefinedOperands.find(Name);
+
+  if (I == DefinedOperands.end())
+    PrintFatalError(SrcLoc, "Operand " + Name + " was not declared in matcher");
+
+  return *I->second;
+}
+
 const OperandMatcher &RuleMatcher::getOperandMatcher(StringRef Name) const {
   const auto &I = DefinedOperands.find(Name);
 
@@ -1081,6 +1090,17 @@ void RecordNamedOperandMatcher::emitPredicateOpcodes(MatchTable &Table,
         << MatchTable::Comment("Name : " + Name) << MatchTable::LineBreak;
 }
 
+//===- RecordRegisterType ------------------------------------------===//
+
+void RecordRegisterType::emitPredicateOpcodes(MatchTable &Table,
+                                              RuleMatcher &Rule) const {
+  assert(Idx < 0 && "Temp types always have negative indexes!");
+  Table << MatchTable::Opcode("GIM_RecordRegType") << MatchTable::Comment("MI")
+        << MatchTable::IntValue(InsnVarID) << MatchTable::Comment("Op")
+        << MatchTable::IntValue(OpIdx) << MatchTable::Comment("TempTypeIdx")
+        << MatchTable::IntValue(Idx) << MatchTable::LineBreak;
+}
+
 //===- ComplexPatternOperandMatcher ---------------------------------------===//
 
 void ComplexPatternOperandMatcher::emitPredicateOpcodes(
@@ -1196,6 +1216,18 @@ std::string OperandMatcher::getOperandExpr(unsigned InsnVarID) const {
 
 unsigned OperandMatcher::getInsnVarID() const { return Insn.getInsnVarID(); }
 
+TempTypeIdx OperandMatcher::getTempTypeIdx(RuleMatcher &Rule) {
+  if (TTIdx >= 0) {
+    // Temp type index not assigned yet, so assign one and add the necessary
+    // predicate.
+    TTIdx = Rule.getNextTempTypeIdx();
+    assert(TTIdx < 0);
+    addPredicate<RecordRegisterType>(TTIdx);
+    return TTIdx;
+  }
+  return TTIdx;
+}
+
 void OperandMatcher::emitPredicateOpcodes(MatchTable &Table,
                                           RuleMatcher &Rule) {
   if (!Optimized) {
@@ -2092,9 +2124,7 @@ void MakeTempRegisterAction::emitActionOpcodes(MatchTable &Table,
                                                RuleMatcher &Rule) const {
   Table << MatchTable::Opcode("GIR_MakeTempReg")
         << MatchTable::Comment("TempRegID") << MatchTable::IntValue(TempRegID)
-        << MatchTable::Comment("TypeID")
-        << MatchTable::NamedValue(Ty.getCxxEnumValue())
-        << MatchTable::LineBreak;
+        << MatchTable::Comment("TypeID") << Ty << MatchTable::LineBreak;
 }
 
 } // namespace gi
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.h b/llvm/utils/TableGen/GlobalISelMatchTable.h
index 5608bab482bfd34..364f2a1ec725d53 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.h
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.h
@@ -273,6 +273,40 @@ extern std::set<LLTCodeGen> KnownTypes;
 /// MVTs that don't map cleanly to an LLT (e.g., iPTR, *any, ...).
 std::optional<LLTCodeGen> MVTToLLT(MVT::SimpleValueType SVT);
 
+using TempTypeIdx = int64_t;
+class LLTCodeGenOrTempType {
+public:
+  LLTCodeGenOrTempType(const LLTCodeGen &LLT) : Data(LLT) {}
+  LLTCodeGenOrTempType(TempTypeIdx TempTy) : Data(TempTy) {}
+
+  bool isLLTCodeGen() const { return std::holds_alternative<LLTCodeGen>(Data); }
+  bool isTempTypeIdx() const {
+    return std::holds_alternative<TempTypeIdx>(Data);
+  }
+
+  const LLTCodeGen &getLLTCodeGen() const {
+    assert(isLLTCodeGen());
+    return std::get<LLTCodeGen>(Data);
+  }
+
+  TempTypeIdx getTempTypeIdx() const {
+    assert(isTempTypeIdx());
+    return std::get<TempTypeIdx>(Data);
+  }
+
+private:
+  std::variant<LLTCodeGen, TempTypeIdx> Data;
+};
+
+inline MatchTable &operator<<(MatchTable &Table,
+                              const LLTCodeGenOrTempType &Ty) {
+  if (Ty.isLLTCodeGen())
+    Table << MatchTable::NamedValue(Ty.getLLTCodeGen().getCxxEnumValue());
+  else
+    Table << MatchTable::IntValue(Ty.getTempTypeIdx());
+  return Table;
+}
+
 //===- Matchers -----------------------------------------------------------===//
 class Matcher {
 public:
@@ -459,6 +493,9 @@ class RuleMatcher : public Matcher {
   /// ID for the next temporary register ID allocated with allocateTempRegID()
   unsigned NextTempRegID;
 
+  /// ID for the next recorded type. Starts at -1 and counts down.
+  TempTypeIdx NextTempTypeIdx = -1;
+
   // HwMode predicate index for this rule. -1 if no HwMode.
   int HwModeIdx = -1;
 
@@ -498,6 +535,8 @@ class RuleMatcher : public Matcher {
   RuleMatcher(RuleMatcher &&Other) = default;
   RuleMatcher &operator=(RuleMatcher &&Other) = default;
 
+  TempTypeIdx getNextTempTypeIdx() { return NextTempTypeIdx--; }
+
   uint64_t getRuleID() const { return RuleID; }
 
   InstructionMatcher &addInstructionMatcher(StringRef SymbolicName);
@@ -602,6 +641,7 @@ class RuleMatcher : public Matcher {
   }
 
   InstructionMatcher &getInstructionMatcher(StringRef SymbolicName) const;
+  OperandMatcher &getOperandMatcher(StringRef Name);
   const OperandMatcher &getOperandMatcher(StringRef Name) const;
   const OperandMatcher &getPhysRegOperandMatcher(Record *) const;
 
@@ -762,6 +802,7 @@ class PredicateMatcher {
     OPM_RegBank,
     OPM_MBB,
     OPM_RecordNamedOperand,
+    OPM_RecordRegType,
   };
 
 protected:
@@ -963,6 +1004,30 @@ class RecordNamedOperandMatcher : public OperandPredicateMatcher {
                             RuleMatcher &Rule) const override;
 };
 
+/// Generates code to store a register operand's type into the set of temporary
+/// LLTs.
+class RecordRegisterType : public OperandPredicateMatcher {
+protected:
+  TempTypeIdx Idx;
+
+public:
+  RecordRegisterType(unsigned InsnVarID, unsigned OpIdx, TempTypeIdx Idx)
+      : OperandPredicateMatcher(OPM_RecordRegType, InsnVarID, OpIdx), Idx(Idx) {
+  }
+
+  static bool classof(const PredicateMatcher *P) {
+    return P->getKind() == OPM_RecordRegType;
+  }
+
+  bool isIdentical(const PredicateMatcher &B) const override {
+    return OperandPredicateMatcher::isIdentical(B) &&
+           Idx == cast<RecordRegisterType>(&B)->Idx;
+  }
+
+  void emitPredicateOpcodes(MatchTable &Table,
+                            RuleMatcher &Rule) const override;
+};
+
 /// Generates code to check that an operand is a particular target constant.
 class ComplexPatternOperandMatcher : public OperandPredicateMatcher {
 protected:
@@ -1169,6 +1234,8 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
   /// countRendererFns().
   unsigned AllocatedTemporariesBaseID;
 
+  TempTypeIdx TTIdx = 0;
+
 public:
   OperandMatcher(InstructionMatcher &Insn, unsigned OpIdx,
                  const std::string &SymbolicName,
@@ -1196,6 +1263,11 @@ class OperandMatcher : public PredicateListMatcher<OperandPredicateMatcher> {
   unsigned getOpIdx() const { return OpIdx; }
   unsigned getInsnVarID() const;
 
+  /// If this OperandMatcher has not been assigned a TempTypeIdx yet, assigns it
+  /// one and adds a `RecordRegisterType` predicate to this matcher. If one has
+  /// already been assigned, simply returns it.
+  TempTypeIdx getTempTypeIdx(RuleMatcher &Rule);
+
   std::string getOperandExpr(unsigned InsnVarID) const;
 
   InstructionMatcher &getInstructionMatcher() const { return Insn; }
@@ -1955,15 +2027,16 @@ class ImmRenderer : public OperandRenderer {
 protected:
   unsigned InsnID;
   int64_t Imm;
-  std::optional<LLTCodeGen> CImmLLT;
+  std::optional<LLTCodeGenOrTempType> CImmLLT;
 
 public:
   ImmRenderer(unsigned InsnID, int64_t Imm)
       : OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm) {}
 
-  ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGen &CImmLLT)
+  ImmRenderer(unsigned InsnID, int64_t Imm, const LLTCodeGenOrTempType &CImmLLT)
       : OperandRenderer(OR_Imm), InsnID(InsnID), Imm(Imm), CImmLLT(CImmLLT) {
-    KnownTypes.insert(CImmLLT);
+    if (CImmLLT.isLLTCodeGen())
+      KnownTypes.insert(CImmLLT.getLLTCodeGen());
   }
 
   static bool classof(const OperandRenderer *R) {
@@ -1976,8 +2049,7 @@ class ImmRenderer : public OperandRenderer {
              "ConstantInt immediate are only for combiners!");
       Table << MatchTable::Opcode("GIR_AddCImm")
             << MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
-            << MatchTable::Comment("Type")
-            << MatchTable::NamedValue(CImmLLT->getCxxEnumValue())
+            << MatchTable::Comment("Type") << *CImmLLT
             << MatchTable::Comment("Imm") << MatchTable::IntValue(Imm)
             << MatchTable::LineBreak;
     } else {
@@ -2290,13 +2362,14 @@ class ConstrainOperandToRegClassAction : public MatchAction {
 /// instructions together.
 class MakeTempRegisterAction : public MatchAction {
 private:
-  LLTCodeGen Ty;
+  LLTCodeGenOrTempType Ty;
   unsigned TempRegID;
 
 public:
-  MakeTempRegisterAction(const LLTCodeGen &Ty, unsigned TempRegID)
+  MakeTempRegisterAction(const LLTCodeGenOrTempType &Ty, unsigned TempRegID)
       : MatchAction(AK_MakeTempReg), Ty(Ty), TempRegID(TempRegID) {
-    KnownTypes.insert(Ty);
+    if (Ty.isLLTCodeGen())
+      KnownTypes.insert(Ty.getLLTCodeGen());
   }
 
   static bool classof(const MatchAction *A) {

>From b3981d223e03e959a46b35c7aa546a94a5e5a0da Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Thu, 14 Sep 2023 15:52:29 +0200
Subject: [PATCH 2/2] [TableGen][GlobalISel] Add rule-wide type inference

The inference is trivial and leverages the MCOI OperandTypes encoded in
CodeGenInstructions to infer types across patterns in a CombineRule. It's
thus very limited and only supports CodeGenInstructions (but that's the main
use case so it's fine).

We only try to infer untyped operands in apply patterns when they're temp
reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
a named operand from a match pattern.
---
 llvm/include/llvm/Target/GenericOpcodes.td    |   7 +
 .../include/llvm/Target/GlobalISel/Combine.td |   2 +-
 .../pattern-errors.td                         |   4 +-
 .../type-inference.td                         |  75 ++
 .../typeof-errors.td                          |   7 +-
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 749 +++++++++++++++---
 6 files changed, 713 insertions(+), 131 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td

diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index a1afc3b8042c284..9a9c09d3c20d612 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -17,6 +17,10 @@
 
 class GenericInstruction : StandardPseudoInstruction {
   let isPreISelOpcode = true;
+
+  // When all variadic ops share a type with another operand,
+  // this is the type they share. Used by MIR patterns type inference.
+  TypedOperand variadicOpsType = ?;
 }
 
 // Provide a variant of an instruction with the same operands, but
@@ -1228,6 +1232,7 @@ def G_UNMERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst0, variable_ops);
   let InOperandList = (ins type1:$src);
   let hasSideEffects = false;
+  let variadicOpsType = type0;
 }
 
 // Insert a smaller register into a larger one at the specified bit-index.
@@ -1245,6 +1250,7 @@ def G_MERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Create a vector from multiple scalar registers. No implicit
@@ -1254,6 +1260,7 @@ def G_BUILD_VECTOR : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Like G_BUILD_VECTOR, but truncates the larger operand types to fit the
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 31a337feed847ec..b1c9b77c23fb69d 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -796,7 +796,7 @@ def trunc_shift: GICombineRule <
 def mul_by_neg_one: GICombineRule <
   (defs root:$dst),
   (match (G_MUL $dst, $x, -1)),
-  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  (apply (G_SUB $dst, 0, $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
index 48a06474da78a10..77321a3f2b59dbe 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
@@ -151,7 +151,7 @@ def bad_imm_too_many_args : GICombineRule<
   (match (COPY $x, (i32 0, 0)):$d),
   (apply (COPY $x, $b):$d)>;
 
-// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)', 'COPY' is not a ValueType
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: unknown type 'COPY'
 // CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(COPY ?:$x, (COPY 0))
 def bad_imm_not_a_valuetype : GICombineRule<
   (defs root:$a),
@@ -186,7 +186,7 @@ def expected_op_name : GICombineRule<
   (match (G_FNEG $x, i32)),
   (apply (COPY $x, (i32 0)))>;
 
-// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: invalid operand type: 'not_a_type' is not a ValueType
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: unknown type 'not_a_type'
 // CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_FNEG ?:$x, not_a_type:$y)'
 def not_a_type;
 def bad_mo_type_not_a_valuetype : GICombineRule<
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
new file mode 100644
index 000000000000000..7ed14dd5e6cc0eb
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
@@ -0,0 +1,75 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -gicombiner-debug-typeinfer -combiners=MyCombiner %s 2>&1 | \
+// RUN: FileCheck %s
+
+// Checks reasoning of the inference rules.
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK:      Rule Operand Type Equivalence Classes for inference_mul_by_neg_one:
+// CHECK-NEXT:         __inference_mul_by_neg_one_match_0:             [dst, x]
+// CHECK-NEXT:         __inference_mul_by_neg_one_apply_0:             [dst, x]
+// CHECK-NEXT: (merging [dst, x] | [dst, x])
+// CHECK-NEXT: Result: [dst, x]
+// CHECK-NEXT: INFER: imm 0 -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule inference_mul_by_neg_one after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__inference_mul_by_neg_one_apply_0 G_SUB operands:[<def>$dst, (GITypeOf<$x> 0), $x])
+def inference_mul_by_neg_one: GICombineRule <
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, 0, $x))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_complex_tempreg:
+// CHECK-NEXT:         __infer_complex_tempreg_match_0:                [dst]   [x, y, z]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_0:                [tmp2]  [x, y]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_1:                [tmp, tmp2]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_2:                [dst, tmp]
+// CHECK-NEXT: (merging [dst] | [dst, tmp])
+// CHECK-NEXT: (merging [x, y, z] | [x, y])
+// CHECK-NEXT: (merging [tmp2] | [tmp, tmp2])
+// CHECK-NEXT: (merging [dst, tmp] | [tmp2, tmp])
+// CHECK-NEXT: Result: [dst, tmp, tmp2] [x, y, z]
+// CHECK-NEXT: INFER: MachineOperand $tmp2 -> GITypeOf<$dst>
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst>
+// CHECK-NEXT: Apply patterns for rule infer_complex_tempreg after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_0 G_BUILD_VECTOR operands:[<def>GITypeOf<$dst>:$tmp2, $x, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_1 G_FNEG operands:[<def>GITypeOf<$dst>:$tmp, GITypeOf<$dst>:$tmp2])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_2 G_FNEG operands:[<def>$dst, GITypeOf<$dst>:$tmp])
+def infer_complex_tempreg: GICombineRule <
+  (defs root:$dst),
+  (match (G_MERGE_VALUES $dst, $x, $y, $z)),
+  (apply (G_BUILD_VECTOR $tmp2, $x, $y),
+         (G_FNEG $tmp, $tmp2),
+         (G_FNEG $dst, $tmp))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_variadic_outs:
+// CHECK-NEXT:         __infer_variadic_outs_match_0:          [x, y]  [vec]
+// CHECK-NEXT:         __infer_variadic_outs_match_1:          [dst, x]
+// CHECK-NEXT:         __infer_variadic_outs_apply_0:          [tmp, y]
+// CHECK-NEXT:         __infer_variadic_outs_apply_1:  (empty)
+// CHECK-NEXT: (merging [x, y] | [dst, x])
+// CHECK-NEXT: (merging [x, y, dst] | [tmp, y])
+// CHECK-NEXT: Result: [x, y, dst, tmp] [vec]
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule infer_variadic_outs after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_0 G_FNEG operands:[<def>GITypeOf<$x>:$tmp, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_1 COPY operands:[<def>$dst, GITypeOf<$x>:$tmp])
+def infer_variadic_outs: GICombineRule <
+  (defs root:$dst),
+  (match  (G_UNMERGE_VALUES $x, $y, $vec),
+          (G_FNEG $dst, $x)),
+  (apply (G_FNEG $tmp, $y),
+         (COPY $dst, $tmp))
+>;
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  inference_mul_by_neg_one,
+  infer_complex_tempreg,
+  infer_variadic_outs
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
index 6040d6def449766..b86b4ec19564488 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -8,7 +8,8 @@ include "llvm/Target/GlobalISel/Combine.td"
 def MyTargetISA : InstrInfo;
 def MyTarget : Target { let InstructionSet = MyTargetISA; }
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ANYEXT ?:$dst, (anonymous_
 def NoDollarSign : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, $src)),
@@ -47,7 +48,9 @@ def InferredUseInMatch : GICombineRule<
   (match (G_ZEXT $dst, $src)),
   (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: conflicting types for operand 'src': 'i32' vs 'GITypeOf<$dst>'
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: 'src' seen with type 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: 'src' seen with type 'i32' in '__InferenceConflict_match_0'
 def InferenceConflict : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, i32:$src)),
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 844d8b1051bd471..f7aa79bf8a63fb5 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -37,6 +37,8 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
@@ -68,6 +70,9 @@ cl::opt<bool> DebugCXXPreds(
     "gicombiner-debug-cxxpreds",
     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
     cl::cat(GICombinerEmitterCat));
+cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
+                             cl::desc("Print type inference debug logs"),
+                             cl::cat(GICombinerEmitterCat));
 
 constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
@@ -125,6 +130,20 @@ template <typename Container> auto values(Container &&C) {
   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
 }
 
+template <typename SetTy> bool doSetsIntersect(const SetTy &A, const SetTy &B) {
+  for (const auto &Elt : A) {
+    if (B.contains(Elt))
+      return true;
+  }
+  return false;
+}
+
+template <class Ty>
+void setVectorUnion(SetVector<Ty> &S1, const SetVector<Ty> &S2) {
+  for (auto SI = S2.begin(), SE = S2.end(); SI != SE; ++SI)
+    S1.insert(*SI);
+}
+
 //===- MatchData Handling -------------------------------------------------===//
 
 /// Represents MatchData defined by the match stage and required by the apply
@@ -298,23 +317,30 @@ CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 ///   - Special types, e.g. GITypeOf
 class PatternType {
 public:
-  PatternType() = default;
-  PatternType(const Record *R) : R(R) {}
+  enum PTKind : uint8_t {
+    PT_None,
+
+    PT_ValueType,
+    PT_TypeOf,
+  };
 
-  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+  PatternType() : Kind(PT_None), Data() {}
 
-  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
-  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
-  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+  static std::optional<PatternType> get(ArrayRef<SMLoc> DiagLoc,
+                                        const Record *R);
+  static PatternType getTypeOf(StringRef OpName);
+
+  bool isNone() const { return Kind == PT_None; }
+  bool isLLT() const { return Kind == PT_ValueType; }
+  bool isSpecial() const { return isTypeOf(); }
+  bool isTypeOf() const { return Kind == PT_TypeOf; }
 
   StringRef getTypeOfOpName() const;
   LLTCodeGen getLLTCodeGen() const;
 
-  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
-
   LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
 
-  explicit operator bool() const { return R != nullptr; }
+  explicit operator bool() const { return !isNone(); }
 
   bool operator==(const PatternType &Other) const;
   bool operator!=(const PatternType &Other) const { return !operator==(Other); }
@@ -322,26 +348,66 @@ class PatternType {
   std::string str() const;
 
 private:
-  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+  PatternType(PTKind Kind) : Kind(Kind), Data() {}
+
+  PTKind Kind;
+  union DataT {
+    DataT() : Str() {}
 
-  const Record *R = nullptr;
+    /// PT_ValueType -> ValueType Def.
+    const Record *Def;
+
+    /// PT_TypeOf -> Operand name (without the '$')
+    StringRef Str;
+  } Data;
 };
 
+std::optional<PatternType> PatternType::get(ArrayRef<SMLoc> DiagLoc,
+                                            const Record *R) {
+  assert(R);
+  if (R->isSubClassOf("ValueType")) {
+    PatternType PT(PT_ValueType);
+    PT.Data.Def = R;
+    return PT;
+  }
+
+  if (R->isSubClassOf(TypeOfClassName)) {
+    auto RawOpName = R->getValueAsString("OpName");
+    if (!RawOpName.starts_with("$")) {
+      PrintError(DiagLoc, "invalid operand name format '" + RawOpName +
+                              "' in " + TypeOfClassName +
+                              ": expected '$' followed by an operand name");
+      return std::nullopt;
+    }
+
+    PatternType PT(PT_TypeOf);
+    PT.Data.Str = RawOpName.drop_front(1);
+    return PT;
+  }
+
+  PrintError(DiagLoc, "unknown type '" + R->getName() + "'");
+  return std::nullopt;
+}
+
+PatternType PatternType::getTypeOf(StringRef OpName) {
+  PatternType PT(PT_TypeOf);
+  PT.Data.Str = OpName;
+  return PT;
+}
+
 StringRef PatternType::getTypeOfOpName() const {
   assert(isTypeOf());
-  StringRef Name = getRawOpName();
-  Name.consume_front("$");
-  return Name;
+  return Data.Str;
 }
 
 LLTCodeGen PatternType::getLLTCodeGen() const {
   assert(isLLT());
-  return *MVTToLLT(getValueType(R));
+  return *MVTToLLT(getValueType(Data.Def));
 }
 
 LLTCodeGenOrTempType
 PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
-  assert(isValidType());
+  assert(!isNone());
 
   if (isLLT())
     return getLLTCodeGen();
@@ -351,51 +417,31 @@ PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
   return OM.getTempTypeIdx(RM);
 }
 
-bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
-  if (!isTypeOf())
-    return true;
-
-  auto RawOpName = getRawOpName();
-  if (RawOpName.starts_with("$"))
-    return true;
-
-  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
-                          TypeOfClassName +
-                          ": expected '$' followed by an operand name");
-  return false;
-}
-
 bool PatternType::operator==(const PatternType &Other) const {
-  if (R == Other.R) {
-    if (R && R->getName() != Other.R->getName()) {
-      dbgs() << "Same ptr but: " << R->getName() << " and "
-             << Other.R->getName() << "?\n";
-      assert(false);
-    }
+  if (Kind != Other.Kind)
+    return false;
+
+  switch (Kind) {
+  case PT_None:
     return true;
+  case PT_ValueType:
+    return Data.Def == Other.Data.Def;
+  case PT_TypeOf:
+    return Data.Str == Other.Data.Str;
   }
 
-  if (isTypeOf() && Other.isTypeOf())
-    return getTypeOfOpName() == Other.getTypeOfOpName();
-
-  return false;
+  llvm_unreachable("Unknown Type Kind");
 }
 
 std::string PatternType::str() const {
-  if (!R)
+  switch (Kind) {
+  case PT_None:
     return "";
-
-  if (!isValidType())
-    return "<invalid>";
-
-  if (isLLT())
-    return R->getName().str();
-
-  assert(isSpecial());
-
-  // TODO: Use an enum for these?
-  if (isTypeOf())
+  case PT_ValueType:
+    return Data.Def->getName().str();
+  case PT_TypeOf:
     return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+  }
 
   llvm_unreachable("Unknown type!");
 }
@@ -608,14 +654,10 @@ class InstructionOperand {
   using IntImmTy = int64_t;
 
   InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
-      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {}
 
   InstructionOperand(StringRef Name, PatternType Type)
-      : Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Name(insertStrRef(Name)), Type(Type) {}
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -639,7 +681,6 @@ class InstructionOperand {
 
   void setType(PatternType NewType) {
     assert((!Type || (Type == NewType)) && "Overwriting type!");
-    assert(NewType.isValidType());
     Type = NewType;
   }
   PatternType getType() const { return Type; }
@@ -810,12 +851,10 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
 /// Maps InstructionPattern operands to their definitions. This allows us to tie
 /// different patterns of a (apply), (match) or (patterns) set of patterns
 /// together.
-template <typename DefTy = InstructionPattern> class OperandTable {
+class OperandTable {
 public:
-  static_assert(std::is_base_of_v<InstructionPattern, DefTy>,
-                "DefTy should be a derived class from InstructionPattern");
-
-  bool addPattern(DefTy *P, function_ref<void(StringRef)> DiagnoseRedef) {
+  bool addPattern(InstructionPattern *P,
+                  function_ref<void(StringRef)> DiagnoseRedef) {
     for (const auto &Op : P->named_operands()) {
       StringRef OpName = Op.getOperandName();
 
@@ -844,10 +883,10 @@ template <typename DefTy = InstructionPattern> class OperandTable {
 
   struct LookupResult {
     LookupResult() = default;
-    LookupResult(DefTy *Def) : Found(true), Def(Def) {}
+    LookupResult(InstructionPattern *Def) : Found(true), Def(Def) {}
 
     bool Found = false;
-    DefTy *Def = nullptr;
+    InstructionPattern *Def = nullptr;
 
     bool isLiveIn() const { return Found && !Def; }
   };
@@ -858,7 +897,9 @@ template <typename DefTy = InstructionPattern> class OperandTable {
     return LookupResult();
   }
 
-  DefTy *getDef(StringRef OpName) const { return lookup(OpName).Def; }
+  InstructionPattern *getDef(StringRef OpName) const {
+    return lookup(OpName).Def;
+  }
 
   void print(raw_ostream &OS, StringRef Name = "",
              StringRef Indent = "") const {
@@ -888,7 +929,7 @@ template <typename DefTy = InstructionPattern> class OperandTable {
   void dump() const { print(dbgs()); }
 
 private:
-  StringMap<DefTy *> Table;
+  StringMap<InstructionPattern *> Table;
 };
 
 //===- CodeGenInstructionPattern ------------------------------------------===//
@@ -957,62 +998,76 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 ///
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
-/// setAllOperandTypes.
+/// propagateTypes.
 ///
 /// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P,
+  /// Step 1: Check each pattern one by one. All patterns that pass through here
+  /// are added to a common worklist so propagateTypes can access them.
+  bool check(InstructionPattern &P,
              std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
-  void setAllOperandTypes();
+  /// Step 2: Propagate all types. e.g. if one use of "$a" has type i32, make
+  /// all uses of "$a" have type i32.
+  void propagateTypes();
+
+protected:
+  ArrayRef<SMLoc> DiagLoc;
 
 private:
+  using InconsistentTypeDiagFn = std::function<void()>;
+
+  void PrintSeenWithTypeIn(InstructionPattern &P, StringRef OpName,
+                           PatternType Ty) const {
+    PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() +
+                           "' in '" + P.getName() + "'");
+  }
+
   struct OpTypeInfo {
     PatternType Type;
-    InstructionPattern *TypeSrc = nullptr;
+    InconsistentTypeDiagFn PrintTypeSrcNote = []() {};
   };
 
-  ArrayRef<SMLoc> DiagLoc;
   StringMap<OpTypeInfo> Types;
 
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
 bool OperandTypeChecker::check(
-    InstructionPattern *P,
+    InstructionPattern &P,
     std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
-  Pats.push_back(P);
+  Pats.push_back(&P);
 
-  for (auto &Op : P->operands()) {
+  for (auto &Op : P.operands()) {
     const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    if (!Ty.checkSemantics(DiagLoc))
-      return false;
-
     if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
       return false;
 
     if (!Op.isNamedOperand())
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    StringRef OpName = Op.getOperandName();
+    auto &Info = Types[OpName];
     if (!Info.Type) {
       Info.Type = Ty;
-      Info.TypeSrc = P;
+      Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() {
+        PrintSeenWithTypeIn(P, OpName, Ty);
+      };
       continue;
     }
 
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
-                              Op.getOperandName() + "': first seen with '" +
-                              Info.Type.str() + "' in '" +
-                              Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty.str() + "' in '" + P->getName() + "'");
+                              Op.getOperandName() + "': '" + Info.Type.str() +
+                              "' vs '" + Ty.str() + "'");
+      PrintSeenWithTypeIn(P, OpName, Ty);
+      Info.PrintTypeSrcNote();
       return false;
     }
   }
@@ -1020,7 +1075,7 @@ bool OperandTypeChecker::check(
   return true;
 }
 
-void OperandTypeChecker::setAllOperandTypes() {
+void OperandTypeChecker::propagateTypes() {
   for (auto *Pat : Pats) {
     for (auto &Op : Pat->named_operands()) {
       if (auto &Info = Types[Op.getOperandName()]; Info.Type)
@@ -1074,7 +1129,7 @@ class PatFrag {
   /// Each argument to the `pattern` DAG operator is parsed into a Pattern
   /// instance.
   struct Alternative {
-    OperandTable<> OpTable;
+    OperandTable OpTable;
     SmallVector<std::unique_ptr<Pattern>, 4> Pats;
   };
 
@@ -1298,11 +1353,11 @@ bool PatFrag::checkSemantics() {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP, CheckTypeOf))
+        if (!OTC.check(*IP, CheckTypeOf))
           return false;
       }
     }
-    OTC.setAllOperandTypes();
+    OTC.propagateTypes();
   }
 
   return true;
@@ -1640,6 +1695,471 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry {
   }
 };
 
+//===- CombineRuleOperandTypeChecker --------------------------------------===//
+
+/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
+/// On top of doing the same things as OperandTypeChecker, this also attempts to
+/// infer as many types as possible for temporary register defs & immediates in
+/// apply patterns.
+///
+/// The inference is trivial and leverages the MCOI OperandTypes encoded in
+/// CodeGenInstructions to infer types across patterns in a CombineRule. It's
+/// thus very limited and only supports CodeGenInstructions (but that's the main
+/// use case so it's fine).
+///
+/// We only try to infer untyped operands in apply patterns when they're temp
+/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
+/// a named operand from a match pattern.
+class CombineRuleOperandTypeChecker : private OperandTypeChecker {
+public:
+  CombineRuleOperandTypeChecker(const Record &RuleDef,
+                                const OperandTable &MatchOpTable)
+      : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
+        MatchOpTable(MatchOpTable) {}
+
+  /// Records and checks a 'match' pattern.
+  bool processMatchPattern(InstructionPattern &P);
+
+  /// Records and checks an 'apply' pattern.
+  bool processApplyPattern(InstructionPattern &P);
+
+  /// Propagates types, then perform type inference and do a second round of
+  /// propagation in the apply patterns only if any types were inferred.
+  void propagateAndInferTypes();
+
+private:
+  /// TypeEquivalenceClasses are groups of operands of an instruction that share
+  /// a common type.
+  ///
+  /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
+  /// d have the same type too. b/c and a/d don't have to have the same type,
+  /// though.
+  ///
+  /// NOTE: We use a SetVector, not a Set. This is to guarantee a stable
+  /// iteration order which is important because:
+  ///   - During inference, we iterate that set and pick the first suitable
+  ///   candidate. Using a normal set could make inference inconsistent across
+  ///   runs if the Set uses the StringRef ptr to cache values.
+  ///   - We print this set if DebugInfer is set, and we don't want our tests to
+  ///   fail randomly due to the Set's iteration order changing.
+  using TypeEquivalenceClasses = std::vector<SetVector<StringRef>>;
+
+  static std::string toString(const SetVector<StringRef> &EqClass) {
+    return "[" + join(EqClass, ", ") + "]";
+  }
+
+  /// \returns true for `OPERAND_GENERIC_` 0 through 5.
+  /// These are the MCOI types that can be registers. The other MCOI types are
+  /// either immediates, or fancier operands used only post-ISel, so we don't
+  /// care about them for combiners.
+  static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
+    // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
+    // OperandTypes are either never used in gMIR, or not relevant (e.g.
+    // OPERAND_GENERIC_IMM, which is definitely never a register).
+    return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
+  }
+
+  /// Finds the "MCOI::"" operand types for each operand of \p CGP.
+  ///
+  /// This is a bit trickier than it looks because we need to handle variadic
+  /// in/outs.
+  ///
+  /// e.g. for
+  ///   (G_BUILD_VECTOR $vec, $x, $y) ->
+  ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  ///    MCOI::OPERAND_GENERIC_1]
+  ///
+  /// For unknown types (which can happen in variadics where varargs types are
+  /// inconsistent), a unique name is given, e.g. "unknown_type_0".
+  static std::vector<std::string>
+  getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
+
+  /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
+  void getInstEqClasses(const InstructionPattern &P,
+                        TypeEquivalenceClasses &OutTECs) const;
+
+  /// Calculates the TypeEquivalenceClasses for each instruction, then merges
+  /// them into a common set of TypeEquivalenceClasses for the whole rule.
+  ///
+  /// This works by repeatedly merging intersecting type equivalence classes
+  /// until no more merging occurs.
+  ///
+  /// This essentially applies the "transitive" part of type inference. Let's
+  /// take the following equivalence classes:
+  ///   inst0: [a, b], [c, d]
+  ///   inst1: [b, c]
+  ///
+  /// If we see inst0 alone, we can't say that a and d have the same type -
+  /// they're not in the same equivalence classes. However if we just use logic,
+  /// we can say: "a == d because a == b, b == c and c == d".
+  ///
+  /// Merging condenses that information into a single big equivalence class
+  /// which can be looked at alone to make the same deduction.
+  ///   rule: [a, b, c, d]
+  TypeEquivalenceClasses getRuleEqClasses() const;
+
+  /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
+  /// TECs.
+  ///
+  /// This is achieved by trying to find a named operand in \p IP that shares
+  /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
+  /// operand instead.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferImmediateType(const InstructionPattern &IP,
+                                 unsigned ImmOpIdx,
+                                 const TypeEquivalenceClasses &TECs) const;
+
+  /// Looks inside \p TECs to infer \p OpName's type.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferNamedOperandType(const InstructionPattern &IP,
+                                    StringRef OpName,
+                                    const TypeEquivalenceClasses &TECs) const;
+
+  const Record &RuleDef;
+  SmallVector<InstructionPattern *, 8> MatchPats;
+  SmallVector<InstructionPattern *, 8> ApplyPats;
+
+  const OperandTable &MatchOpTable;
+};
+
+bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
+  MatchPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [](const auto &) {
+    // GITypeOf in 'match' is currently always rejected by the
+    // CombineRuleBuilder after inference is done.
+    return true;
+  });
+}
+
+bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
+  ApplyPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
+                                     "') does not refer to a matched operand!");
+    return false;
+  });
+}
+
+void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
+  /// First step here is to propagate types using the OperandTypeChecker. That
+  /// way we ensure all uses of a given register have consistent types.
+  propagateTypes();
+
+  /// Build the TypeEquivalenceClasses for the whole rule.
+  const TypeEquivalenceClasses TECs = getRuleEqClasses();
+
+  /// Look at the apply patterns and find operands that need to be
+  /// inferred. We then try to find an equivalence class that they're a part of
+  /// and select the best operand to use for the `GITypeOf` type. We prioritize
+  /// defs of matched instructions because those are guaranteed to be registers.
+  bool InferredAny = false;
+  for (auto *Pat : ApplyPats) {
+    for (unsigned K = 0; K < Pat->operands_size(); ++K) {
+      auto &Op = Pat->getOperand(K);
+
+      // We only want to take a look at untyped defs or immediates.
+      if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
+        continue;
+
+      // Infer defs & named immediates.
+      if (Op.isDef() || Op.isNamedImmediate()) {
+        // Check it's not a redefinition of a matched operand.
+        // In such cases, inference is not necessary because we just copy
+        // operands and don't create temporary registers.
+        if (MatchOpTable.lookup(Op.getOperandName()).Found)
+          continue;
+
+        // Inference is needed here, so try to do it.
+        if (PatternType Ty =
+                inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+
+        continue;
+      }
+
+      // Infer immediates
+      if (Op.hasImmValue()) {
+        if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+        continue;
+      }
+    }
+  }
+
+  // If we've inferred any types, we want to propagate them across the apply
+  // patterns. Type inference only adds GITypeOf types that point to Matched
+  // operands, so we definitely don't want to propagate types into the match
+  // patterns as well, otherwise bad things happen.
+  if (InferredAny) {
+    OperandTypeChecker OTC(RuleDef.getLoc());
+    for (auto *Pat : ApplyPats) {
+      if (!OTC.check(*Pat, [&](const auto &) { return true; }))
+        PrintFatalError(RuleDef.getLoc(),
+                        "OperandTypeChecker unexpectedly failed on '" +
+                            Pat->getName() + "' during Type Inference");
+    }
+    OTC.propagateTypes();
+
+    if (DebugTypeInfer) {
+      errs() << "Apply patterns for rule " << RuleDef.getName()
+             << " after inference:\n";
+      for (auto *Pat : ApplyPats) {
+        Pat->print(errs(), /*PrintName*/ true);
+        errs() << "\n";
+      }
+      errs() << "\n";
+    }
+  }
+}
+
+PatternType CombineRuleOperandTypeChecker::inferImmediateType(
+    const InstructionPattern &IP, unsigned ImmOpIdx,
+    const TypeEquivalenceClasses &TECs) const {
+  // We can only infer CGPs.
+  const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGP)
+    return {};
+
+  // For CGPs, we try to infer immediates by trying to infer another named
+  // operand that shares its type.
+  //
+  // e.g.
+  //    Pattern: G_BUILD_VECTOR $x, $y, 0
+  //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  //              MCOI::OPERAND_GENERIC_1]
+  //    $y has the same type as 0, so we can infer $y and get the type 0 should
+  //    have.
+
+  // We infer immediates by looking for a named operand that shares the same
+  // MCOI type.
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  StringRef ImmOpTy = MCOITypes[ImmOpIdx];
+
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
+    if (Idx != ImmOpIdx && Ty == ImmOpTy) {
+      const auto &Op = IP.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      // Named operand with the same name, try to infer that.
+      if (PatternType InferTy =
+              inferNamedOperandType(IP, Op.getOperandName(), TECs))
+        return InferTy;
+    }
+  }
+
+  return {};
+}
+
+PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
+    const InstructionPattern &IP, StringRef OpName,
+    const TypeEquivalenceClasses &TECs) const {
+  // This is the simplest possible case, we just need to find a TEC that
+  // contains OpName.
+  for (const auto &TEC : TECs) {
+    if (!TEC.contains(OpName))
+      continue;
+
+    // This TEC mentions the operand. Look at all other operands in this TEC and
+    // try to find a suitable one.
+
+    // First, check for a def of a matched pattern. This is guaranteed to always
+    // be a register so we can blindly use that.
+    StringRef GoodOpName;
+    for (const auto &EqOp : TEC) {
+      if (EqOp == OpName)
+        continue;
+
+      const auto LookupRes = MatchOpTable.lookup(EqOp);
+      if (LookupRes.Def) // Favor defs
+        return PatternType::getTypeOf(EqOp);
+
+      // Otherwise just save this in case we don't find any def.
+      if (GoodOpName.empty() && LookupRes.Found)
+        GoodOpName = EqOp;
+    }
+
+    if (!GoodOpName.empty())
+      return PatternType::getTypeOf(GoodOpName);
+
+    // No good operand found, give up.
+    return {};
+  }
+
+  return {};
+}
+
+std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
+    const CodeGenInstructionPattern &CGP) {
+  // FIXME?: Should we cache this? We call it twice when inferring immediates.
+
+  static unsigned UnknownTypeIdx = 0;
+
+  std::vector<std::string> OpTypes;
+  auto &CGI = CGP.getInst();
+  Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
+                          ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
+                          : nullptr;
+  std::string VarArgsTyName =
+      VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
+                : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
+
+  // First, handle defs.
+  for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
+    OpTypes.push_back(CGI.Operands[K].OperandType);
+
+  // Then, handle variadic defs if there are any.
+  if (CGP.hasVariadicDefs()) {
+    for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
+      OpTypes.push_back(VarArgsTyName);
+  }
+
+  // If we had variadic defs, the op idx in the pattern won't match the op idx
+  // in the CGI anymore.
+  int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
+  assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
+
+  // Handle all remaining use operands, including variadic ones.
+  for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
+    unsigned CGIOpIdx = K + CGIOpOffset;
+    if (CGIOpIdx >= CGI.Operands.size()) {
+      assert(CGP.isVariadic());
+      OpTypes.push_back(VarArgsTyName);
+    } else {
+      OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
+    }
+  }
+
+  assert(OpTypes.size() == CGP.operands_size());
+  return OpTypes;
+}
+
+void CombineRuleOperandTypeChecker::getInstEqClasses(
+    const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
+  // Determine the TypeEquivalenceClasses by:
+  //    - Getting the MCOI Operand Types.
+  //    - Creating a Map of MCOI Type -> [Operand Indexes]
+  //    - Iterating over the map, filtering types we don't like, and just adding
+  //      the array of Operand Indexes to \p OutTECs.
+
+  // We can only do this on CodeGenInstructions. Other InstructionPatterns have
+  // no type inference information associated with them.
+  // TODO: Could we add some inference information to builtins at least? e.g.
+  // ReplaceReg should always replace with a reg of the same type, for instance.
+  // Though, those patterns are often used alone so it might not be worth the
+  // trouble to infer their types.
+  auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
+  if (!CGP)
+    return;
+
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  assert(MCOITypes.size() == P.operands_size());
+
+  DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes))
+    TyToOpIdx[Ty].push_back(Idx);
+
+  const unsigned FirstNewTEC = OutTECs.size();
+  for (const auto &[Ty, Idxs] : TyToOpIdx) {
+    if (!canMCOIOperandTypeBeARegister(Ty))
+      continue;
+
+    SetVector<StringRef> OpNames;
+    // We only collect named operands.
+    for (unsigned Idx : Idxs) {
+      const auto &Op = P.getOperand(Idx);
+      if (Op.isNamedOperand())
+        OpNames.insert(Op.getOperandName());
+    }
+    OutTECs.emplace_back(std::move(OpNames));
+  }
+
+  if (DebugTypeInfer) {
+    errs() << "\t" << P.getName() << ":\t";
+    if (FirstNewTEC == OutTECs.size())
+      errs() << "(empty)";
+    else {
+      for (unsigned K = FirstNewTEC; K < OutTECs.size(); ++K)
+        errs() << "\t" << toString(OutTECs[K]);
+    }
+    errs() << "\n";
+  }
+}
+
+CombineRuleOperandTypeChecker::TypeEquivalenceClasses
+CombineRuleOperandTypeChecker::getRuleEqClasses() const {
+  StringMap<unsigned> OpNameToEqClassIdx;
+  TypeEquivalenceClasses TECs;
+
+  if (DebugTypeInfer)
+    errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
+           << ":\n";
+
+  for (const auto *Pat : MatchPats)
+    getInstEqClasses(*Pat, TECs);
+  for (const auto *Pat : ApplyPats)
+    getInstEqClasses(*Pat, TECs);
+
+  bool Merged;
+  do {
+    Merged = false;
+    for (auto &X : TECs) {
+      if (X.empty()) // Already merged
+        continue;
+
+      for (auto &Y : TECs) {
+        if (&X == &Y || Y.empty()) // Same set or already merged.
+          continue;
+
+        if (doSetsIntersect(X, Y)) {
+          if (DebugTypeInfer)
+            errs() << "(merging " << toString(X) << " | " << toString(Y)
+                   << ")\n";
+          setVectorUnion(X, Y);
+          Merged = true;
+
+          // To avoid invalidating iterators during iteration, we just clear the
+          // other set then clean it up later.
+          Y.clear();
+        }
+      }
+    }
+
+    // Remove empty sets.
+    erase_if(TECs, [&](auto &Set) { return Set.empty(); });
+  } while (Merged);
+
+  if (DebugTypeInfer) {
+    errs() << "Result: ";
+    if (TECs.empty())
+      errs() << "(empty)";
+    else {
+      for (const auto &TEC : TECs)
+        errs() << toString(TEC) << " ";
+    }
+    errs() << "\n";
+  }
+
+  return TECs;
+}
+
 //===- CombineRuleBuilder -------------------------------------------------===//
 
 /// Parses combine rule and builds a small intermediate representation to tie
@@ -1821,8 +2341,8 @@ class CombineRuleBuilder {
   PatternMap ApplyPats;
 
   /// Operand tables to tie match/apply patterns together.
-  OperandTable<> MatchOpTable;
-  OperandTable<> ApplyOpTable;
+  OperandTable MatchOpTable;
+  OperandTable ApplyOpTable;
 
   /// Set by findRoots.
   Pattern *MatchRoot = nullptr;
@@ -2103,39 +2623,23 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 }
 
 bool CombineRuleBuilder::typecheckPatterns() {
-  OperandTypeChecker OTC(RuleDef.getLoc());
-
-  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
-    // We'll reject those after we're done inferring
-    return true;
-  };
+  CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
 
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckMatchTypeOf))
+      if (!OTC.processMatchPattern(*IP))
         return false;
     }
   }
 
-  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
-    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
-    const auto OpName = Ty.getTypeOfOpName();
-    if (MatchOpTable.lookup(OpName).Found)
-      return true;
-
-    PrintError("'" + OpName + "' ('" + Ty.str() +
-               "') does not refer to a matched operand!");
-    return false;
-  };
-
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckApplyTypeOf))
+      if (!OTC.processApplyPattern(*IP))
         return false;
     }
   }
 
-  OTC.setAllOperandTypes();
+  OTC.propagateAndInferTypes();
 
   // Always check this after in case inference adds some special types to the
   // match patterns.
@@ -2631,7 +3135,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
   // untyped immediate, e.g. 0
   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(IntImm->getValue(), Name, /*Type=*/nullptr);
+    IP.addOperand(IntImm->getValue(), Name, PatternType());
     return true;
   }
 
@@ -2641,13 +3145,9 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    PatternType ImmTy(TyDef);
-    if (!ImmTy.isValidType()) {
-      PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + TyDef->getName() + "' is not a ValueType or " +
-                 SpecialTyClassName);
+    auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef);
+    if (!ImmTy)
       return false;
-    }
 
     if (!IP.hasAllDefs()) {
       PrintError("out operand of '" + IP.getInstName() +
@@ -2660,7 +3160,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(Val->getValue(), Name, ImmTy);
+    IP.addOperand(Val->getValue(), Name, *ImmTy);
     return true;
   }
 
@@ -2672,20 +3172,17 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    PatternType Ty(Def);
-    if (!Ty.isValidType()) {
-      PrintError("invalid operand type: '" + Def->getName() +
-                 "' is not a ValueType");
+    auto Ty = PatternType::get(RuleDef.getLoc(), Def);
+    if (!Ty)
       return false;
-    }
-    IP.addOperand(OpName->getAsUnquotedString(), Ty);
+    IP.addOperand(OpName->getAsUnquotedString(), *Ty);
     return true;
   }
 
   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
   if (isa<UnsetInit>(OpInit)) {
     assert(OpName && "Unset w/ no OpName?");
-    IP.addOperand(OpName->getAsUnquotedString(), /*Type=*/nullptr);
+    IP.addOperand(OpName->getAsUnquotedString(), PatternType());
     return true;
   }
 



More information about the llvm-commits mailing list