[llvm] [TableGen][GlobalISel] Add rule-wide type inference (PR #66377)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 31 02:19:08 PDT 2023


https://github.com/Pierre-vh updated https://github.com/llvm/llvm-project/pull/66377

>From 7e044c27f2258b1346eebb3fbcddc556b790d7fa Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Thu, 14 Sep 2023 15:52:29 +0200
Subject: [PATCH] [TableGen][GlobalISel] Add rule-wide type inference

The inference is trivial and leverages the MCOI OperandTypes encoded in
CodeGenInstructions to infer types across patterns in a CombineRule. It's
thus very limited and only supports CodeGenInstructions (but that's the main
use case so it's fine).

We only try to infer untyped operands in apply patterns when they're temp
reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
a named operand from a match pattern.
---
 llvm/include/llvm/Target/GenericOpcodes.td    |   7 +
 .../include/llvm/Target/GlobalISel/Combine.td |   2 +-
 .../pattern-errors.td                         |   4 +-
 .../type-inference.td                         |  75 ++
 .../typeof-errors.td                          |   7 +-
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 748 +++++++++++++++---
 6 files changed, 713 insertions(+), 130 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td

diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index a1afc3b8042c284..9a9c09d3c20d612 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -17,6 +17,10 @@
 
 class GenericInstruction : StandardPseudoInstruction {
   let isPreISelOpcode = true;
+
+  // When all variadic ops share a type with another operand,
+  // this is the type they share. Used by MIR patterns type inference.
+  TypedOperand variadicOpsType = ?;
 }
 
 // Provide a variant of an instruction with the same operands, but
@@ -1228,6 +1232,7 @@ def G_UNMERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst0, variable_ops);
   let InOperandList = (ins type1:$src);
   let hasSideEffects = false;
+  let variadicOpsType = type0;
 }
 
 // Insert a smaller register into a larger one at the specified bit-index.
@@ -1245,6 +1250,7 @@ def G_MERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Create a vector from multiple scalar registers. No implicit
@@ -1254,6 +1260,7 @@ def G_BUILD_VECTOR : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Like G_BUILD_VECTOR, but truncates the larger operand types to fit the
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 63c485a5a6c6070..ee0209eb9e5593a 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -796,7 +796,7 @@ def trunc_shift: GICombineRule <
 def mul_by_neg_one: GICombineRule <
   (defs root:$dst),
   (match (G_MUL $dst, $x, -1)),
-  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  (apply (G_SUB $dst, 0, $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
index 48a06474da78a10..77321a3f2b59dbe 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
@@ -151,7 +151,7 @@ def bad_imm_too_many_args : GICombineRule<
   (match (COPY $x, (i32 0, 0)):$d),
   (apply (COPY $x, $b):$d)>;
 
-// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)', 'COPY' is not a ValueType
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: unknown type 'COPY'
 // CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(COPY ?:$x, (COPY 0))
 def bad_imm_not_a_valuetype : GICombineRule<
   (defs root:$a),
@@ -186,7 +186,7 @@ def expected_op_name : GICombineRule<
   (match (G_FNEG $x, i32)),
   (apply (COPY $x, (i32 0)))>;
 
-// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: invalid operand type: 'not_a_type' is not a ValueType
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: unknown type 'not_a_type'
 // CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_FNEG ?:$x, not_a_type:$y)'
 def not_a_type;
 def bad_mo_type_not_a_valuetype : GICombineRule<
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
new file mode 100644
index 000000000000000..7ed14dd5e6cc0eb
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
@@ -0,0 +1,75 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -gicombiner-debug-typeinfer -combiners=MyCombiner %s 2>&1 | \
+// RUN: FileCheck %s
+
+// Checks reasoning of the inference rules.
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK:      Rule Operand Type Equivalence Classes for inference_mul_by_neg_one:
+// CHECK-NEXT:         __inference_mul_by_neg_one_match_0:             [dst, x]
+// CHECK-NEXT:         __inference_mul_by_neg_one_apply_0:             [dst, x]
+// CHECK-NEXT: (merging [dst, x] | [dst, x])
+// CHECK-NEXT: Result: [dst, x]
+// CHECK-NEXT: INFER: imm 0 -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule inference_mul_by_neg_one after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__inference_mul_by_neg_one_apply_0 G_SUB operands:[<def>$dst, (GITypeOf<$x> 0), $x])
+def inference_mul_by_neg_one: GICombineRule <
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, 0, $x))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_complex_tempreg:
+// CHECK-NEXT:         __infer_complex_tempreg_match_0:                [dst]   [x, y, z]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_0:                [tmp2]  [x, y]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_1:                [tmp, tmp2]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_2:                [dst, tmp]
+// CHECK-NEXT: (merging [dst] | [dst, tmp])
+// CHECK-NEXT: (merging [x, y, z] | [x, y])
+// CHECK-NEXT: (merging [tmp2] | [tmp, tmp2])
+// CHECK-NEXT: (merging [dst, tmp] | [tmp2, tmp])
+// CHECK-NEXT: Result: [dst, tmp, tmp2] [x, y, z]
+// CHECK-NEXT: INFER: MachineOperand $tmp2 -> GITypeOf<$dst>
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst>
+// CHECK-NEXT: Apply patterns for rule infer_complex_tempreg after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_0 G_BUILD_VECTOR operands:[<def>GITypeOf<$dst>:$tmp2, $x, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_1 G_FNEG operands:[<def>GITypeOf<$dst>:$tmp, GITypeOf<$dst>:$tmp2])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_2 G_FNEG operands:[<def>$dst, GITypeOf<$dst>:$tmp])
+def infer_complex_tempreg: GICombineRule <
+  (defs root:$dst),
+  (match (G_MERGE_VALUES $dst, $x, $y, $z)),
+  (apply (G_BUILD_VECTOR $tmp2, $x, $y),
+         (G_FNEG $tmp, $tmp2),
+         (G_FNEG $dst, $tmp))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_variadic_outs:
+// CHECK-NEXT:         __infer_variadic_outs_match_0:          [x, y]  [vec]
+// CHECK-NEXT:         __infer_variadic_outs_match_1:          [dst, x]
+// CHECK-NEXT:         __infer_variadic_outs_apply_0:          [tmp, y]
+// CHECK-NEXT:         __infer_variadic_outs_apply_1:  (empty)
+// CHECK-NEXT: (merging [x, y] | [dst, x])
+// CHECK-NEXT: (merging [x, y, dst] | [tmp, y])
+// CHECK-NEXT: Result: [x, y, dst, tmp] [vec]
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule infer_variadic_outs after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_0 G_FNEG operands:[<def>GITypeOf<$x>:$tmp, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_1 COPY operands:[<def>$dst, GITypeOf<$x>:$tmp])
+def infer_variadic_outs: GICombineRule <
+  (defs root:$dst),
+  (match  (G_UNMERGE_VALUES $x, $y, $vec),
+          (G_FNEG $dst, $x)),
+  (apply (G_FNEG $tmp, $y),
+         (COPY $dst, $tmp))
+>;
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  inference_mul_by_neg_one,
+  infer_complex_tempreg,
+  infer_variadic_outs
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
index 6040d6def449766..b86b4ec19564488 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -8,7 +8,8 @@ include "llvm/Target/GlobalISel/Combine.td"
 def MyTargetISA : InstrInfo;
 def MyTarget : Target { let InstructionSet = MyTargetISA; }
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ANYEXT ?:$dst, (anonymous_
 def NoDollarSign : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, $src)),
@@ -47,7 +48,9 @@ def InferredUseInMatch : GICombineRule<
   (match (G_ZEXT $dst, $src)),
   (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: conflicting types for operand 'src': 'i32' vs 'GITypeOf<$dst>'
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: 'src' seen with type 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: 'src' seen with type 'i32' in '__InferenceConflict_match_0'
 def InferenceConflict : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, i32:$src)),
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 0c7b33a7b9d889d..82df0c5e58f041b 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -37,6 +37,8 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
@@ -68,6 +70,9 @@ cl::opt<bool> DebugCXXPreds(
     "gicombiner-debug-cxxpreds",
     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
     cl::cat(GICombinerEmitterCat));
+cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
+                             cl::desc("Print type inference debug logs"),
+                             cl::cat(GICombinerEmitterCat));
 
 constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
@@ -125,6 +130,20 @@ template <typename Container> auto values(Container &&C) {
   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
 }
 
+template <typename SetTy> bool doSetsIntersect(const SetTy &A, const SetTy &B) {
+  for (const auto &Elt : A) {
+    if (B.contains(Elt))
+      return true;
+  }
+  return false;
+}
+
+template <class Ty>
+void setVectorUnion(SetVector<Ty> &S1, const SetVector<Ty> &S2) {
+  for (auto SI = S2.begin(), SE = S2.end(); SI != SE; ++SI)
+    S1.insert(*SI);
+}
+
 //===- MatchData Handling -------------------------------------------------===//
 
 /// Represents MatchData defined by the match stage and required by the apply
@@ -298,23 +317,30 @@ CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 ///   - Special types, e.g. GITypeOf
 class PatternType {
 public:
-  PatternType() = default;
-  PatternType(const Record *R) : R(R) {}
+  enum PTKind : uint8_t {
+    PT_None,
+
+    PT_ValueType,
+    PT_TypeOf,
+  };
 
-  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+  PatternType() : Kind(PT_None), Data() {}
 
-  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
-  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
-  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+  static std::optional<PatternType> get(ArrayRef<SMLoc> DiagLoc,
+                                        const Record *R);
+  static PatternType getTypeOf(StringRef OpName);
+
+  bool isNone() const { return Kind == PT_None; }
+  bool isLLT() const { return Kind == PT_ValueType; }
+  bool isSpecial() const { return isTypeOf(); }
+  bool isTypeOf() const { return Kind == PT_TypeOf; }
 
   StringRef getTypeOfOpName() const;
   LLTCodeGen getLLTCodeGen() const;
 
-  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
-
   LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
 
-  explicit operator bool() const { return R != nullptr; }
+  explicit operator bool() const { return !isNone(); }
 
   bool operator==(const PatternType &Other) const;
   bool operator!=(const PatternType &Other) const { return !operator==(Other); }
@@ -322,26 +348,66 @@ class PatternType {
   std::string str() const;
 
 private:
-  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+  PatternType(PTKind Kind) : Kind(Kind), Data() {}
+
+  PTKind Kind;
+  union DataT {
+    DataT() : Str() {}
 
-  const Record *R = nullptr;
+    /// PT_ValueType -> ValueType Def.
+    const Record *Def;
+
+    /// PT_TypeOf -> Operand name (without the '$')
+    StringRef Str;
+  } Data;
 };
 
+std::optional<PatternType> PatternType::get(ArrayRef<SMLoc> DiagLoc,
+                                            const Record *R) {
+  assert(R);
+  if (R->isSubClassOf("ValueType")) {
+    PatternType PT(PT_ValueType);
+    PT.Data.Def = R;
+    return PT;
+  }
+
+  if (R->isSubClassOf(TypeOfClassName)) {
+    auto RawOpName = R->getValueAsString("OpName");
+    if (!RawOpName.starts_with("$")) {
+      PrintError(DiagLoc, "invalid operand name format '" + RawOpName +
+                              "' in " + TypeOfClassName +
+                              ": expected '$' followed by an operand name");
+      return std::nullopt;
+    }
+
+    PatternType PT(PT_TypeOf);
+    PT.Data.Str = RawOpName.drop_front(1);
+    return PT;
+  }
+
+  PrintError(DiagLoc, "unknown type '" + R->getName() + "'");
+  return std::nullopt;
+}
+
+PatternType PatternType::getTypeOf(StringRef OpName) {
+  PatternType PT(PT_TypeOf);
+  PT.Data.Str = OpName;
+  return PT;
+}
+
 StringRef PatternType::getTypeOfOpName() const {
   assert(isTypeOf());
-  StringRef Name = getRawOpName();
-  Name.consume_front("$");
-  return Name;
+  return Data.Str;
 }
 
 LLTCodeGen PatternType::getLLTCodeGen() const {
   assert(isLLT());
-  return *MVTToLLT(getValueType(R));
+  return *MVTToLLT(getValueType(Data.Def));
 }
 
 LLTCodeGenOrTempType
 PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
-  assert(isValidType());
+  assert(!isNone());
 
   if (isLLT())
     return getLLTCodeGen();
@@ -351,50 +417,31 @@ PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
   return OM.getTempTypeIdx(RM);
 }
 
-bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
-  if (!isTypeOf())
-    return true;
-
-  auto RawOpName = getRawOpName();
-  if (RawOpName.starts_with("$"))
-    return true;
-
-  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
-                          TypeOfClassName +
-                          ": expected '$' followed by an operand name");
-  return false;
-}
-
 bool PatternType::operator==(const PatternType &Other) const {
-  if (R == Other.R) {
-    if (R && R->getName() != Other.R->getName()) {
-      dbgs() << "Same ptr but: " << R->getName() << " and "
-             << Other.R->getName() << "?\n";
-      assert(false);
-    }
+  if (Kind != Other.Kind)
+    return false;
+
+  switch (Kind) {
+  case PT_None:
     return true;
+  case PT_ValueType:
+    return Data.Def == Other.Data.Def;
+  case PT_TypeOf:
+    return Data.Str == Other.Data.Str;
   }
 
-  if (isTypeOf() && Other.isTypeOf())
-    return getTypeOfOpName() == Other.getTypeOfOpName();
-
-  return false;
+  llvm_unreachable("Unknown Type Kind");
 }
 
 std::string PatternType::str() const {
-  if (!R)
+  switch (Kind) {
+  case PT_None:
     return "";
-
-  if (!isValidType())
-    return "<invalid>";
-
-  if (isLLT())
-    return R->getName().str();
-
-  assert(isSpecial());
-
-  if (isTypeOf())
+  case PT_ValueType:
+    return Data.Def->getName().str();
+  case PT_TypeOf:
     return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+  }
 
   llvm_unreachable("Unknown type!");
 }
@@ -607,14 +654,10 @@ class InstructionOperand {
   using IntImmTy = int64_t;
 
   InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
-      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {}
 
   InstructionOperand(StringRef Name, PatternType Type)
-      : Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Name(insertStrRef(Name)), Type(Type) {}
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -638,7 +681,6 @@ class InstructionOperand {
 
   void setType(PatternType NewType) {
     assert((!Type || (Type == NewType)) && "Overwriting type!");
-    assert(NewType.isValidType());
     Type = NewType;
   }
   PatternType getType() const { return Type; }
@@ -809,12 +851,10 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
 /// Maps InstructionPattern operands to their definitions. This allows us to tie
 /// different patterns of a (apply), (match) or (patterns) set of patterns
 /// together.
-template <typename DefTy = InstructionPattern> class OperandTable {
+class OperandTable {
 public:
-  static_assert(std::is_base_of_v<InstructionPattern, DefTy>,
-                "DefTy should be a derived class from InstructionPattern");
-
-  bool addPattern(DefTy *P, function_ref<void(StringRef)> DiagnoseRedef) {
+  bool addPattern(InstructionPattern *P,
+                  function_ref<void(StringRef)> DiagnoseRedef) {
     for (const auto &Op : P->named_operands()) {
       StringRef OpName = Op.getOperandName();
 
@@ -843,10 +883,10 @@ template <typename DefTy = InstructionPattern> class OperandTable {
 
   struct LookupResult {
     LookupResult() = default;
-    LookupResult(DefTy *Def) : Found(true), Def(Def) {}
+    LookupResult(InstructionPattern *Def) : Found(true), Def(Def) {}
 
     bool Found = false;
-    DefTy *Def = nullptr;
+    InstructionPattern *Def = nullptr;
 
     bool isLiveIn() const { return Found && !Def; }
   };
@@ -857,7 +897,9 @@ template <typename DefTy = InstructionPattern> class OperandTable {
     return LookupResult();
   }
 
-  DefTy *getDef(StringRef OpName) const { return lookup(OpName).Def; }
+  InstructionPattern *getDef(StringRef OpName) const {
+    return lookup(OpName).Def;
+  }
 
   void print(raw_ostream &OS, StringRef Name = "",
              StringRef Indent = "") const {
@@ -887,7 +929,7 @@ template <typename DefTy = InstructionPattern> class OperandTable {
   void dump() const { print(dbgs()); }
 
 private:
-  StringMap<DefTy *> Table;
+  StringMap<InstructionPattern *> Table;
 };
 
 //===- CodeGenInstructionPattern ------------------------------------------===//
@@ -956,62 +998,76 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 ///
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
-/// setAllOperandTypes.
+/// propagateTypes.
 ///
 /// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P,
+  /// Step 1: Check each pattern one by one. All patterns that pass through here
+  /// are added to a common worklist so propagateTypes can access them.
+  bool check(InstructionPattern &P,
              std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
-  void setAllOperandTypes();
+  /// Step 2: Propagate all types. e.g. if one use of "$a" has type i32, make
+  /// all uses of "$a" have type i32.
+  void propagateTypes();
+
+protected:
+  ArrayRef<SMLoc> DiagLoc;
 
 private:
+  using InconsistentTypeDiagFn = std::function<void()>;
+
+  void PrintSeenWithTypeIn(InstructionPattern &P, StringRef OpName,
+                           PatternType Ty) const {
+    PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() +
+                           "' in '" + P.getName() + "'");
+  }
+
   struct OpTypeInfo {
     PatternType Type;
-    InstructionPattern *TypeSrc = nullptr;
+    InconsistentTypeDiagFn PrintTypeSrcNote = []() {};
   };
 
-  ArrayRef<SMLoc> DiagLoc;
   StringMap<OpTypeInfo> Types;
 
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
 bool OperandTypeChecker::check(
-    InstructionPattern *P,
+    InstructionPattern &P,
     std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
-  Pats.push_back(P);
+  Pats.push_back(&P);
 
-  for (auto &Op : P->operands()) {
+  for (auto &Op : P.operands()) {
     const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    if (!Ty.checkSemantics(DiagLoc))
-      return false;
-
     if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
       return false;
 
     if (!Op.isNamedOperand())
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    StringRef OpName = Op.getOperandName();
+    auto &Info = Types[OpName];
     if (!Info.Type) {
       Info.Type = Ty;
-      Info.TypeSrc = P;
+      Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() {
+        PrintSeenWithTypeIn(P, OpName, Ty);
+      };
       continue;
     }
 
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
-                              Op.getOperandName() + "': first seen with '" +
-                              Info.Type.str() + "' in '" +
-                              Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty.str() + "' in '" + P->getName() + "'");
+                              Op.getOperandName() + "': '" + Info.Type.str() +
+                              "' vs '" + Ty.str() + "'");
+      PrintSeenWithTypeIn(P, OpName, Ty);
+      Info.PrintTypeSrcNote();
       return false;
     }
   }
@@ -1019,7 +1075,7 @@ bool OperandTypeChecker::check(
   return true;
 }
 
-void OperandTypeChecker::setAllOperandTypes() {
+void OperandTypeChecker::propagateTypes() {
   for (auto *Pat : Pats) {
     for (auto &Op : Pat->named_operands()) {
       if (auto &Info = Types[Op.getOperandName()]; Info.Type)
@@ -1073,7 +1129,7 @@ class PatFrag {
   /// Each argument to the `pattern` DAG operator is parsed into a Pattern
   /// instance.
   struct Alternative {
-    OperandTable<> OpTable;
+    OperandTable OpTable;
     SmallVector<std::unique_ptr<Pattern>, 4> Pats;
   };
 
@@ -1297,11 +1353,11 @@ bool PatFrag::checkSemantics() {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP, CheckTypeOf))
+        if (!OTC.check(*IP, CheckTypeOf))
           return false;
       }
     }
-    OTC.setAllOperandTypes();
+    OTC.propagateTypes();
   }
 
   return true;
@@ -1639,6 +1695,471 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry {
   }
 };
 
+//===- CombineRuleOperandTypeChecker --------------------------------------===//
+
+/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
+/// On top of doing the same things as OperandTypeChecker, this also attempts to
+/// infer as many types as possible for temporary register defs & immediates in
+/// apply patterns.
+///
+/// The inference is trivial and leverages the MCOI OperandTypes encoded in
+/// CodeGenInstructions to infer types across patterns in a CombineRule. It's
+/// thus very limited and only supports CodeGenInstructions (but that's the main
+/// use case so it's fine).
+///
+/// We only try to infer untyped operands in apply patterns when they're temp
+/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
+/// a named operand from a match pattern.
+class CombineRuleOperandTypeChecker : private OperandTypeChecker {
+public:
+  CombineRuleOperandTypeChecker(const Record &RuleDef,
+                                const OperandTable &MatchOpTable)
+      : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
+        MatchOpTable(MatchOpTable) {}
+
+  /// Records and checks a 'match' pattern.
+  bool processMatchPattern(InstructionPattern &P);
+
+  /// Records and checks an 'apply' pattern.
+  bool processApplyPattern(InstructionPattern &P);
+
+  /// Propagates types, then perform type inference and do a second round of
+  /// propagation in the apply patterns only if any types were inferred.
+  void propagateAndInferTypes();
+
+private:
+  /// TypeEquivalenceClasses are groups of operands of an instruction that share
+  /// a common type.
+  ///
+  /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
+  /// d have the same type too. b/c and a/d don't have to have the same type,
+  /// though.
+  ///
+  /// NOTE: We use a SetVector, not a Set. This is to guarantee a stable
+  /// iteration order which is important because:
+  ///   - During inference, we iterate that set and pick the first suitable
+  ///   candidate. Using a normal set could make inference inconsistent across
+  ///   runs if the Set uses the StringRef ptr to cache values.
+  ///   - We print this set if DebugInfer is set, and we don't want our tests to
+  ///   fail randomly due to the Set's iteration order changing.
+  using TypeEquivalenceClasses = std::vector<SetVector<StringRef>>;
+
+  static std::string toString(const SetVector<StringRef> &EqClass) {
+    return "[" + join(EqClass, ", ") + "]";
+  }
+
+  /// \returns true for `OPERAND_GENERIC_` 0 through 5.
+  /// These are the MCOI types that can be registers. The other MCOI types are
+  /// either immediates, or fancier operands used only post-ISel, so we don't
+  /// care about them for combiners.
+  static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
+    // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
+    // OperandTypes are either never used in gMIR, or not relevant (e.g.
+    // OPERAND_GENERIC_IMM, which is definitely never a register).
+    return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
+  }
+
+  /// Finds the "MCOI::"" operand types for each operand of \p CGP.
+  ///
+  /// This is a bit trickier than it looks because we need to handle variadic
+  /// in/outs.
+  ///
+  /// e.g. for
+  ///   (G_BUILD_VECTOR $vec, $x, $y) ->
+  ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  ///    MCOI::OPERAND_GENERIC_1]
+  ///
+  /// For unknown types (which can happen in variadics where varargs types are
+  /// inconsistent), a unique name is given, e.g. "unknown_type_0".
+  static std::vector<std::string>
+  getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
+
+  /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
+  void getInstEqClasses(const InstructionPattern &P,
+                        TypeEquivalenceClasses &OutTECs) const;
+
+  /// Calculates the TypeEquivalenceClasses for each instruction, then merges
+  /// them into a common set of TypeEquivalenceClasses for the whole rule.
+  ///
+  /// This works by repeatedly merging intersecting type equivalence classes
+  /// until no more merging occurs.
+  ///
+  /// This essentially applies the "transitive" part of type inference. Let's
+  /// take the following equivalence classes:
+  ///   inst0: [a, b], [c, d]
+  ///   inst1: [b, c]
+  ///
+  /// If we see inst0 alone, we can't say that a and d have the same type -
+  /// they're not in the same equivalence classes. However if we just use logic,
+  /// we can say: "a == d because a == b, b == c and c == d".
+  ///
+  /// Merging condenses that information into a single big equivalence class
+  /// which can be looked at alone to make the same deduction.
+  ///   rule: [a, b, c, d]
+  TypeEquivalenceClasses getRuleEqClasses() const;
+
+  /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
+  /// TECs.
+  ///
+  /// This is achieved by trying to find a named operand in \p IP that shares
+  /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
+  /// operand instead.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferImmediateType(const InstructionPattern &IP,
+                                 unsigned ImmOpIdx,
+                                 const TypeEquivalenceClasses &TECs) const;
+
+  /// Looks inside \p TECs to infer \p OpName's type.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferNamedOperandType(const InstructionPattern &IP,
+                                    StringRef OpName,
+                                    const TypeEquivalenceClasses &TECs) const;
+
+  const Record &RuleDef;
+  SmallVector<InstructionPattern *, 8> MatchPats;
+  SmallVector<InstructionPattern *, 8> ApplyPats;
+
+  const OperandTable &MatchOpTable;
+};
+
+bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
+  MatchPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [](const auto &) {
+    // GITypeOf in 'match' is currently always rejected by the
+    // CombineRuleBuilder after inference is done.
+    return true;
+  });
+}
+
+bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
+  ApplyPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
+                                     "') does not refer to a matched operand!");
+    return false;
+  });
+}
+
+void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
+  /// First step here is to propagate types using the OperandTypeChecker. That
+  /// way we ensure all uses of a given register have consistent types.
+  propagateTypes();
+
+  /// Build the TypeEquivalenceClasses for the whole rule.
+  const TypeEquivalenceClasses TECs = getRuleEqClasses();
+
+  /// Look at the apply patterns and find operands that need to be
+  /// inferred. We then try to find an equivalence class that they're a part of
+  /// and select the best operand to use for the `GITypeOf` type. We prioritize
+  /// defs of matched instructions because those are guaranteed to be registers.
+  bool InferredAny = false;
+  for (auto *Pat : ApplyPats) {
+    for (unsigned K = 0; K < Pat->operands_size(); ++K) {
+      auto &Op = Pat->getOperand(K);
+
+      // We only want to take a look at untyped defs or immediates.
+      if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
+        continue;
+
+      // Infer defs & named immediates.
+      if (Op.isDef() || Op.isNamedImmediate()) {
+        // Check it's not a redefinition of a matched operand.
+        // In such cases, inference is not necessary because we just copy
+        // operands and don't create temporary registers.
+        if (MatchOpTable.lookup(Op.getOperandName()).Found)
+          continue;
+
+        // Inference is needed here, so try to do it.
+        if (PatternType Ty =
+                inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+
+        continue;
+      }
+
+      // Infer immediates
+      if (Op.hasImmValue()) {
+        if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+        continue;
+      }
+    }
+  }
+
+  // If we've inferred any types, we want to propagate them across the apply
+  // patterns. Type inference only adds GITypeOf types that point to Matched
+  // operands, so we definitely don't want to propagate types into the match
+  // patterns as well, otherwise bad things happen.
+  if (InferredAny) {
+    OperandTypeChecker OTC(RuleDef.getLoc());
+    for (auto *Pat : ApplyPats) {
+      if (!OTC.check(*Pat, [&](const auto &) { return true; }))
+        PrintFatalError(RuleDef.getLoc(),
+                        "OperandTypeChecker unexpectedly failed on '" +
+                            Pat->getName() + "' during Type Inference");
+    }
+    OTC.propagateTypes();
+
+    if (DebugTypeInfer) {
+      errs() << "Apply patterns for rule " << RuleDef.getName()
+             << " after inference:\n";
+      for (auto *Pat : ApplyPats) {
+        Pat->print(errs(), /*PrintName*/ true);
+        errs() << "\n";
+      }
+      errs() << "\n";
+    }
+  }
+}
+
+PatternType CombineRuleOperandTypeChecker::inferImmediateType(
+    const InstructionPattern &IP, unsigned ImmOpIdx,
+    const TypeEquivalenceClasses &TECs) const {
+  // We can only infer CGPs.
+  const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGP)
+    return {};
+
+  // For CGPs, we try to infer immediates by trying to infer another named
+  // operand that shares its type.
+  //
+  // e.g.
+  //    Pattern: G_BUILD_VECTOR $x, $y, 0
+  //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  //              MCOI::OPERAND_GENERIC_1]
+  //    $y has the same type as 0, so we can infer $y and get the type 0 should
+  //    have.
+
+  // We infer immediates by looking for a named operand that shares the same
+  // MCOI type.
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  StringRef ImmOpTy = MCOITypes[ImmOpIdx];
+
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
+    if (Idx != ImmOpIdx && Ty == ImmOpTy) {
+      const auto &Op = IP.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      // Named operand with the same name, try to infer that.
+      if (PatternType InferTy =
+              inferNamedOperandType(IP, Op.getOperandName(), TECs))
+        return InferTy;
+    }
+  }
+
+  return {};
+}
+
+PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
+    const InstructionPattern &IP, StringRef OpName,
+    const TypeEquivalenceClasses &TECs) const {
+  // This is the simplest possible case, we just need to find a TEC that
+  // contains OpName.
+  for (const auto &TEC : TECs) {
+    if (!TEC.contains(OpName))
+      continue;
+
+    // This TEC mentions the operand. Look at all other operands in this TEC and
+    // try to find a suitable one.
+
+    // First, check for a def of a matched pattern. This is guaranteed to always
+    // be a register so we can blindly use that.
+    StringRef GoodOpName;
+    for (const auto &EqOp : TEC) {
+      if (EqOp == OpName)
+        continue;
+
+      const auto LookupRes = MatchOpTable.lookup(EqOp);
+      if (LookupRes.Def) // Favor defs
+        return PatternType::getTypeOf(EqOp);
+
+      // Otherwise just save this in case we don't find any def.
+      if (GoodOpName.empty() && LookupRes.Found)
+        GoodOpName = EqOp;
+    }
+
+    if (!GoodOpName.empty())
+      return PatternType::getTypeOf(GoodOpName);
+
+    // No good operand found, give up.
+    return {};
+  }
+
+  return {};
+}
+
+std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
+    const CodeGenInstructionPattern &CGP) {
+  // FIXME?: Should we cache this? We call it twice when inferring immediates.
+
+  static unsigned UnknownTypeIdx = 0;
+
+  std::vector<std::string> OpTypes;
+  auto &CGI = CGP.getInst();
+  Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
+                          ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
+                          : nullptr;
+  std::string VarArgsTyName =
+      VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
+                : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
+
+  // First, handle defs.
+  for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
+    OpTypes.push_back(CGI.Operands[K].OperandType);
+
+  // Then, handle variadic defs if there are any.
+  if (CGP.hasVariadicDefs()) {
+    for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
+      OpTypes.push_back(VarArgsTyName);
+  }
+
+  // If we had variadic defs, the op idx in the pattern won't match the op idx
+  // in the CGI anymore.
+  int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
+  assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
+
+  // Handle all remaining use operands, including variadic ones.
+  for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
+    unsigned CGIOpIdx = K + CGIOpOffset;
+    if (CGIOpIdx >= CGI.Operands.size()) {
+      assert(CGP.isVariadic());
+      OpTypes.push_back(VarArgsTyName);
+    } else {
+      OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
+    }
+  }
+
+  assert(OpTypes.size() == CGP.operands_size());
+  return OpTypes;
+}
+
+void CombineRuleOperandTypeChecker::getInstEqClasses(
+    const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
+  // Determine the TypeEquivalenceClasses by:
+  //    - Getting the MCOI Operand Types.
+  //    - Creating a Map of MCOI Type -> [Operand Indexes]
+  //    - Iterating over the map, filtering types we don't like, and just adding
+  //      the array of Operand Indexes to \p OutTECs.
+
+  // We can only do this on CodeGenInstructions. Other InstructionPatterns have
+  // no type inference information associated with them.
+  // TODO: Could we add some inference information to builtins at least? e.g.
+  // ReplaceReg should always replace with a reg of the same type, for instance.
+  // Though, those patterns are often used alone so it might not be worth the
+  // trouble to infer their types.
+  auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
+  if (!CGP)
+    return;
+
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  assert(MCOITypes.size() == P.operands_size());
+
+  DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes))
+    TyToOpIdx[Ty].push_back(Idx);
+
+  const unsigned FirstNewTEC = OutTECs.size();
+  for (const auto &[Ty, Idxs] : TyToOpIdx) {
+    if (!canMCOIOperandTypeBeARegister(Ty))
+      continue;
+
+    SetVector<StringRef> OpNames;
+    // We only collect named operands.
+    for (unsigned Idx : Idxs) {
+      const auto &Op = P.getOperand(Idx);
+      if (Op.isNamedOperand())
+        OpNames.insert(Op.getOperandName());
+    }
+    OutTECs.emplace_back(std::move(OpNames));
+  }
+
+  if (DebugTypeInfer) {
+    errs() << "\t" << P.getName() << ":\t";
+    if (FirstNewTEC == OutTECs.size())
+      errs() << "(empty)";
+    else {
+      for (unsigned K = FirstNewTEC; K < OutTECs.size(); ++K)
+        errs() << "\t" << toString(OutTECs[K]);
+    }
+    errs() << "\n";
+  }
+}
+
+CombineRuleOperandTypeChecker::TypeEquivalenceClasses
+CombineRuleOperandTypeChecker::getRuleEqClasses() const {
+  StringMap<unsigned> OpNameToEqClassIdx;
+  TypeEquivalenceClasses TECs;
+
+  if (DebugTypeInfer)
+    errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
+           << ":\n";
+
+  for (const auto *Pat : MatchPats)
+    getInstEqClasses(*Pat, TECs);
+  for (const auto *Pat : ApplyPats)
+    getInstEqClasses(*Pat, TECs);
+
+  bool Merged;
+  do {
+    Merged = false;
+    for (auto &X : TECs) {
+      if (X.empty()) // Already merged
+        continue;
+
+      for (auto &Y : TECs) {
+        if (&X == &Y || Y.empty()) // Same set or already merged.
+          continue;
+
+        if (doSetsIntersect(X, Y)) {
+          if (DebugTypeInfer)
+            errs() << "(merging " << toString(X) << " | " << toString(Y)
+                   << ")\n";
+          setVectorUnion(X, Y);
+          Merged = true;
+
+          // To avoid invalidating iterators during iteration, we just clear the
+          // other set then clean it up later.
+          Y.clear();
+        }
+      }
+    }
+
+    // Remove empty sets.
+    erase_if(TECs, [&](auto &Set) { return Set.empty(); });
+  } while (Merged);
+
+  if (DebugTypeInfer) {
+    errs() << "Result: ";
+    if (TECs.empty())
+      errs() << "(empty)";
+    else {
+      for (const auto &TEC : TECs)
+        errs() << toString(TEC) << " ";
+    }
+    errs() << "\n";
+  }
+
+  return TECs;
+}
+
 //===- CombineRuleBuilder -------------------------------------------------===//
 
 /// Parses combine rule and builds a small intermediate representation to tie
@@ -1820,8 +2341,8 @@ class CombineRuleBuilder {
   PatternMap ApplyPats;
 
   /// Operand tables to tie match/apply patterns together.
-  OperandTable<> MatchOpTable;
-  OperandTable<> ApplyOpTable;
+  OperandTable MatchOpTable;
+  OperandTable ApplyOpTable;
 
   /// Set by findRoots.
   Pattern *MatchRoot = nullptr;
@@ -2102,39 +2623,23 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 }
 
 bool CombineRuleBuilder::typecheckPatterns() {
-  OperandTypeChecker OTC(RuleDef.getLoc());
-
-  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
-    // We'll reject those after we're done inferring
-    return true;
-  };
+  CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
 
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckMatchTypeOf))
+      if (!OTC.processMatchPattern(*IP))
         return false;
     }
   }
 
-  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
-    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
-    const auto OpName = Ty.getTypeOfOpName();
-    if (MatchOpTable.lookup(OpName).Found)
-      return true;
-
-    PrintError("'" + OpName + "' ('" + Ty.str() +
-               "') does not refer to a matched operand!");
-    return false;
-  };
-
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckApplyTypeOf))
+      if (!OTC.processApplyPattern(*IP))
         return false;
     }
   }
 
-  OTC.setAllOperandTypes();
+  OTC.propagateAndInferTypes();
 
   // Always check this after in case inference adds some special types to the
   // match patterns.
@@ -2630,7 +3135,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
   // untyped immediate, e.g. 0
   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(IntImm->getValue(), Name, /*Type=*/nullptr);
+    IP.addOperand(IntImm->getValue(), Name, PatternType());
     return true;
   }
 
@@ -2640,13 +3145,9 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    PatternType ImmTy(TyDef);
-    if (!ImmTy.isValidType()) {
-      PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + TyDef->getName() + "' is not a ValueType or " +
-                 SpecialTyClassName);
+    auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef);
+    if (!ImmTy)
       return false;
-    }
 
     if (!IP.hasAllDefs()) {
       PrintError("out operand of '" + IP.getInstName() +
@@ -2659,7 +3160,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(Val->getValue(), Name, ImmTy);
+    IP.addOperand(Val->getValue(), Name, *ImmTy);
     return true;
   }
 
@@ -2671,20 +3172,17 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    PatternType Ty(Def);
-    if (!Ty.isValidType()) {
-      PrintError("invalid operand type: '" + Def->getName() +
-                 "' is not a ValueType");
+    auto Ty = PatternType::get(RuleDef.getLoc(), Def);
+    if (!Ty)
       return false;
-    }
-    IP.addOperand(OpName->getAsUnquotedString(), Ty);
+    IP.addOperand(OpName->getAsUnquotedString(), *Ty);
     return true;
   }
 
   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
   if (isa<UnsetInit>(OpInit)) {
     assert(OpName && "Unset w/ no OpName?");
-    IP.addOperand(OpName->getAsUnquotedString(), /*Type=*/nullptr);
+    IP.addOperand(OpName->getAsUnquotedString(), PatternType());
     return true;
   }
 



More information about the llvm-commits mailing list