[llvm] 573fa77 - [TableGen][GlobalISel] Add rule-wide type inference (#66377)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 7 23:10:26 PST 2023


Author: Pierre van Houtryve
Date: 2023-11-08T08:10:22+01:00
New Revision: 573fa770d07bacd9276370454be32731fca99eea

URL: https://github.com/llvm/llvm-project/commit/573fa770d07bacd9276370454be32731fca99eea
DIFF: https://github.com/llvm/llvm-project/commit/573fa770d07bacd9276370454be32731fca99eea.diff

LOG: [TableGen][GlobalISel] Add rule-wide type inference (#66377)

The inference is trivial and leverages the MCOI OperandTypes encoded in
CodeGenInstructions to infer types across patterns in a CombineRule.
It's thus very limited and only supports CodeGenInstructions (but that's the
main use case so it's fine).

We only try to infer untyped operands in apply patterns when they're
temp reg defs, or immediates. Inference always outputs a `GITypeOf<$x>` where
$x is a named operand from a match pattern.

This allows us to drop the `GITypeOf` in most cases without any errors.

Added: 
    llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td

Modified: 
    llvm/include/llvm/Target/GenericOpcodes.td
    llvm/include/llvm/Target/GlobalISel/Combine.td
    llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
    llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
    llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index a1afc3b8042c284..9a9c09d3c20d612 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -17,6 +17,10 @@
 
 class GenericInstruction : StandardPseudoInstruction {
   let isPreISelOpcode = true;
+
+  // When all variadic ops share a type with another operand,
+  // this is the type they share. Used by MIR patterns type inference.
+  TypedOperand variadicOpsType = ?;
 }
 
 // Provide a variant of an instruction with the same operands, but
@@ -1228,6 +1232,7 @@ def G_UNMERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst0, variable_ops);
   let InOperandList = (ins type1:$src);
   let hasSideEffects = false;
+  let variadicOpsType = type0;
 }
 
 // Insert a smaller register into a larger one at the specified bit-index.
@@ -1245,6 +1250,7 @@ def G_MERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Create a vector from multiple scalar registers. No implicit
@@ -1254,6 +1260,7 @@ def G_BUILD_VECTOR : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Like G_BUILD_VECTOR, but truncates the larger operand types to fit the

diff  --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 63c485a5a6c6070..ee0209eb9e5593a 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -796,7 +796,7 @@ def trunc_shift: GICombineRule <
 def mul_by_neg_one: GICombineRule <
   (defs root:$dst),
   (match (G_MUL $dst, $x, -1)),
-  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  (apply (G_SUB $dst, 0, $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)

diff  --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
index 48a06474da78a10..4f705589b92e90b 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
@@ -151,7 +151,7 @@ def bad_imm_too_many_args : GICombineRule<
   (match (COPY $x, (i32 0, 0)):$d),
   (apply (COPY $x, $b):$d)>;
 
-// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)', 'COPY' is not a ValueType
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)': unknown type 'COPY'
 // CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(COPY ?:$x, (COPY 0))
 def bad_imm_not_a_valuetype : GICombineRule<
   (defs root:$a),
@@ -186,7 +186,7 @@ def expected_op_name : GICombineRule<
   (match (G_FNEG $x, i32)),
   (apply (COPY $x, (i32 0)))>;
 
-// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: invalid operand type: 'not_a_type' is not a ValueType
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: cannot parse operand type: unknown type 'not_a_type'
 // CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_FNEG ?:$x, not_a_type:$y)'
 def not_a_type;
 def bad_mo_type_not_a_valuetype : GICombineRule<

diff  --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
new file mode 100644
index 000000000000000..c9ffe4e7adb3db1
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
@@ -0,0 +1,68 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -gicombiner-debug-typeinfer -combiners=MyCombiner %s 2>&1 | \
+// RUN: FileCheck %s
+
+// Checks reasoning of the inference rules.
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK:      Rule Operand Type Equivalence Classes for inference_mul_by_neg_one:
+// CHECK-NEXT:         Groups for __inference_mul_by_neg_one_match_0:             [dst, x]
+// CHECK-NEXT:         Groups for __inference_mul_by_neg_one_apply_0:             [dst, x]
+// CHECK-NEXT: Final Type Equivalence Classes: [dst, x]
+// CHECK-NEXT: INFER: imm 0 -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule inference_mul_by_neg_one after inference:
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__inference_mul_by_neg_one_apply_0 G_SUB operands:[<def>$dst, (GITypeOf<$x> 0), $x])
+def inference_mul_by_neg_one: GICombineRule <
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, 0, $x))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_complex_tempreg:
+// CHECK-NEXT:         Groups for __infer_complex_tempreg_match_0:                [dst]   [x, y, z]
+// CHECK-NEXT:         Groups for __infer_complex_tempreg_apply_0:                [tmp2]  [x, y]
+// CHECK-NEXT:         Groups for __infer_complex_tempreg_apply_1:                [tmp, tmp2]
+// CHECK-NEXT:         Groups for __infer_complex_tempreg_apply_2:                [dst, tmp]
+// CHECK-NEXT: Final Type Equivalence Classes: [dst, tmp, tmp2] [x, y, z]
+// CHECK-NEXT: INFER: MachineOperand $tmp2 -> GITypeOf<$dst>
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst>
+// CHECK-NEXT: Apply patterns for rule infer_complex_tempreg after inference:
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_0 G_BUILD_VECTOR operands:[<def>GITypeOf<$dst>:$tmp2, $x, $y])
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_1 G_FNEG operands:[<def>GITypeOf<$dst>:$tmp, GITypeOf<$dst>:$tmp2])
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_2 G_FNEG operands:[<def>$dst, GITypeOf<$dst>:$tmp])
+def infer_complex_tempreg: GICombineRule <
+  (defs root:$dst),
+  (match (G_MERGE_VALUES $dst, $x, $y, $z)),
+  (apply (G_BUILD_VECTOR $tmp2, $x, $y),
+         (G_FNEG $tmp, $tmp2),
+         (G_FNEG $dst, $tmp))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_variadic_outs:
+// CHECK-NEXT:         Groups for  __infer_variadic_outs_match_0:          [x, y]  [vec]
+// CHECK-NEXT:         Groups for  __infer_variadic_outs_match_1:          [dst, x]
+// CHECK-NEXT:         Groups for  __infer_variadic_outs_apply_0:          [tmp, y]
+// CHECK-NEXT:         Groups for  __infer_variadic_outs_apply_1:
+// CHECK-NEXT: Final Type Equivalence Classes: [tmp, dst, x, y]  [vec]
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst>
+// CHECK-NEXT: Apply patterns for rule infer_variadic_outs after inference:
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__infer_variadic_outs_apply_0 G_FNEG operands:[<def>GITypeOf<$dst>:$tmp, $y])
+// CHECK-NEXT:   (CodeGenInstructionPattern name:__infer_variadic_outs_apply_1 COPY operands:[<def>$dst, GITypeOf<$dst>:$tmp])
+def infer_variadic_outs: GICombineRule <
+  (defs root:$dst),
+  (match  (G_UNMERGE_VALUES $x, $y, $vec),
+          (G_FNEG $dst, $x)),
+  (apply (G_FNEG $tmp, $y),
+         (COPY $dst, $tmp))
+>;
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  inference_mul_by_neg_one,
+  infer_complex_tempreg,
+  infer_variadic_outs
+]>;

diff  --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
index 6040d6def449766..4e182d555db3362 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -8,7 +8,8 @@ include "llvm/Target/GlobalISel/Combine.td"
 def MyTargetISA : InstrInfo;
 def MyTarget : Target { let InstructionSet = MyTargetISA; }
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(anonymous_7029 0)': invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ANYEXT ?:$dst, (anonymous_
 def NoDollarSign : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, $src)),
@@ -47,7 +48,9 @@ def InferredUseInMatch : GICombineRule<
   (match (G_ZEXT $dst, $src)),
   (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: conflicting types for operand 'src': 'i32' vs 'GITypeOf<$dst>'
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: 'src' seen with type 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: 'src' seen with type 'i32' in '__InferenceConflict_match_0'
 def InferenceConflict : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, i32:$src)),

diff  --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 0c7b33a7b9d889d..b4b3db70c076a1a 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -35,6 +35,7 @@
 #include "GlobalISelMatchTableExecutorEmitter.h"
 #include "SubtargetFeatureInfo.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Statistic.h"
@@ -68,6 +69,9 @@ cl::opt<bool> DebugCXXPreds(
     "gicombiner-debug-cxxpreds",
     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
     cl::cat(GICombinerEmitterCat));
+cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
+                             cl::desc("Print type inference debug logs"),
+                             cl::cat(GICombinerEmitterCat));
 
 constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
@@ -298,23 +302,30 @@ CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 ///   - Special types, e.g. GITypeOf
 class PatternType {
 public:
-  PatternType() = default;
-  PatternType(const Record *R) : R(R) {}
+  enum PTKind : uint8_t {
+    PT_None,
 
-  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+    PT_ValueType,
+    PT_TypeOf,
+  };
+
+  PatternType() : Kind(PT_None), Data() {}
+
+  static std::optional<PatternType> get(ArrayRef<SMLoc> DiagLoc,
+                                        const Record *R, Twine DiagCtx);
+  static PatternType getTypeOf(StringRef OpName);
 
-  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
-  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
-  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+  bool isNone() const { return Kind == PT_None; }
+  bool isLLT() const { return Kind == PT_ValueType; }
+  bool isSpecial() const { return isTypeOf(); }
+  bool isTypeOf() const { return Kind == PT_TypeOf; }
 
   StringRef getTypeOfOpName() const;
   LLTCodeGen getLLTCodeGen() const;
 
-  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
-
   LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
 
-  explicit operator bool() const { return R != nullptr; }
+  explicit operator bool() const { return !isNone(); }
 
   bool operator==(const PatternType &Other) const;
   bool operator!=(const PatternType &Other) const { return !operator==(Other); }
@@ -322,26 +333,66 @@ class PatternType {
   std::string str() const;
 
 private:
-  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+  PatternType(PTKind Kind) : Kind(Kind), Data() {}
+
+  PTKind Kind;
+  union DataT {
+    DataT() : Str() {}
+
+    /// PT_ValueType -> ValueType Def.
+    const Record *Def;
 
-  const Record *R = nullptr;
+    /// PT_TypeOf -> Operand name (without the '$')
+    StringRef Str;
+  } Data;
 };
 
+std::optional<PatternType> PatternType::get(ArrayRef<SMLoc> DiagLoc,
+                                            const Record *R, Twine DiagCtx) {
+  assert(R);
+  if (R->isSubClassOf("ValueType")) {
+    PatternType PT(PT_ValueType);
+    PT.Data.Def = R;
+    return PT;
+  }
+
+  if (R->isSubClassOf(TypeOfClassName)) {
+    auto RawOpName = R->getValueAsString("OpName");
+    if (!RawOpName.starts_with("$")) {
+      PrintError(DiagLoc, DiagCtx + ": invalid operand name format '" +
+                              RawOpName + "' in " + TypeOfClassName +
+                              ": expected '$' followed by an operand name");
+      return std::nullopt;
+    }
+
+    PatternType PT(PT_TypeOf);
+    PT.Data.Str = RawOpName.drop_front(1);
+    return PT;
+  }
+
+  PrintError(DiagLoc, DiagCtx + ": unknown type '" + R->getName() + "'");
+  return std::nullopt;
+}
+
+PatternType PatternType::getTypeOf(StringRef OpName) {
+  PatternType PT(PT_TypeOf);
+  PT.Data.Str = OpName;
+  return PT;
+}
+
 StringRef PatternType::getTypeOfOpName() const {
   assert(isTypeOf());
-  StringRef Name = getRawOpName();
-  Name.consume_front("$");
-  return Name;
+  return Data.Str;
 }
 
 LLTCodeGen PatternType::getLLTCodeGen() const {
   assert(isLLT());
-  return *MVTToLLT(getValueType(R));
+  return *MVTToLLT(getValueType(Data.Def));
 }
 
 LLTCodeGenOrTempType
 PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
-  assert(isValidType());
+  assert(!isNone());
 
   if (isLLT())
     return getLLTCodeGen();
@@ -351,50 +402,31 @@ PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
   return OM.getTempTypeIdx(RM);
 }
 
-bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
-  if (!isTypeOf())
-    return true;
-
-  auto RawOpName = getRawOpName();
-  if (RawOpName.starts_with("$"))
-    return true;
-
-  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
-                          TypeOfClassName +
-                          ": expected '$' followed by an operand name");
-  return false;
-}
-
 bool PatternType::operator==(const PatternType &Other) const {
-  if (R == Other.R) {
-    if (R && R->getName() != Other.R->getName()) {
-      dbgs() << "Same ptr but: " << R->getName() << " and "
-             << Other.R->getName() << "?\n";
-      assert(false);
-    }
+  if (Kind != Other.Kind)
+    return false;
+
+  switch (Kind) {
+  case PT_None:
     return true;
+  case PT_ValueType:
+    return Data.Def == Other.Data.Def;
+  case PT_TypeOf:
+    return Data.Str == Other.Data.Str;
   }
 
-  if (isTypeOf() && Other.isTypeOf())
-    return getTypeOfOpName() == Other.getTypeOfOpName();
-
-  return false;
+  llvm_unreachable("Unknown Type Kind");
 }
 
 std::string PatternType::str() const {
-  if (!R)
+  switch (Kind) {
+  case PT_None:
     return "";
-
-  if (!isValidType())
-    return "<invalid>";
-
-  if (isLLT())
-    return R->getName().str();
-
-  assert(isSpecial());
-
-  if (isTypeOf())
+  case PT_ValueType:
+    return Data.Def->getName().str();
+  case PT_TypeOf:
     return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+  }
 
   llvm_unreachable("Unknown type!");
 }
@@ -607,14 +639,10 @@ class InstructionOperand {
   using IntImmTy = int64_t;
 
   InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
-      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {}
 
   InstructionOperand(StringRef Name, PatternType Type)
-      : Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Name(insertStrRef(Name)), Type(Type) {}
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -638,7 +666,6 @@ class InstructionOperand {
 
   void setType(PatternType NewType) {
     assert((!Type || (Type == NewType)) && "Overwriting type!");
-    assert(NewType.isValidType());
     Type = NewType;
   }
   PatternType getType() const { return Type; }
@@ -809,12 +836,10 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
 /// Maps InstructionPattern operands to their definitions. This allows us to tie
 /// 
diff erent patterns of a (apply), (match) or (patterns) set of patterns
 /// together.
-template <typename DefTy = InstructionPattern> class OperandTable {
+class OperandTable {
 public:
-  static_assert(std::is_base_of_v<InstructionPattern, DefTy>,
-                "DefTy should be a derived class from InstructionPattern");
-
-  bool addPattern(DefTy *P, function_ref<void(StringRef)> DiagnoseRedef) {
+  bool addPattern(InstructionPattern *P,
+                  function_ref<void(StringRef)> DiagnoseRedef) {
     for (const auto &Op : P->named_operands()) {
       StringRef OpName = Op.getOperandName();
 
@@ -843,10 +868,10 @@ template <typename DefTy = InstructionPattern> class OperandTable {
 
   struct LookupResult {
     LookupResult() = default;
-    LookupResult(DefTy *Def) : Found(true), Def(Def) {}
+    LookupResult(InstructionPattern *Def) : Found(true), Def(Def) {}
 
     bool Found = false;
-    DefTy *Def = nullptr;
+    InstructionPattern *Def = nullptr;
 
     bool isLiveIn() const { return Found && !Def; }
   };
@@ -857,7 +882,9 @@ template <typename DefTy = InstructionPattern> class OperandTable {
     return LookupResult();
   }
 
-  DefTy *getDef(StringRef OpName) const { return lookup(OpName).Def; }
+  InstructionPattern *getDef(StringRef OpName) const {
+    return lookup(OpName).Def;
+  }
 
   void print(raw_ostream &OS, StringRef Name = "",
              StringRef Indent = "") const {
@@ -872,11 +899,11 @@ template <typename DefTy = InstructionPattern> class OperandTable {
     SmallVector<StringRef, 0> Keys(Table.keys());
     sort(Keys);
 
-    OS << "\n";
+    OS << '\n';
     for (const auto &Key : Keys) {
       const auto *Def = Table.at(Key);
       OS << Indent << "  " << Key << " -> "
-         << (Def ? Def->getName() : "<live-in>") << "\n";
+         << (Def ? Def->getName() : "<live-in>") << '\n';
     }
     OS << Indent << ")\n";
   }
@@ -887,7 +914,7 @@ template <typename DefTy = InstructionPattern> class OperandTable {
   void dump() const { print(dbgs()); }
 
 private:
-  StringMap<DefTy *> Table;
+  StringMap<InstructionPattern *> Table;
 };
 
 //===- CodeGenInstructionPattern ------------------------------------------===//
@@ -956,62 +983,76 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 ///
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
-/// setAllOperandTypes.
+/// propagateTypes.
 ///
 /// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P,
+  /// Step 1: Check each pattern one by one. All patterns that pass through here
+  /// are added to a common worklist so propagateTypes can access them.
+  bool check(InstructionPattern &P,
              std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
-  void setAllOperandTypes();
+  /// Step 2: Propagate all types. e.g. if one use of "$a" has type i32, make
+  /// all uses of "$a" have type i32.
+  void propagateTypes();
+
+protected:
+  ArrayRef<SMLoc> DiagLoc;
 
 private:
+  using InconsistentTypeDiagFn = std::function<void()>;
+
+  void PrintSeenWithTypeIn(InstructionPattern &P, StringRef OpName,
+                           PatternType Ty) const {
+    PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() +
+                           "' in '" + P.getName() + "'");
+  }
+
   struct OpTypeInfo {
     PatternType Type;
-    InstructionPattern *TypeSrc = nullptr;
+    InconsistentTypeDiagFn PrintTypeSrcNote = []() {};
   };
 
-  ArrayRef<SMLoc> DiagLoc;
   StringMap<OpTypeInfo> Types;
 
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
 bool OperandTypeChecker::check(
-    InstructionPattern *P,
+    InstructionPattern &P,
     std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
-  Pats.push_back(P);
+  Pats.push_back(&P);
 
-  for (auto &Op : P->operands()) {
+  for (auto &Op : P.operands()) {
     const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    if (!Ty.checkSemantics(DiagLoc))
-      return false;
-
     if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
       return false;
 
     if (!Op.isNamedOperand())
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    StringRef OpName = Op.getOperandName();
+    auto &Info = Types[OpName];
     if (!Info.Type) {
       Info.Type = Ty;
-      Info.TypeSrc = P;
+      Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() {
+        PrintSeenWithTypeIn(P, OpName, Ty);
+      };
       continue;
     }
 
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
-                              Op.getOperandName() + "': first seen with '" +
-                              Info.Type.str() + "' in '" +
-                              Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty.str() + "' in '" + P->getName() + "'");
+                              Op.getOperandName() + "': '" + Info.Type.str() +
+                              "' vs '" + Ty.str() + "'");
+      PrintSeenWithTypeIn(P, OpName, Ty);
+      Info.PrintTypeSrcNote();
       return false;
     }
   }
@@ -1019,7 +1060,7 @@ bool OperandTypeChecker::check(
   return true;
 }
 
-void OperandTypeChecker::setAllOperandTypes() {
+void OperandTypeChecker::propagateTypes() {
   for (auto *Pat : Pats) {
     for (auto &Op : Pat->named_operands()) {
       if (auto &Info = Types[Op.getOperandName()]; Info.Type)
@@ -1073,7 +1114,7 @@ class PatFrag {
   /// Each argument to the `pattern` DAG operator is parsed into a Pattern
   /// instance.
   struct Alternative {
-    OperandTable<> OpTable;
+    OperandTable OpTable;
     SmallVector<std::unique_ptr<Pattern>, 4> Pats;
   };
 
@@ -1297,11 +1338,11 @@ bool PatFrag::checkSemantics() {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP, CheckTypeOf))
+        if (!OTC.check(*IP, CheckTypeOf))
           return false;
       }
     }
-    OTC.setAllOperandTypes();
+    OTC.propagateTypes();
   }
 
   return true;
@@ -1369,7 +1410,7 @@ bool PatFrag::buildOperandsTables() {
 }
 
 void PatFrag::print(raw_ostream &OS, StringRef Indent) const {
-  OS << Indent << "(PatFrag name:" << getName() << "\n";
+  OS << Indent << "(PatFrag name:" << getName() << '\n';
   if (!in_params().empty()) {
     OS << Indent << "  (ins ";
     printParamsList(OS, in_params());
@@ -1613,7 +1654,7 @@ class PrettyStackTraceParse : public PrettyStackTraceEntry {
       OS << "Parsing " << PatFragClassName << " '" << Def.getName() << "'";
     else
       OS << "Parsing '" << Def.getName() << "'";
-    OS << "\n";
+    OS << '\n';
   }
 };
 
@@ -1635,10 +1676,429 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry {
 
     if (Pat)
       OS << " [" << Pat->getKindName() << " '" << Pat->getName() << "']";
-    OS << "\n";
+    OS << '\n';
   }
 };
 
+//===- CombineRuleOperandTypeChecker --------------------------------------===//
+
+/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
+/// On top of doing the same things as OperandTypeChecker, this also attempts to
+/// infer as many types as possible for temporary register defs & immediates in
+/// apply patterns.
+///
+/// The inference is trivial and leverages the MCOI OperandTypes encoded in
+/// CodeGenInstructions to infer types across patterns in a CombineRule. It's
+/// thus very limited and only supports CodeGenInstructions (but that's the main
+/// use case so it's fine).
+///
+/// We only try to infer untyped operands in apply patterns when they're temp
+/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
+/// a named operand from a match pattern.
+class CombineRuleOperandTypeChecker : private OperandTypeChecker {
+public:
+  CombineRuleOperandTypeChecker(const Record &RuleDef,
+                                const OperandTable &MatchOpTable)
+      : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
+        MatchOpTable(MatchOpTable) {}
+
+  /// Records and checks a 'match' pattern.
+  bool processMatchPattern(InstructionPattern &P);
+
+  /// Records and checks an 'apply' pattern.
+  bool processApplyPattern(InstructionPattern &P);
+
+  /// Propagates types, then perform type inference and do a second round of
+  /// propagation in the apply patterns only if any types were inferred.
+  void propagateAndInferTypes();
+
+private:
+  /// TypeEquivalenceClasses are groups of operands of an instruction that share
+  /// a common type.
+  ///
+  /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
+  /// d have the same type too. b/c and a/d don't have to have the same type,
+  /// though.
+  using TypeEquivalenceClasses = EquivalenceClasses<StringRef>;
+
+  /// \returns true for `OPERAND_GENERIC_` 0 through 5.
+  /// These are the MCOI types that can be registers. The other MCOI types are
+  /// either immediates, or fancier operands used only post-ISel, so we don't
+  /// care about them for combiners.
+  static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
+    // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
+    // OperandTypes are either never used in gMIR, or not relevant (e.g.
+    // OPERAND_GENERIC_IMM, which is definitely never a register).
+    return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
+  }
+
+  /// Finds the "MCOI::"" operand types for each operand of \p CGP.
+  ///
+  /// This is a bit trickier than it looks because we need to handle variadic
+  /// in/outs.
+  ///
+  /// e.g. for
+  ///   (G_BUILD_VECTOR $vec, $x, $y) ->
+  ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  ///    MCOI::OPERAND_GENERIC_1]
+  ///
+  /// For unknown types (which can happen in variadics where varargs types are
+  /// inconsistent), a unique name is given, e.g. "unknown_type_0".
+  static std::vector<std::string>
+  getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
+
+  /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
+  void getInstEqClasses(const InstructionPattern &P,
+                        TypeEquivalenceClasses &OutTECs) const;
+
+  /// Calls `getInstEqClasses` on all patterns of the rule to produce the whole
+  /// rule's TypeEquivalenceClasses.
+  TypeEquivalenceClasses getRuleEqClasses() const;
+
+  /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
+  /// TECs.
+  ///
+  /// This is achieved by trying to find a named operand in \p IP that shares
+  /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
+  /// operand instead.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferImmediateType(const InstructionPattern &IP,
+                                 unsigned ImmOpIdx,
+                                 const TypeEquivalenceClasses &TECs) const;
+
+  /// Looks inside \p TECs to infer \p OpName's type.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferNamedOperandType(const InstructionPattern &IP,
+                                    StringRef OpName,
+                                    const TypeEquivalenceClasses &TECs) const;
+
+  const Record &RuleDef;
+  SmallVector<InstructionPattern *, 8> MatchPats;
+  SmallVector<InstructionPattern *, 8> ApplyPats;
+
+  const OperandTable &MatchOpTable;
+};
+
+bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
+  MatchPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [](const auto &) {
+    // GITypeOf in 'match' is currently always rejected by the
+    // CombineRuleBuilder after inference is done.
+    return true;
+  });
+}
+
+bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
+  ApplyPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
+                                     "') does not refer to a matched operand!");
+    return false;
+  });
+}
+
+void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
+  /// First step here is to propagate types using the OperandTypeChecker. That
+  /// way we ensure all uses of a given register have consistent types.
+  propagateTypes();
+
+  /// Build the TypeEquivalenceClasses for the whole rule.
+  const TypeEquivalenceClasses TECs = getRuleEqClasses();
+
+  /// Look at the apply patterns and find operands that need to be
+  /// inferred. We then try to find an equivalence class that they're a part of
+  /// and select the best operand to use for the `GITypeOf` type. We prioritize
+  /// defs of matched instructions because those are guaranteed to be registers.
+  bool InferredAny = false;
+  for (auto *Pat : ApplyPats) {
+    for (unsigned K = 0; K < Pat->operands_size(); ++K) {
+      auto &Op = Pat->getOperand(K);
+
+      // We only want to take a look at untyped defs or immediates.
+      if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
+        continue;
+
+      // Infer defs & named immediates.
+      if (Op.isDef() || Op.isNamedImmediate()) {
+        // Check it's not a redefinition of a matched operand.
+        // In such cases, inference is not necessary because we just copy
+        // operands and don't create temporary registers.
+        if (MatchOpTable.lookup(Op.getOperandName()).Found)
+          continue;
+
+        // Inference is needed here, so try to do it.
+        if (PatternType Ty =
+                inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+
+        continue;
+      }
+
+      // Infer immediates
+      if (Op.hasImmValue()) {
+        if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+        continue;
+      }
+    }
+  }
+
+  // If we've inferred any types, we want to propagate them across the apply
+  // patterns. Type inference only adds GITypeOf types that point to Matched
+  // operands, so we definitely don't want to propagate types into the match
+  // patterns as well, otherwise bad things happen.
+  if (InferredAny) {
+    OperandTypeChecker OTC(RuleDef.getLoc());
+    for (auto *Pat : ApplyPats) {
+      if (!OTC.check(*Pat, [&](const auto &) { return true; }))
+        PrintFatalError(RuleDef.getLoc(),
+                        "OperandTypeChecker unexpectedly failed on '" +
+                            Pat->getName() + "' during Type Inference");
+    }
+    OTC.propagateTypes();
+
+    if (DebugTypeInfer) {
+      errs() << "Apply patterns for rule " << RuleDef.getName()
+             << " after inference:\n";
+      for (auto *Pat : ApplyPats) {
+        errs() << "  ";
+        Pat->print(errs(), /*PrintName*/ true);
+        errs() << '\n';
+      }
+      errs() << '\n';
+    }
+  }
+}
+
+PatternType CombineRuleOperandTypeChecker::inferImmediateType(
+    const InstructionPattern &IP, unsigned ImmOpIdx,
+    const TypeEquivalenceClasses &TECs) const {
+  // We can only infer CGPs.
+  const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGP)
+    return {};
+
+  // For CGPs, we try to infer immediates by trying to infer another named
+  // operand that shares its type.
+  //
+  // e.g.
+  //    Pattern: G_BUILD_VECTOR $x, $y, 0
+  //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  //              MCOI::OPERAND_GENERIC_1]
+  //    $y has the same type as 0, so we can infer $y and get the type 0 should
+  //    have.
+
+  // We infer immediates by looking for a named operand that shares the same
+  // MCOI type.
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  StringRef ImmOpTy = MCOITypes[ImmOpIdx];
+
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
+    if (Idx != ImmOpIdx && Ty == ImmOpTy) {
+      const auto &Op = IP.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      // Named operand with the same name, try to infer that.
+      if (PatternType InferTy =
+              inferNamedOperandType(IP, Op.getOperandName(), TECs))
+        return InferTy;
+    }
+  }
+
+  return {};
+}
+
+PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
+    const InstructionPattern &IP, StringRef OpName,
+    const TypeEquivalenceClasses &TECs) const {
+  // This is the simplest possible case, we just need to find a TEC that
+  // contains OpName. Look at all other operands in equivalence class and try to
+  // find a suitable one.
+
+  // Check for a def of a matched pattern. This is guaranteed to always
+  // be a register so we can blindly use that.
+  StringRef GoodOpName;
+  for (auto It = TECs.findLeader(OpName); It != TECs.member_end(); ++It) {
+    if (*It == OpName)
+      continue;
+
+    const auto LookupRes = MatchOpTable.lookup(*It);
+    if (LookupRes.Def) // Favor defs
+      return PatternType::getTypeOf(*It);
+
+    // Otherwise just save this in case we don't find any def.
+    if (GoodOpName.empty() && LookupRes.Found)
+      GoodOpName = *It;
+  }
+
+  if (!GoodOpName.empty())
+    return PatternType::getTypeOf(GoodOpName);
+
+  // No good operand found, give up.
+  return {};
+}
+
+std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
+    const CodeGenInstructionPattern &CGP) {
+  // FIXME?: Should we cache this? We call it twice when inferring immediates.
+
+  static unsigned UnknownTypeIdx = 0;
+
+  std::vector<std::string> OpTypes;
+  auto &CGI = CGP.getInst();
+  Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
+                          ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
+                          : nullptr;
+  std::string VarArgsTyName =
+      VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
+                : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
+
+  // First, handle defs.
+  for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
+    OpTypes.push_back(CGI.Operands[K].OperandType);
+
+  // Then, handle variadic defs if there are any.
+  if (CGP.hasVariadicDefs()) {
+    for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
+      OpTypes.push_back(VarArgsTyName);
+  }
+
+  // If we had variadic defs, the op idx in the pattern won't match the op idx
+  // in the CGI anymore.
+  int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
+  assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
+
+  // Handle all remaining use operands, including variadic ones.
+  for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
+    unsigned CGIOpIdx = K + CGIOpOffset;
+    if (CGIOpIdx >= CGI.Operands.size()) {
+      assert(CGP.isVariadic());
+      OpTypes.push_back(VarArgsTyName);
+    } else {
+      OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
+    }
+  }
+
+  assert(OpTypes.size() == CGP.operands_size());
+  return OpTypes;
+}
+
+void CombineRuleOperandTypeChecker::getInstEqClasses(
+    const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
+  // Determine the TypeEquivalenceClasses by:
+  //    - Getting the MCOI Operand Types.
+  //    - Creating a Map of MCOI Type -> [Operand Indexes]
+  //    - Iterating over the map, filtering types we don't like, and just adding
+  //      the array of Operand Indexes to \p OutTECs.
+
+  // We can only do this on CodeGenInstructions. Other InstructionPatterns have
+  // no type inference information associated with them.
+  // TODO: Could we add some inference information to builtins at least? e.g.
+  // ReplaceReg should always replace with a reg of the same type, for instance.
+  // Though, those patterns are often used alone so it might not be worth the
+  // trouble to infer their types.
+  auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
+  if (!CGP)
+    return;
+
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  assert(MCOITypes.size() == P.operands_size());
+
+  DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes))
+    TyToOpIdx[Ty].push_back(Idx);
+
+  if (DebugTypeInfer)
+    errs() << "\tGroups for " << P.getName() << ":\t";
+
+  for (const auto &[Ty, Idxs] : TyToOpIdx) {
+    if (!canMCOIOperandTypeBeARegister(Ty))
+      continue;
+
+    if (DebugTypeInfer)
+      errs() << '[';
+    StringRef Sep = "";
+
+    // We only collect named operands.
+    StringRef Leader;
+    for (unsigned Idx : Idxs) {
+      const auto &Op = P.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      const auto OpName = Op.getOperandName();
+      if (DebugTypeInfer) {
+        errs() << Sep << OpName;
+        Sep = ", ";
+      }
+
+      if (Leader.empty())
+        OutTECs.insert((Leader = OpName));
+      else
+        OutTECs.unionSets(Leader, OpName);
+    }
+
+    if (DebugTypeInfer)
+      errs() << "] ";
+  }
+
+  if (DebugTypeInfer)
+    errs() << '\n';
+}
+
+CombineRuleOperandTypeChecker::TypeEquivalenceClasses
+CombineRuleOperandTypeChecker::getRuleEqClasses() const {
+  StringMap<unsigned> OpNameToEqClassIdx;
+  TypeEquivalenceClasses TECs;
+
+  if (DebugTypeInfer)
+    errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
+           << ":\n";
+
+  for (const auto *Pat : MatchPats)
+    getInstEqClasses(*Pat, TECs);
+  for (const auto *Pat : ApplyPats)
+    getInstEqClasses(*Pat, TECs);
+
+  if (DebugTypeInfer) {
+    errs() << "Final Type Equivalence Classes: ";
+    for (auto ClassIt = TECs.begin(); ClassIt != TECs.end(); ++ClassIt) {
+      // only print non-empty classes.
+      if (auto MembIt = TECs.member_begin(ClassIt);
+          MembIt != TECs.member_end()) {
+        errs() << '[';
+        StringRef Sep = "";
+        for (; MembIt != TECs.member_end(); ++MembIt) {
+          errs() << Sep << *MembIt;
+          Sep = ", ";
+        }
+        errs() << "] ";
+      }
+    }
+    errs() << '\n';
+  }
+
+  return TECs;
+}
+
 //===- CombineRuleBuilder -------------------------------------------------===//
 
 /// Parses combine rule and builds a small intermediate representation to tie
@@ -1820,8 +2280,8 @@ class CombineRuleBuilder {
   PatternMap ApplyPats;
 
   /// Operand tables to tie match/apply patterns together.
-  OperandTable<> MatchOpTable;
-  OperandTable<> ApplyOpTable;
+  OperandTable MatchOpTable;
+  OperandTable ApplyOpTable;
 
   /// Set by findRoots.
   Pattern *MatchRoot = nullptr;
@@ -1893,14 +2353,14 @@ bool CombineRuleBuilder::emitRuleMatchers() {
 
 void CombineRuleBuilder::print(raw_ostream &OS) const {
   OS << "(CombineRule name:" << RuleDef.getName() << " id:" << RuleID
-     << " root:" << RootName << "\n";
+     << " root:" << RootName << '\n';
 
   if (!MatchDatas.empty()) {
     OS << "  (MatchDatas\n";
     for (const auto &MD : MatchDatas) {
       OS << "    ";
       MD.print(OS);
-      OS << "\n";
+      OS << '\n';
     }
     OS << "  )\n";
   }
@@ -1909,7 +2369,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const {
     OS << "  (PatFrags\n";
     for (const auto *PF : SeenPatFrags) {
       PF->print(OS, /*Indent=*/"    ");
-      OS << "\n";
+      OS << '\n';
     }
     OS << "  )\n";
   }
@@ -1921,7 +2381,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const {
       return;
     }
 
-    OS << "\n";
+    OS << '\n';
     for (const auto &[Name, Pat] : Pats) {
       OS << "    ";
       if (Pat.get() == MatchRoot)
@@ -1931,7 +2391,7 @@ void CombineRuleBuilder::print(raw_ostream &OS) const {
         OS << "<apply_root>";
       OS << Name << ":";
       Pat->print(OS, /*PrintName=*/false);
-      OS << "\n";
+      OS << '\n';
     }
     OS << "  )\n";
   };
@@ -1972,9 +2432,9 @@ void CombineRuleBuilder::verify() const {
       // Both strings are allocated in the pool using insertStrRef.
       if (Name.data() != Pat->getName().data()) {
         dbgs() << "Map StringRef: '" << Name << "' @ "
-               << (const void *)Name.data() << "\n";
+               << (const void *)Name.data() << '\n';
         dbgs() << "Pat String: '" << Pat->getName() << "' @ "
-               << (const void *)Pat->getName().data() << "\n";
+               << (const void *)Pat->getName().data() << '\n';
         PrintFatalError("StringRef stored in the PatternMap is not referencing "
                         "the same string as its Pattern!");
       }
@@ -2074,7 +2534,7 @@ void CombineRuleBuilder::addCXXPredicate(RuleMatcher &M,
       P.expandCode(CE, RuleDef.getLoc(), [&](raw_ostream &OS) {
         OS << "// Pattern Alternatives: ";
         print(OS, Alts);
-        OS << "\n";
+        OS << '\n';
       });
   IM->addPredicate<GenericInstructionPredicateMatcher>(
       ExpandedCode.getEnumNameWithPrefix(CXXPredPrefix));
@@ -2102,39 +2562,23 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 }
 
 bool CombineRuleBuilder::typecheckPatterns() {
-  OperandTypeChecker OTC(RuleDef.getLoc());
-
-  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
-    // We'll reject those after we're done inferring
-    return true;
-  };
+  CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
 
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckMatchTypeOf))
+      if (!OTC.processMatchPattern(*IP))
         return false;
     }
   }
 
-  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
-    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
-    const auto OpName = Ty.getTypeOfOpName();
-    if (MatchOpTable.lookup(OpName).Found)
-      return true;
-
-    PrintError("'" + OpName + "' ('" + Ty.str() +
-               "') does not refer to a matched operand!");
-    return false;
-  };
-
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckApplyTypeOf))
+      if (!OTC.processApplyPattern(*IP))
         return false;
     }
   }
 
-  OTC.setAllOperandTypes();
+  OTC.propagateAndInferTypes();
 
   // Always check this after in case inference adds some special types to the
   // match patterns.
@@ -2630,7 +3074,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
   // untyped immediate, e.g. 0
   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(IntImm->getValue(), Name, /*Type=*/nullptr);
+    IP.addOperand(IntImm->getValue(), Name, PatternType());
     return true;
   }
 
@@ -2640,13 +3084,11 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    PatternType ImmTy(TyDef);
-    if (!ImmTy.isValidType()) {
-      PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + TyDef->getName() + "' is not a ValueType or " +
-                 SpecialTyClassName);
+    auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef,
+                                  "cannot parse immediate '" +
+                                      DagOp->getAsUnquotedString() + "'");
+    if (!ImmTy)
       return false;
-    }
 
     if (!IP.hasAllDefs()) {
       PrintError("out operand of '" + IP.getInstName() +
@@ -2659,7 +3101,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(Val->getValue(), Name, ImmTy);
+    IP.addOperand(Val->getValue(), Name, *ImmTy);
     return true;
   }
 
@@ -2671,20 +3113,18 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    PatternType Ty(Def);
-    if (!Ty.isValidType()) {
-      PrintError("invalid operand type: '" + Def->getName() +
-                 "' is not a ValueType");
+    auto Ty =
+        PatternType::get(RuleDef.getLoc(), Def, "cannot parse operand type");
+    if (!Ty)
       return false;
-    }
-    IP.addOperand(OpName->getAsUnquotedString(), Ty);
+    IP.addOperand(OpName->getAsUnquotedString(), *Ty);
     return true;
   }
 
   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
   if (isa<UnsetInit>(OpInit)) {
     assert(OpName && "Unset w/ no OpName?");
-    IP.addOperand(OpName->getAsUnquotedString(), /*Type=*/nullptr);
+    IP.addOperand(OpName->getAsUnquotedString(), PatternType());
     return true;
   }
 
@@ -3746,7 +4186,7 @@ void GICombinerEmitter::emitRunCustomAction(raw_ostream &OS) {
     OS << "  switch(ApplyID) {\n";
     for (const auto &Apply : ApplyCode) {
       OS << "  case " << Apply->getEnumNameWithPrefix(CXXApplyPrefix) << ":{\n"
-         << "    " << join(split(Apply->Code, "\n"), "\n    ") << "\n"
+         << "    " << join(split(Apply->Code, '\n'), "\n    ") << '\n'
          << "    return;\n";
       OS << "  }\n";
     }


        


More information about the llvm-commits mailing list