[llvm] [TableGen][GlobalISel] Add MIFlags matching & rewriting (PR #71179)

Pierre van Houtryve via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 3 06:02:10 PDT 2023


https://github.com/Pierre-vh created https://github.com/llvm/llvm-project/pull/71179

NOTE: This review is part of a stack. Please only review the last commit. See #66377 to review the first commit.

Also disables generation of MutateOpcode. It's almost never used in combiners anyway.
If we really want to use it, it needs to be investigated & properly fixed (see TODO)
    
Fixes #70780

>From b96b57757bc45e56a8680f962901ab903fbeef7e Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Thu, 14 Sep 2023 15:52:29 +0200
Subject: [PATCH 1/2] [TableGen][GlobalISel] Add rule-wide type inference

The inference is trivial and leverages the MCOI OperandTypes encoded in
CodeGenInstructions to infer types across patterns in a CombineRule. It's
thus very limited and only supports CodeGenInstructions (but that's the main
use case so it's fine).

We only try to infer untyped operands in apply patterns when they're temp
reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
a named operand from a match pattern.
---
 llvm/include/llvm/Target/GenericOpcodes.td    |   7 +
 .../include/llvm/Target/GlobalISel/Combine.td |   2 +-
 .../pattern-errors.td                         |   4 +-
 .../type-inference.td                         |  75 ++
 .../typeof-errors.td                          |   7 +-
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 748 +++++++++++++++---
 6 files changed, 713 insertions(+), 130 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td

diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index a1afc3b8042c284..9a9c09d3c20d612 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -17,6 +17,10 @@
 
 class GenericInstruction : StandardPseudoInstruction {
   let isPreISelOpcode = true;
+
+  // When all variadic ops share a type with another operand,
+  // this is the type they share. Used by MIR patterns type inference.
+  TypedOperand variadicOpsType = ?;
 }
 
 // Provide a variant of an instruction with the same operands, but
@@ -1228,6 +1232,7 @@ def G_UNMERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst0, variable_ops);
   let InOperandList = (ins type1:$src);
   let hasSideEffects = false;
+  let variadicOpsType = type0;
 }
 
 // Insert a smaller register into a larger one at the specified bit-index.
@@ -1245,6 +1250,7 @@ def G_MERGE_VALUES : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Create a vector from multiple scalar registers. No implicit
@@ -1254,6 +1260,7 @@ def G_BUILD_VECTOR : GenericInstruction {
   let OutOperandList = (outs type0:$dst);
   let InOperandList = (ins type1:$src0, variable_ops);
   let hasSideEffects = false;
+  let variadicOpsType = type1;
 }
 
 /// Like G_BUILD_VECTOR, but truncates the larger operand types to fit the
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index 63c485a5a6c6070..ee0209eb9e5593a 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -796,7 +796,7 @@ def trunc_shift: GICombineRule <
 def mul_by_neg_one: GICombineRule <
   (defs root:$dst),
   (match (G_MUL $dst, $x, -1)),
-  (apply (G_SUB $dst, (GITypeOf<"$x"> 0), $x))
+  (apply (G_SUB $dst, 0, $x))
 >;
 
 // Fold (xor (and x, y), y) -> (and (not x), y)
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
index 48a06474da78a10..77321a3f2b59dbe 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
@@ -151,7 +151,7 @@ def bad_imm_too_many_args : GICombineRule<
   (match (COPY $x, (i32 0, 0)):$d),
   (apply (COPY $x, $b):$d)>;
 
-// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: cannot parse immediate '(COPY 0)', 'COPY' is not a ValueType
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: unknown type 'COPY'
 // CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(COPY ?:$x, (COPY 0))
 def bad_imm_not_a_valuetype : GICombineRule<
   (defs root:$a),
@@ -186,7 +186,7 @@ def expected_op_name : GICombineRule<
   (match (G_FNEG $x, i32)),
   (apply (COPY $x, (i32 0)))>;
 
-// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: invalid operand type: 'not_a_type' is not a ValueType
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: unknown type 'not_a_type'
 // CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_FNEG ?:$x, not_a_type:$y)'
 def not_a_type;
 def bad_mo_type_not_a_valuetype : GICombineRule<
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
new file mode 100644
index 000000000000000..7ed14dd5e6cc0eb
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/type-inference.td
@@ -0,0 +1,75 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -gicombiner-debug-typeinfer -combiners=MyCombiner %s 2>&1 | \
+// RUN: FileCheck %s
+
+// Checks reasoning of the inference rules.
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+// CHECK:      Rule Operand Type Equivalence Classes for inference_mul_by_neg_one:
+// CHECK-NEXT:         __inference_mul_by_neg_one_match_0:             [dst, x]
+// CHECK-NEXT:         __inference_mul_by_neg_one_apply_0:             [dst, x]
+// CHECK-NEXT: (merging [dst, x] | [dst, x])
+// CHECK-NEXT: Result: [dst, x]
+// CHECK-NEXT: INFER: imm 0 -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule inference_mul_by_neg_one after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__inference_mul_by_neg_one_apply_0 G_SUB operands:[<def>$dst, (GITypeOf<$x> 0), $x])
+def inference_mul_by_neg_one: GICombineRule <
+  (defs root:$dst),
+  (match (G_MUL $dst, $x, -1)),
+  (apply (G_SUB $dst, 0, $x))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_complex_tempreg:
+// CHECK-NEXT:         __infer_complex_tempreg_match_0:                [dst]   [x, y, z]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_0:                [tmp2]  [x, y]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_1:                [tmp, tmp2]
+// CHECK-NEXT:         __infer_complex_tempreg_apply_2:                [dst, tmp]
+// CHECK-NEXT: (merging [dst] | [dst, tmp])
+// CHECK-NEXT: (merging [x, y, z] | [x, y])
+// CHECK-NEXT: (merging [tmp2] | [tmp, tmp2])
+// CHECK-NEXT: (merging [dst, tmp] | [tmp2, tmp])
+// CHECK-NEXT: Result: [dst, tmp, tmp2] [x, y, z]
+// CHECK-NEXT: INFER: MachineOperand $tmp2 -> GITypeOf<$dst>
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$dst>
+// CHECK-NEXT: Apply patterns for rule infer_complex_tempreg after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_0 G_BUILD_VECTOR operands:[<def>GITypeOf<$dst>:$tmp2, $x, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_1 G_FNEG operands:[<def>GITypeOf<$dst>:$tmp, GITypeOf<$dst>:$tmp2])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_complex_tempreg_apply_2 G_FNEG operands:[<def>$dst, GITypeOf<$dst>:$tmp])
+def infer_complex_tempreg: GICombineRule <
+  (defs root:$dst),
+  (match (G_MERGE_VALUES $dst, $x, $y, $z)),
+  (apply (G_BUILD_VECTOR $tmp2, $x, $y),
+         (G_FNEG $tmp, $tmp2),
+         (G_FNEG $dst, $tmp))
+>;
+
+// CHECK:      Rule Operand Type Equivalence Classes for infer_variadic_outs:
+// CHECK-NEXT:         __infer_variadic_outs_match_0:          [x, y]  [vec]
+// CHECK-NEXT:         __infer_variadic_outs_match_1:          [dst, x]
+// CHECK-NEXT:         __infer_variadic_outs_apply_0:          [tmp, y]
+// CHECK-NEXT:         __infer_variadic_outs_apply_1:  (empty)
+// CHECK-NEXT: (merging [x, y] | [dst, x])
+// CHECK-NEXT: (merging [x, y, dst] | [tmp, y])
+// CHECK-NEXT: Result: [x, y, dst, tmp] [vec]
+// CHECK-NEXT: INFER: MachineOperand $tmp -> GITypeOf<$x>
+// CHECK-NEXT: Apply patterns for rule infer_variadic_outs after inference:
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_0 G_FNEG operands:[<def>GITypeOf<$x>:$tmp, $y])
+// CHECK-NEXT: (CodeGenInstructionPattern name:__infer_variadic_outs_apply_1 COPY operands:[<def>$dst, GITypeOf<$x>:$tmp])
+def infer_variadic_outs: GICombineRule <
+  (defs root:$dst),
+  (match  (G_UNMERGE_VALUES $x, $y, $vec),
+          (G_FNEG $dst, $x)),
+  (apply (G_FNEG $tmp, $y),
+         (COPY $dst, $tmp))
+>;
+
+def MyCombiner: GICombiner<"GenMyCombiner", [
+  inference_mul_by_neg_one,
+  infer_complex_tempreg,
+  infer_variadic_outs
+]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
index 6040d6def449766..b86b4ec19564488 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/typeof-errors.td
@@ -8,7 +8,8 @@ include "llvm/Target/GlobalISel/Combine.td"
 def MyTargetISA : InstrInfo;
 def MyTarget : Target { let InstructionSet = MyTargetISA; }
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: invalid operand name format 'unknown' in GITypeOf: expected '$' followed by an operand name
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ANYEXT ?:$dst, (anonymous_
 def NoDollarSign : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, $src)),
@@ -47,7 +48,9 @@ def InferredUseInMatch : GICombineRule<
   (match (G_ZEXT $dst, $src)),
   (apply (G_ANYEXT $dst, GITypeOf<"$dst">:$src))>;
 
-// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: conflicting types for operand 'src': first seen with 'i32' in '__InferenceConflict_match_0, now seen with 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+3]]:{{[0-9]+}}: error: conflicting types for operand 'src': 'i32' vs 'GITypeOf<$dst>'
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: note: 'src' seen with type 'GITypeOf<$dst>' in '__InferenceConflict_apply_0'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: 'src' seen with type 'i32' in '__InferenceConflict_match_0'
 def InferenceConflict : GICombineRule<
   (defs root:$dst),
   (match (G_ZEXT $dst, i32:$src)),
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 0c7b33a7b9d889d..82df0c5e58f041b 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -37,6 +37,8 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/CommandLine.h"
@@ -68,6 +70,9 @@ cl::opt<bool> DebugCXXPreds(
     "gicombiner-debug-cxxpreds",
     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
     cl::cat(GICombinerEmitterCat));
+cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
+                             cl::desc("Print type inference debug logs"),
+                             cl::cat(GICombinerEmitterCat));
 
 constexpr StringLiteral CXXApplyPrefix = "GICXXCustomAction_CombineApply";
 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
@@ -125,6 +130,20 @@ template <typename Container> auto values(Container &&C) {
   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
 }
 
+template <typename SetTy> bool doSetsIntersect(const SetTy &A, const SetTy &B) {
+  for (const auto &Elt : A) {
+    if (B.contains(Elt))
+      return true;
+  }
+  return false;
+}
+
+template <class Ty>
+void setVectorUnion(SetVector<Ty> &S1, const SetVector<Ty> &S2) {
+  for (auto SI = S2.begin(), SE = S2.end(); SI != SE; ++SI)
+    S1.insert(*SI);
+}
+
 //===- MatchData Handling -------------------------------------------------===//
 
 /// Represents MatchData defined by the match stage and required by the apply
@@ -298,23 +317,30 @@ CXXPredicateCode::CXXPredicateCodePool CXXPredicateCode::AllCXXApplyCode;
 ///   - Special types, e.g. GITypeOf
 class PatternType {
 public:
-  PatternType() = default;
-  PatternType(const Record *R) : R(R) {}
+  enum PTKind : uint8_t {
+    PT_None,
+
+    PT_ValueType,
+    PT_TypeOf,
+  };
 
-  bool isValidType() const { return !R || isLLT() || isSpecial(); }
+  PatternType() : Kind(PT_None), Data() {}
 
-  bool isLLT() const { return R && R->isSubClassOf("ValueType"); }
-  bool isSpecial() const { return R && R->isSubClassOf(SpecialTyClassName); }
-  bool isTypeOf() const { return R && R->isSubClassOf(TypeOfClassName); }
+  static std::optional<PatternType> get(ArrayRef<SMLoc> DiagLoc,
+                                        const Record *R);
+  static PatternType getTypeOf(StringRef OpName);
+
+  bool isNone() const { return Kind == PT_None; }
+  bool isLLT() const { return Kind == PT_ValueType; }
+  bool isSpecial() const { return isTypeOf(); }
+  bool isTypeOf() const { return Kind == PT_TypeOf; }
 
   StringRef getTypeOfOpName() const;
   LLTCodeGen getLLTCodeGen() const;
 
-  bool checkSemantics(ArrayRef<SMLoc> DiagLoc) const;
-
   LLTCodeGenOrTempType getLLTCodeGenOrTempType(RuleMatcher &RM) const;
 
-  explicit operator bool() const { return R != nullptr; }
+  explicit operator bool() const { return !isNone(); }
 
   bool operator==(const PatternType &Other) const;
   bool operator!=(const PatternType &Other) const { return !operator==(Other); }
@@ -322,26 +348,66 @@ class PatternType {
   std::string str() const;
 
 private:
-  StringRef getRawOpName() const { return R->getValueAsString("OpName"); }
+  PatternType(PTKind Kind) : Kind(Kind), Data() {}
+
+  PTKind Kind;
+  union DataT {
+    DataT() : Str() {}
 
-  const Record *R = nullptr;
+    /// PT_ValueType -> ValueType Def.
+    const Record *Def;
+
+    /// PT_TypeOf -> Operand name (without the '$')
+    StringRef Str;
+  } Data;
 };
 
+std::optional<PatternType> PatternType::get(ArrayRef<SMLoc> DiagLoc,
+                                            const Record *R) {
+  assert(R);
+  if (R->isSubClassOf("ValueType")) {
+    PatternType PT(PT_ValueType);
+    PT.Data.Def = R;
+    return PT;
+  }
+
+  if (R->isSubClassOf(TypeOfClassName)) {
+    auto RawOpName = R->getValueAsString("OpName");
+    if (!RawOpName.starts_with("$")) {
+      PrintError(DiagLoc, "invalid operand name format '" + RawOpName +
+                              "' in " + TypeOfClassName +
+                              ": expected '$' followed by an operand name");
+      return std::nullopt;
+    }
+
+    PatternType PT(PT_TypeOf);
+    PT.Data.Str = RawOpName.drop_front(1);
+    return PT;
+  }
+
+  PrintError(DiagLoc, "unknown type '" + R->getName() + "'");
+  return std::nullopt;
+}
+
+PatternType PatternType::getTypeOf(StringRef OpName) {
+  PatternType PT(PT_TypeOf);
+  PT.Data.Str = OpName;
+  return PT;
+}
+
 StringRef PatternType::getTypeOfOpName() const {
   assert(isTypeOf());
-  StringRef Name = getRawOpName();
-  Name.consume_front("$");
-  return Name;
+  return Data.Str;
 }
 
 LLTCodeGen PatternType::getLLTCodeGen() const {
   assert(isLLT());
-  return *MVTToLLT(getValueType(R));
+  return *MVTToLLT(getValueType(Data.Def));
 }
 
 LLTCodeGenOrTempType
 PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
-  assert(isValidType());
+  assert(!isNone());
 
   if (isLLT())
     return getLLTCodeGen();
@@ -351,50 +417,31 @@ PatternType::getLLTCodeGenOrTempType(RuleMatcher &RM) const {
   return OM.getTempTypeIdx(RM);
 }
 
-bool PatternType::checkSemantics(ArrayRef<SMLoc> DiagLoc) const {
-  if (!isTypeOf())
-    return true;
-
-  auto RawOpName = getRawOpName();
-  if (RawOpName.starts_with("$"))
-    return true;
-
-  PrintError(DiagLoc, "invalid operand name format '" + RawOpName + "' in " +
-                          TypeOfClassName +
-                          ": expected '$' followed by an operand name");
-  return false;
-}
-
 bool PatternType::operator==(const PatternType &Other) const {
-  if (R == Other.R) {
-    if (R && R->getName() != Other.R->getName()) {
-      dbgs() << "Same ptr but: " << R->getName() << " and "
-             << Other.R->getName() << "?\n";
-      assert(false);
-    }
+  if (Kind != Other.Kind)
+    return false;
+
+  switch (Kind) {
+  case PT_None:
     return true;
+  case PT_ValueType:
+    return Data.Def == Other.Data.Def;
+  case PT_TypeOf:
+    return Data.Str == Other.Data.Str;
   }
 
-  if (isTypeOf() && Other.isTypeOf())
-    return getTypeOfOpName() == Other.getTypeOfOpName();
-
-  return false;
+  llvm_unreachable("Unknown Type Kind");
 }
 
 std::string PatternType::str() const {
-  if (!R)
+  switch (Kind) {
+  case PT_None:
     return "";
-
-  if (!isValidType())
-    return "<invalid>";
-
-  if (isLLT())
-    return R->getName().str();
-
-  assert(isSpecial());
-
-  if (isTypeOf())
+  case PT_ValueType:
+    return Data.Def->getName().str();
+  case PT_TypeOf:
     return (TypeOfClassName + "<$" + getTypeOfOpName() + ">").str();
+  }
 
   llvm_unreachable("Unknown type!");
 }
@@ -607,14 +654,10 @@ class InstructionOperand {
   using IntImmTy = int64_t;
 
   InstructionOperand(IntImmTy Imm, StringRef Name, PatternType Type)
-      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Value(Imm), Name(insertStrRef(Name)), Type(Type) {}
 
   InstructionOperand(StringRef Name, PatternType Type)
-      : Name(insertStrRef(Name)), Type(Type) {
-    assert(Type.isValidType());
-  }
+      : Name(insertStrRef(Name)), Type(Type) {}
 
   bool isNamedImmediate() const { return hasImmValue() && isNamedOperand(); }
 
@@ -638,7 +681,6 @@ class InstructionOperand {
 
   void setType(PatternType NewType) {
     assert((!Type || (Type == NewType)) && "Overwriting type!");
-    assert(NewType.isValidType());
     Type = NewType;
   }
   PatternType getType() const { return Type; }
@@ -809,12 +851,10 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
 /// Maps InstructionPattern operands to their definitions. This allows us to tie
 /// different patterns of a (apply), (match) or (patterns) set of patterns
 /// together.
-template <typename DefTy = InstructionPattern> class OperandTable {
+class OperandTable {
 public:
-  static_assert(std::is_base_of_v<InstructionPattern, DefTy>,
-                "DefTy should be a derived class from InstructionPattern");
-
-  bool addPattern(DefTy *P, function_ref<void(StringRef)> DiagnoseRedef) {
+  bool addPattern(InstructionPattern *P,
+                  function_ref<void(StringRef)> DiagnoseRedef) {
     for (const auto &Op : P->named_operands()) {
       StringRef OpName = Op.getOperandName();
 
@@ -843,10 +883,10 @@ template <typename DefTy = InstructionPattern> class OperandTable {
 
   struct LookupResult {
     LookupResult() = default;
-    LookupResult(DefTy *Def) : Found(true), Def(Def) {}
+    LookupResult(InstructionPattern *Def) : Found(true), Def(Def) {}
 
     bool Found = false;
-    DefTy *Def = nullptr;
+    InstructionPattern *Def = nullptr;
 
     bool isLiveIn() const { return Found && !Def; }
   };
@@ -857,7 +897,9 @@ template <typename DefTy = InstructionPattern> class OperandTable {
     return LookupResult();
   }
 
-  DefTy *getDef(StringRef OpName) const { return lookup(OpName).Def; }
+  InstructionPattern *getDef(StringRef OpName) const {
+    return lookup(OpName).Def;
+  }
 
   void print(raw_ostream &OS, StringRef Name = "",
              StringRef Indent = "") const {
@@ -887,7 +929,7 @@ template <typename DefTy = InstructionPattern> class OperandTable {
   void dump() const { print(dbgs()); }
 
 private:
-  StringMap<DefTy *> Table;
+  StringMap<InstructionPattern *> Table;
 };
 
 //===- CodeGenInstructionPattern ------------------------------------------===//
@@ -956,62 +998,76 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
 ///
 /// It infers the type of each operand, check it's consistent with the known
 /// type of the operand, and then sets all of the types in all operands in
-/// setAllOperandTypes.
+/// propagateTypes.
 ///
 /// It also handles verifying correctness of special types.
 class OperandTypeChecker {
 public:
   OperandTypeChecker(ArrayRef<SMLoc> DiagLoc) : DiagLoc(DiagLoc) {}
 
-  bool check(InstructionPattern *P,
+  /// Step 1: Check each pattern one by one. All patterns that pass through here
+  /// are added to a common worklist so propagateTypes can access them.
+  bool check(InstructionPattern &P,
              std::function<bool(const PatternType &)> VerifyTypeOfOperand);
 
-  void setAllOperandTypes();
+  /// Step 2: Propagate all types. e.g. if one use of "$a" has type i32, make
+  /// all uses of "$a" have type i32.
+  void propagateTypes();
+
+protected:
+  ArrayRef<SMLoc> DiagLoc;
 
 private:
+  using InconsistentTypeDiagFn = std::function<void()>;
+
+  void PrintSeenWithTypeIn(InstructionPattern &P, StringRef OpName,
+                           PatternType Ty) const {
+    PrintNote(DiagLoc, "'" + OpName + "' seen with type '" + Ty.str() +
+                           "' in '" + P.getName() + "'");
+  }
+
   struct OpTypeInfo {
     PatternType Type;
-    InstructionPattern *TypeSrc = nullptr;
+    InconsistentTypeDiagFn PrintTypeSrcNote = []() {};
   };
 
-  ArrayRef<SMLoc> DiagLoc;
   StringMap<OpTypeInfo> Types;
 
   SmallVector<InstructionPattern *, 16> Pats;
 };
 
 bool OperandTypeChecker::check(
-    InstructionPattern *P,
+    InstructionPattern &P,
     std::function<bool(const PatternType &)> VerifyTypeOfOperand) {
-  Pats.push_back(P);
+  Pats.push_back(&P);
 
-  for (auto &Op : P->operands()) {
+  for (auto &Op : P.operands()) {
     const auto Ty = Op.getType();
     if (!Ty)
       continue;
 
-    if (!Ty.checkSemantics(DiagLoc))
-      return false;
-
     if (Ty.isTypeOf() && !VerifyTypeOfOperand(Ty))
       return false;
 
     if (!Op.isNamedOperand())
       continue;
 
-    auto &Info = Types[Op.getOperandName()];
+    StringRef OpName = Op.getOperandName();
+    auto &Info = Types[OpName];
     if (!Info.Type) {
       Info.Type = Ty;
-      Info.TypeSrc = P;
+      Info.PrintTypeSrcNote = [this, OpName, Ty, &P]() {
+        PrintSeenWithTypeIn(P, OpName, Ty);
+      };
       continue;
     }
 
     if (Info.Type != Ty) {
       PrintError(DiagLoc, "conflicting types for operand '" +
-                              Op.getOperandName() + "': first seen with '" +
-                              Info.Type.str() + "' in '" +
-                              Info.TypeSrc->getName() + ", now seen with '" +
-                              Ty.str() + "' in '" + P->getName() + "'");
+                              Op.getOperandName() + "': '" + Info.Type.str() +
+                              "' vs '" + Ty.str() + "'");
+      PrintSeenWithTypeIn(P, OpName, Ty);
+      Info.PrintTypeSrcNote();
       return false;
     }
   }
@@ -1019,7 +1075,7 @@ bool OperandTypeChecker::check(
   return true;
 }
 
-void OperandTypeChecker::setAllOperandTypes() {
+void OperandTypeChecker::propagateTypes() {
   for (auto *Pat : Pats) {
     for (auto &Op : Pat->named_operands()) {
       if (auto &Info = Types[Op.getOperandName()]; Info.Type)
@@ -1073,7 +1129,7 @@ class PatFrag {
   /// Each argument to the `pattern` DAG operator is parsed into a Pattern
   /// instance.
   struct Alternative {
-    OperandTable<> OpTable;
+    OperandTable OpTable;
     SmallVector<std::unique_ptr<Pattern>, 4> Pats;
   };
 
@@ -1297,11 +1353,11 @@ bool PatFrag::checkSemantics() {
     OperandTypeChecker OTC(Def.getLoc());
     for (auto &Pat : Alt.Pats) {
       if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-        if (!OTC.check(IP, CheckTypeOf))
+        if (!OTC.check(*IP, CheckTypeOf))
           return false;
       }
     }
-    OTC.setAllOperandTypes();
+    OTC.propagateTypes();
   }
 
   return true;
@@ -1639,6 +1695,471 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry {
   }
 };
 
+//===- CombineRuleOperandTypeChecker --------------------------------------===//
+
+/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
+/// On top of doing the same things as OperandTypeChecker, this also attempts to
+/// infer as many types as possible for temporary register defs & immediates in
+/// apply patterns.
+///
+/// The inference is trivial and leverages the MCOI OperandTypes encoded in
+/// CodeGenInstructions to infer types across patterns in a CombineRule. It's
+/// thus very limited and only supports CodeGenInstructions (but that's the main
+/// use case so it's fine).
+///
+/// We only try to infer untyped operands in apply patterns when they're temp
+/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
+/// a named operand from a match pattern.
+class CombineRuleOperandTypeChecker : private OperandTypeChecker {
+public:
+  CombineRuleOperandTypeChecker(const Record &RuleDef,
+                                const OperandTable &MatchOpTable)
+      : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
+        MatchOpTable(MatchOpTable) {}
+
+  /// Records and checks a 'match' pattern.
+  bool processMatchPattern(InstructionPattern &P);
+
+  /// Records and checks an 'apply' pattern.
+  bool processApplyPattern(InstructionPattern &P);
+
+  /// Propagates types, then perform type inference and do a second round of
+  /// propagation in the apply patterns only if any types were inferred.
+  void propagateAndInferTypes();
+
+private:
+  /// TypeEquivalenceClasses are groups of operands of an instruction that share
+  /// a common type.
+  ///
+  /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
+  /// d have the same type too. b/c and a/d don't have to have the same type,
+  /// though.
+  ///
+  /// NOTE: We use a SetVector, not a Set. This is to guarantee a stable
+  /// iteration order which is important because:
+  ///   - During inference, we iterate that set and pick the first suitable
+  ///   candidate. Using a normal set could make inference inconsistent across
+  ///   runs if the Set uses the StringRef ptr to cache values.
+  ///   - We print this set if DebugInfer is set, and we don't want our tests to
+  ///   fail randomly due to the Set's iteration order changing.
+  using TypeEquivalenceClasses = std::vector<SetVector<StringRef>>;
+
+  static std::string toString(const SetVector<StringRef> &EqClass) {
+    return "[" + join(EqClass, ", ") + "]";
+  }
+
+  /// \returns true for `OPERAND_GENERIC_` 0 through 5.
+  /// These are the MCOI types that can be registers. The other MCOI types are
+  /// either immediates, or fancier operands used only post-ISel, so we don't
+  /// care about them for combiners.
+  static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
+    // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
+    // OperandTypes are either never used in gMIR, or not relevant (e.g.
+    // OPERAND_GENERIC_IMM, which is definitely never a register).
+    return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
+  }
+
+  /// Finds the "MCOI::"" operand types for each operand of \p CGP.
+  ///
+  /// This is a bit trickier than it looks because we need to handle variadic
+  /// in/outs.
+  ///
+  /// e.g. for
+  ///   (G_BUILD_VECTOR $vec, $x, $y) ->
+  ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  ///    MCOI::OPERAND_GENERIC_1]
+  ///
+  /// For unknown types (which can happen in variadics where varargs types are
+  /// inconsistent), a unique name is given, e.g. "unknown_type_0".
+  static std::vector<std::string>
+  getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
+
+  /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
+  void getInstEqClasses(const InstructionPattern &P,
+                        TypeEquivalenceClasses &OutTECs) const;
+
+  /// Calculates the TypeEquivalenceClasses for each instruction, then merges
+  /// them into a common set of TypeEquivalenceClasses for the whole rule.
+  ///
+  /// This works by repeatedly merging intersecting type equivalence classes
+  /// until no more merging occurs.
+  ///
+  /// This essentially applies the "transitive" part of type inference. Let's
+  /// take the following equivalence classes:
+  ///   inst0: [a, b], [c, d]
+  ///   inst1: [b, c]
+  ///
+  /// If we see inst0 alone, we can't say that a and d have the same type -
+  /// they're not in the same equivalence classes. However if we just use logic,
+  /// we can say: "a == d because a == b, b == c and c == d".
+  ///
+  /// Merging condenses that information into a single big equivalence class
+  /// which can be looked at alone to make the same deduction.
+  ///   rule: [a, b, c, d]
+  TypeEquivalenceClasses getRuleEqClasses() const;
+
+  /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
+  /// TECs.
+  ///
+  /// This is achieved by trying to find a named operand in \p IP that shares
+  /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
+  /// operand instead.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferImmediateType(const InstructionPattern &IP,
+                                 unsigned ImmOpIdx,
+                                 const TypeEquivalenceClasses &TECs) const;
+
+  /// Looks inside \p TECs to infer \p OpName's type.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferNamedOperandType(const InstructionPattern &IP,
+                                    StringRef OpName,
+                                    const TypeEquivalenceClasses &TECs) const;
+
+  const Record &RuleDef;
+  SmallVector<InstructionPattern *, 8> MatchPats;
+  SmallVector<InstructionPattern *, 8> ApplyPats;
+
+  const OperandTable &MatchOpTable;
+};
+
+bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
+  MatchPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [](const auto &) {
+    // GITypeOf in 'match' is currently always rejected by the
+    // CombineRuleBuilder after inference is done.
+    return true;
+  });
+}
+
+bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
+  ApplyPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
+                                     "') does not refer to a matched operand!");
+    return false;
+  });
+}
+
+void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
+  /// First step here is to propagate types using the OperandTypeChecker. That
+  /// way we ensure all uses of a given register have consistent types.
+  propagateTypes();
+
+  /// Build the TypeEquivalenceClasses for the whole rule.
+  const TypeEquivalenceClasses TECs = getRuleEqClasses();
+
+  /// Look at the apply patterns and find operands that need to be
+  /// inferred. We then try to find an equivalence class that they're a part of
+  /// and select the best operand to use for the `GITypeOf` type. We prioritize
+  /// defs of matched instructions because those are guaranteed to be registers.
+  bool InferredAny = false;
+  for (auto *Pat : ApplyPats) {
+    for (unsigned K = 0; K < Pat->operands_size(); ++K) {
+      auto &Op = Pat->getOperand(K);
+
+      // We only want to take a look at untyped defs or immediates.
+      if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
+        continue;
+
+      // Infer defs & named immediates.
+      if (Op.isDef() || Op.isNamedImmediate()) {
+        // Check it's not a redefinition of a matched operand.
+        // In such cases, inference is not necessary because we just copy
+        // operands and don't create temporary registers.
+        if (MatchOpTable.lookup(Op.getOperandName()).Found)
+          continue;
+
+        // Inference is needed here, so try to do it.
+        if (PatternType Ty =
+                inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+
+        continue;
+      }
+
+      // Infer immediates
+      if (Op.hasImmValue()) {
+        if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+        continue;
+      }
+    }
+  }
+
+  // If we've inferred any types, we want to propagate them across the apply
+  // patterns. Type inference only adds GITypeOf types that point to Matched
+  // operands, so we definitely don't want to propagate types into the match
+  // patterns as well, otherwise bad things happen.
+  if (InferredAny) {
+    OperandTypeChecker OTC(RuleDef.getLoc());
+    for (auto *Pat : ApplyPats) {
+      if (!OTC.check(*Pat, [&](const auto &) { return true; }))
+        PrintFatalError(RuleDef.getLoc(),
+                        "OperandTypeChecker unexpectedly failed on '" +
+                            Pat->getName() + "' during Type Inference");
+    }
+    OTC.propagateTypes();
+
+    if (DebugTypeInfer) {
+      errs() << "Apply patterns for rule " << RuleDef.getName()
+             << " after inference:\n";
+      for (auto *Pat : ApplyPats) {
+        Pat->print(errs(), /*PrintName*/ true);
+        errs() << "\n";
+      }
+      errs() << "\n";
+    }
+  }
+}
+
+PatternType CombineRuleOperandTypeChecker::inferImmediateType(
+    const InstructionPattern &IP, unsigned ImmOpIdx,
+    const TypeEquivalenceClasses &TECs) const {
+  // We can only infer CGPs.
+  const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGP)
+    return {};
+
+  // For CGPs, we try to infer immediates by trying to infer another named
+  // operand that shares its type.
+  //
+  // e.g.
+  //    Pattern: G_BUILD_VECTOR $x, $y, 0
+  //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  //              MCOI::OPERAND_GENERIC_1]
+  //    $y has the same type as 0, so we can infer $y and get the type 0 should
+  //    have.
+
+  // We infer immediates by looking for a named operand that shares the same
+  // MCOI type.
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  StringRef ImmOpTy = MCOITypes[ImmOpIdx];
+
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
+    if (Idx != ImmOpIdx && Ty == ImmOpTy) {
+      const auto &Op = IP.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      // Named operand with the same name, try to infer that.
+      if (PatternType InferTy =
+              inferNamedOperandType(IP, Op.getOperandName(), TECs))
+        return InferTy;
+    }
+  }
+
+  return {};
+}
+
+PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
+    const InstructionPattern &IP, StringRef OpName,
+    const TypeEquivalenceClasses &TECs) const {
+  // This is the simplest possible case, we just need to find a TEC that
+  // contains OpName.
+  for (const auto &TEC : TECs) {
+    if (!TEC.contains(OpName))
+      continue;
+
+    // This TEC mentions the operand. Look at all other operands in this TEC and
+    // try to find a suitable one.
+
+    // First, check for a def of a matched pattern. This is guaranteed to always
+    // be a register so we can blindly use that.
+    StringRef GoodOpName;
+    for (const auto &EqOp : TEC) {
+      if (EqOp == OpName)
+        continue;
+
+      const auto LookupRes = MatchOpTable.lookup(EqOp);
+      if (LookupRes.Def) // Favor defs
+        return PatternType::getTypeOf(EqOp);
+
+      // Otherwise just save this in case we don't find any def.
+      if (GoodOpName.empty() && LookupRes.Found)
+        GoodOpName = EqOp;
+    }
+
+    if (!GoodOpName.empty())
+      return PatternType::getTypeOf(GoodOpName);
+
+    // No good operand found, give up.
+    return {};
+  }
+
+  return {};
+}
+
+std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
+    const CodeGenInstructionPattern &CGP) {
+  // FIXME?: Should we cache this? We call it twice when inferring immediates.
+
+  static unsigned UnknownTypeIdx = 0;
+
+  std::vector<std::string> OpTypes;
+  auto &CGI = CGP.getInst();
+  Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
+                          ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
+                          : nullptr;
+  std::string VarArgsTyName =
+      VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
+                : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
+
+  // First, handle defs.
+  for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
+    OpTypes.push_back(CGI.Operands[K].OperandType);
+
+  // Then, handle variadic defs if there are any.
+  if (CGP.hasVariadicDefs()) {
+    for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
+      OpTypes.push_back(VarArgsTyName);
+  }
+
+  // If we had variadic defs, the op idx in the pattern won't match the op idx
+  // in the CGI anymore.
+  int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
+  assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
+
+  // Handle all remaining use operands, including variadic ones.
+  for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
+    unsigned CGIOpIdx = K + CGIOpOffset;
+    if (CGIOpIdx >= CGI.Operands.size()) {
+      assert(CGP.isVariadic());
+      OpTypes.push_back(VarArgsTyName);
+    } else {
+      OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
+    }
+  }
+
+  assert(OpTypes.size() == CGP.operands_size());
+  return OpTypes;
+}
+
+void CombineRuleOperandTypeChecker::getInstEqClasses(
+    const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
+  // Determine the TypeEquivalenceClasses by:
+  //    - Getting the MCOI Operand Types.
+  //    - Creating a Map of MCOI Type -> [Operand Indexes]
+  //    - Iterating over the map, filtering types we don't like, and just adding
+  //      the array of Operand Indexes to \p OutTECs.
+
+  // We can only do this on CodeGenInstructions. Other InstructionPatterns have
+  // no type inference information associated with them.
+  // TODO: Could we add some inference information to builtins at least? e.g.
+  // ReplaceReg should always replace with a reg of the same type, for instance.
+  // Though, those patterns are often used alone so it might not be worth the
+  // trouble to infer their types.
+  auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
+  if (!CGP)
+    return;
+
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  assert(MCOITypes.size() == P.operands_size());
+
+  DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes))
+    TyToOpIdx[Ty].push_back(Idx);
+
+  const unsigned FirstNewTEC = OutTECs.size();
+  for (const auto &[Ty, Idxs] : TyToOpIdx) {
+    if (!canMCOIOperandTypeBeARegister(Ty))
+      continue;
+
+    SetVector<StringRef> OpNames;
+    // We only collect named operands.
+    for (unsigned Idx : Idxs) {
+      const auto &Op = P.getOperand(Idx);
+      if (Op.isNamedOperand())
+        OpNames.insert(Op.getOperandName());
+    }
+    OutTECs.emplace_back(std::move(OpNames));
+  }
+
+  if (DebugTypeInfer) {
+    errs() << "\t" << P.getName() << ":\t";
+    if (FirstNewTEC == OutTECs.size())
+      errs() << "(empty)";
+    else {
+      for (unsigned K = FirstNewTEC; K < OutTECs.size(); ++K)
+        errs() << "\t" << toString(OutTECs[K]);
+    }
+    errs() << "\n";
+  }
+}
+
+CombineRuleOperandTypeChecker::TypeEquivalenceClasses
+CombineRuleOperandTypeChecker::getRuleEqClasses() const {
+  StringMap<unsigned> OpNameToEqClassIdx;
+  TypeEquivalenceClasses TECs;
+
+  if (DebugTypeInfer)
+    errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
+           << ":\n";
+
+  for (const auto *Pat : MatchPats)
+    getInstEqClasses(*Pat, TECs);
+  for (const auto *Pat : ApplyPats)
+    getInstEqClasses(*Pat, TECs);
+
+  bool Merged;
+  do {
+    Merged = false;
+    for (auto &X : TECs) {
+      if (X.empty()) // Already merged
+        continue;
+
+      for (auto &Y : TECs) {
+        if (&X == &Y || Y.empty()) // Same set or already merged.
+          continue;
+
+        if (doSetsIntersect(X, Y)) {
+          if (DebugTypeInfer)
+            errs() << "(merging " << toString(X) << " | " << toString(Y)
+                   << ")\n";
+          setVectorUnion(X, Y);
+          Merged = true;
+
+          // To avoid invalidating iterators during iteration, we just clear the
+          // other set then clean it up later.
+          Y.clear();
+        }
+      }
+    }
+
+    // Remove empty sets.
+    erase_if(TECs, [&](auto &Set) { return Set.empty(); });
+  } while (Merged);
+
+  if (DebugTypeInfer) {
+    errs() << "Result: ";
+    if (TECs.empty())
+      errs() << "(empty)";
+    else {
+      for (const auto &TEC : TECs)
+        errs() << toString(TEC) << " ";
+    }
+    errs() << "\n";
+  }
+
+  return TECs;
+}
+
 //===- CombineRuleBuilder -------------------------------------------------===//
 
 /// Parses combine rule and builds a small intermediate representation to tie
@@ -1820,8 +2341,8 @@ class CombineRuleBuilder {
   PatternMap ApplyPats;
 
   /// Operand tables to tie match/apply patterns together.
-  OperandTable<> MatchOpTable;
-  OperandTable<> ApplyOpTable;
+  OperandTable MatchOpTable;
+  OperandTable ApplyOpTable;
 
   /// Set by findRoots.
   Pattern *MatchRoot = nullptr;
@@ -2102,39 +2623,23 @@ bool CombineRuleBuilder::hasEraseRoot() const {
 }
 
 bool CombineRuleBuilder::typecheckPatterns() {
-  OperandTypeChecker OTC(RuleDef.getLoc());
-
-  const auto CheckMatchTypeOf = [&](const PatternType &) -> bool {
-    // We'll reject those after we're done inferring
-    return true;
-  };
+  CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
 
   for (auto &Pat : values(MatchPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckMatchTypeOf))
+      if (!OTC.processMatchPattern(*IP))
         return false;
     }
   }
 
-  const auto CheckApplyTypeOf = [&](const PatternType &Ty) {
-    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
-    const auto OpName = Ty.getTypeOfOpName();
-    if (MatchOpTable.lookup(OpName).Found)
-      return true;
-
-    PrintError("'" + OpName + "' ('" + Ty.str() +
-               "') does not refer to a matched operand!");
-    return false;
-  };
-
   for (auto &Pat : values(ApplyPats)) {
     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
-      if (!OTC.check(IP, CheckApplyTypeOf))
+      if (!OTC.processApplyPattern(*IP))
         return false;
     }
   }
 
-  OTC.setAllOperandTypes();
+  OTC.propagateAndInferTypes();
 
   // Always check this after in case inference adds some special types to the
   // match patterns.
@@ -2630,7 +3135,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
   // untyped immediate, e.g. 0
   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(IntImm->getValue(), Name, /*Type=*/nullptr);
+    IP.addOperand(IntImm->getValue(), Name, PatternType());
     return true;
   }
 
@@ -2640,13 +3145,9 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     const Record *TyDef = DagOp->getOperatorAsDef(RuleDef.getLoc());
-    PatternType ImmTy(TyDef);
-    if (!ImmTy.isValidType()) {
-      PrintError("cannot parse immediate '" + OpInit->getAsUnquotedString() +
-                 "', '" + TyDef->getName() + "' is not a ValueType or " +
-                 SpecialTyClassName);
+    auto ImmTy = PatternType::get(RuleDef.getLoc(), TyDef);
+    if (!ImmTy)
       return false;
-    }
 
     if (!IP.hasAllDefs()) {
       PrintError("out operand of '" + IP.getInstName() +
@@ -2659,7 +3160,7 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return ParseErr();
 
     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
-    IP.addOperand(Val->getValue(), Name, ImmTy);
+    IP.addOperand(Val->getValue(), Name, *ImmTy);
     return true;
   }
 
@@ -2671,20 +3172,17 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
       return false;
     }
     const Record *Def = DefI->getDef();
-    PatternType Ty(Def);
-    if (!Ty.isValidType()) {
-      PrintError("invalid operand type: '" + Def->getName() +
-                 "' is not a ValueType");
+    auto Ty = PatternType::get(RuleDef.getLoc(), Def);
+    if (!Ty)
       return false;
-    }
-    IP.addOperand(OpName->getAsUnquotedString(), Ty);
+    IP.addOperand(OpName->getAsUnquotedString(), *Ty);
     return true;
   }
 
   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
   if (isa<UnsetInit>(OpInit)) {
     assert(OpName && "Unset w/ no OpName?");
-    IP.addOperand(OpName->getAsUnquotedString(), /*Type=*/nullptr);
+    IP.addOperand(OpName->getAsUnquotedString(), PatternType());
     return true;
   }
 

>From b7518d48cbd030a85a1ad416b89a603270cbf7d7 Mon Sep 17 00:00:00 2001
From: pvanhout <pierre.vanhoutryve at amd.com>
Date: Fri, 3 Nov 2023 13:59:04 +0100
Subject: [PATCH 2/2] [TableGen][GlobalISel] Add MIFlags matching & rewriting

Also disables generation of MutateOpcode. It's almost never used in combiners anyway.
If we really want to use it, it needs to be investigated & properly fixed.

Fixes #70780
---
 llvm/docs/GlobalISel/MIRPatterns.rst          |  38 ++++
 .../CodeGen/GlobalISel/GIMatchTableExecutor.h |  21 ++
 .../GlobalISel/GIMatchTableExecutorImpl.h     |  61 +++++-
 .../include/llvm/Target/GlobalISel/Combine.td |  19 ++
 .../match-table-miflags.td                    |  47 +++++
 .../patfrag-errors.td                         |  17 +-
 .../pattern-errors.td                         |  52 ++++-
 .../pattern-parsing.td                        |  25 ++-
 .../TableGen/GlobalISelCombinerEmitter.cpp    | 198 +++++++++++++++++-
 llvm/utils/TableGen/GlobalISelMatchTable.cpp  |  46 ++++
 llvm/utils/TableGen/GlobalISelMatchTable.h    |  33 +++
 11 files changed, 547 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-miflags.td

diff --git a/llvm/docs/GlobalISel/MIRPatterns.rst b/llvm/docs/GlobalISel/MIRPatterns.rst
index a3883b14b3e0bd6..9c363a38d29551d 100644
--- a/llvm/docs/GlobalISel/MIRPatterns.rst
+++ b/llvm/docs/GlobalISel/MIRPatterns.rst
@@ -183,6 +183,44 @@ Semantics:
 * The root cannot have any output operands.
 * The root must be a CodeGenInstruction
 
+Instruction Flags
+-----------------
+
+MIR Patterns support both matching & writing ``MIFlags``.
+``MIFlags`` are never preserved; output instructions have never have
+any flags unless explicitly set.
+
+.. code-block:: text
+  :caption: Example
+
+  def Test : GICombineRule<
+    (defs root:$dst),
+    (match (G_FOO $dst, $src, (MIFlags FmNoNans, FmNoInfs))),
+    (apply (G_BAR $dst, $src, (MIFlags FmReassoc)))>;
+
+In ``apply`` patterns, we also support referring to a matched instruction to
+"take" its MIFlags.
+
+.. code-block:: text
+  :caption: Example
+
+  ; We match NoNans/NoInfs, but $zext may have more flags.
+  ; Copy them all into the output instruction, but remove Reassoc if present.
+  def TestCpyFlags : GICombineRule<
+    (defs root:$dst),
+    (match (G_FOO $dst, $src, (MIFlags FmNoNans, FmNoInfs)):$zext),
+    (apply (G_BAR $dst, $src, (MIFlags $zext, FmReassoc)))>;
+
+The ``not`` operator can be used to check that a flag is NOT present
+on a matched instruction, and to remove a flag from a generated instruction.
+
+.. code-block:: text
+  :caption: Example
+
+  def TestNot : GICombineRule<
+    (defs root:$dst),
+    (match (G_FOO $dst, $src, (MIFlags FmNoInfs, (not FmNoNans, FmReassoc))):$zext),
+    (apply (G_BAR $dst, $src, (MIFlags $zext, (not FmNoInfs))))>;
 
 Limitations
 -----------
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
index 6fcd9d09e1863cc..f5d9f5f40881cb5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
@@ -266,6 +266,13 @@ enum {
   /// - NewOpIdx
   GIM_CheckCanReplaceReg,
 
+  /// Check that a matched instruction has, or doesn't have a MIFlag.
+  ///
+  /// - InsnID  - Instruction to check.
+  /// - Flag(s) - (can be one or more flags OR'd together)
+  GIM_MIFlags,
+  GIM_MIFlagsNot,
+
   /// Predicates with 'let PredicateCodeUsesOperands = 1' need to examine some
   /// named operands that will be recorded in RecordedOperands. Names of these
   /// operands are referenced in predicate argument list. Emitter determines
@@ -344,6 +351,20 @@ enum {
   /// OpIdx starts at 0 for the first implicit def.
   GIR_SetImplicitDefDead,
 
+  /// Set or unset a MIFlag on an instruction.
+  ///
+  /// - InsnID  - Instruction to modify.
+  /// - Flag(s) - (can be one or more flags OR'd together)
+  GIR_SetMIFlags,
+  GIR_UnsetMIFlags,
+
+  /// Copy the MIFlags of a matched instruction into an
+  /// output instruction. The flags are OR'd together.
+  ///
+  /// - InsnID     - Instruction to modify.
+  /// - OldInsnID  - Matched instruction to copy flags from.
+  GIR_CopyMIFlags,
+
   /// Add a temporary register to the specified instruction
   /// - InsnID - Instruction ID to modify
   /// - TempRegID - The temporary register ID to add
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
index 32e2f21d775f303..f0ee76c097bcab5 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
@@ -88,8 +88,6 @@ bool GIMatchTableExecutor::executeMatchTable(
       if (Observer)
         Observer->changedInstr(*MIB);
     }
-
-    return true;
   };
 
   // If the index is >= 0, it's an index in the type objects generated by
@@ -919,6 +917,32 @@ bool GIMatchTableExecutor::executeMatchTable(
       }
       break;
     }
+    case GIM_MIFlags: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      uint32_t Flags = (uint32_t)MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIM_MIFlags(MIs[" << InsnID
+                             << "], " << Flags << ")\n");
+      if ((State.MIs[InsnID]->getFlags() & Flags) != Flags) {
+        if (handleReject() == RejectAndGiveUp)
+          return false;
+      }
+      break;
+    }
+    case GIM_MIFlagsNot: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      uint32_t Flags = (uint32_t)MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIM_MIFlagsNot(MIs[" << InsnID
+                             << "], " << Flags << ")\n");
+      if ((State.MIs[InsnID]->getFlags() & Flags)) {
+        if (handleReject() == RejectAndGiveUp)
+          return false;
+      }
+      break;
+    }
     case GIM_Reject:
       DEBUG_WITH_TYPE(TgtExecutor::getName(),
                       dbgs() << CurrentIdx << ": GIM_Reject\n");
@@ -1062,6 +1086,39 @@ bool GIMatchTableExecutor::executeMatchTable(
       MI->getOperand(MI->getNumExplicitOperands() + OpIdx).setIsDead();
       break;
     }
+    case GIR_SetMIFlags: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      uint32_t Flags = (uint32_t)MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIR_SetMIFlags(OutMIs["
+                             << InsnID << "], " << Flags << ")\n");
+      MachineInstr *MI = OutMIs[InsnID];
+      MI->setFlags(MI->getFlags() | Flags);
+      break;
+    }
+    case GIR_UnsetMIFlags: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      uint32_t Flags = (uint32_t)MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIR_UnsetMIFlags(OutMIs["
+                             << InsnID << "], " << Flags << ")\n");
+      MachineInstr *MI = OutMIs[InsnID];
+      MI->setFlags(MI->getFlags() & ~Flags);
+      break;
+    }
+    case GIR_CopyMIFlags: {
+      int64_t InsnID = MatchTable[CurrentIdx++];
+      int64_t OldInsnID = MatchTable[CurrentIdx++];
+
+      DEBUG_WITH_TYPE(TgtExecutor::getName(),
+                      dbgs() << CurrentIdx << ": GIR_CopyMIFlags(OutMIs["
+                             << InsnID << "], MIs[" << OldInsnID << "])\n");
+      MachineInstr *MI = OutMIs[InsnID];
+      MI->setFlags(MI->getFlags() | State.MIs[OldInsnID]->getFlags());
+      break;
+    }
     case GIR_AddTempRegister:
     case GIR_AddTempSubRegister: {
       int64_t InsnID = MatchTable[CurrentIdx++];
diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td
index ee0209eb9e5593a..76b83cc5df073ae 100644
--- a/llvm/include/llvm/Target/GlobalISel/Combine.td
+++ b/llvm/include/llvm/Target/GlobalISel/Combine.td
@@ -164,6 +164,25 @@ def GIReplaceReg : GIBuiltinInst;
 // TODO: Allow using this directly, like (apply GIEraseRoot)
 def GIEraseRoot : GIBuiltinInst;
 
+//===----------------------------------------------------------------------===//
+// Pattern MIFlags
+//===----------------------------------------------------------------------===//
+
+class MIFlagEnum<string enumName> {
+  string EnumName = "MachineInstr::" # enumName;
+}
+
+def FmNoNans    : MIFlagEnum<"FmNoNans">;
+def FmNoInfs    : MIFlagEnum<"FmNoInfs">;
+def FmNsz       : MIFlagEnum<"FmNsz">;
+def FmArcp      : MIFlagEnum<"FmArcp">;
+def FmContract  : MIFlagEnum<"FmContract">;
+def FmAfn       : MIFlagEnum<"FmAfn">;
+def FmReassoc   : MIFlagEnum<"FmReassoc">;
+
+def MIFlags;
+// def not; -> Already defined as a SDNode
+
 //===----------------------------------------------------------------------===//
 
 def extending_load_matchdata : GIDefMatchData<"PreferredTuple">;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-miflags.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-miflags.td
new file mode 100644
index 000000000000000..9f02ff17493652d
--- /dev/null
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/match-table-miflags.td
@@ -0,0 +1,47 @@
+// RUN: llvm-tblgen -I %p/../../../include -gen-global-isel-combiner \
+// RUN:     -combiners=MyCombiner %s | \
+// RUN: FileCheck %s
+
+include "llvm/Target/Target.td"
+include "llvm/Target/GlobalISel/Combine.td"
+
+def MyTargetISA : InstrInfo;
+def MyTarget : Target { let InstructionSet = MyTargetISA; }
+
+def MIFlagsTest : GICombineRule<
+  (defs root:$dst),
+  (match (G_SEXT $dst, $tmp), (G_ZEXT $tmp, $src, (MIFlags FmReassoc, FmNsz, (not FmNoNans, FmArcp))):$mi),
+  (apply (G_MUL $dst, $src, $src, (MIFlags $mi, FmReassoc, (not FmNsz, FmArcp))))>;
+
+def MyCombiner: GICombiner<"GenMyCombiner", [MIFlagsTest]>;
+
+// CHECK:      const int64_t *GenMyCombiner::getMatchTable() const {
+// CHECK-NEXT:   constexpr static int64_t MatchTable0[] = {
+// CHECK-NEXT:     GIM_Try, /*On fail goto*//*Label 0*/ 49, // Rule ID 0 //
+// CHECK-NEXT:       GIM_CheckSimplePredicate, GICXXPred_Simple_IsRule0Enabled,
+// CHECK-NEXT:       GIM_CheckOpcode, /*MI*/0, TargetOpcode::G_SEXT,
+// CHECK-NEXT:       // MIs[0] dst
+// CHECK-NEXT:       // No operand predicates
+// CHECK-NEXT:       // MIs[0] tmp
+// CHECK-NEXT:       GIM_RecordInsnIgnoreCopies, /*DefineMI*/1, /*MI*/0, /*OpIdx*/1, // MIs[1]
+// CHECK-NEXT:       GIM_CheckOpcode, /*MI*/1, TargetOpcode::G_ZEXT,
+// CHECK-NEXT:       GIM_MIFlags, /*MI*/1, MachineInstr::FmNsz | MachineInstr::FmReassoc,
+// CHECK-NEXT:       GIM_MIFlagsNot, /*MI*/1, MachineInstr::FmArcp | MachineInstr::FmNoNans,
+// CHECK-NEXT:       // MIs[1] src
+// CHECK-NEXT:       // No operand predicates
+// CHECK-NEXT:       GIM_CheckIsSafeToFold, /*InsnID*/1,
+// CHECK-NEXT:       // Combiner Rule #0: MIFlagsTest
+// CHECK-NEXT:       GIR_BuildMI, /*InsnID*/0, /*Opcode*/TargetOpcode::G_MUL,
+// CHECK-NEXT:       GIR_Copy, /*NewInsnID*/0, /*OldInsnID*/0, /*OpIdx*/0, // dst
+// CHECK-NEXT:       GIR_Copy, /*NewInsnID*/0, /*OldInsnID*/1, /*OpIdx*/1, // src
+// CHECK-NEXT:       GIR_Copy, /*NewInsnID*/0, /*OldInsnID*/1, /*OpIdx*/1, // src
+// CHECK-NEXT:       GIR_CopyMIFlags, /*InsnID*/0, /*OldInsnID*/1,
+// CHECK-NEXT:       GIR_SetMIFlags, /*InsnID*/0, MachineInstr::FmReassoc,
+// CHECK-NEXT:       GIR_UnsetMIFlags, /*InsnID*/0, MachineInstr::FmNsz | MachineInstr::FmArcp,
+// CHECK-NEXT:       GIR_EraseFromParent, /*InsnID*/0,
+// CHECK-NEXT:       GIR_Done,
+// CHECK-NEXT:     // Label 0: @49
+// CHECK-NEXT:     GIM_Reject,
+// CHECK-NEXT:     };
+// CHECK-NEXT:   return MatchTable0;
+// CHECK-NEXT: }
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/patfrag-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/patfrag-errors.td
index 68bec4fa722d191..6f5c2b93609f428 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/patfrag-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/patfrag-errors.td
@@ -271,6 +271,19 @@ def root_def_has_multi_defs : GICombineRule<
   (match (RootDefHasMultiDefs $root, (i32 10))),
   (apply (COPY $root, (i32 0)))>;
 
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: matching/writing MIFlags is only allowed on CodeGenInstruction patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(DummyCXXPF ?:$x, (MIFlags FmArcp))'
+def miflags_in_pf : GICombineRule<
+  (defs root:$x),
+  (match (COPY $x, $y), (DummyCXXPF $x, (MIFlags FmArcp))),
+  (apply (COPY $x, $y))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: '$pf' does not refer to a CodeGenInstruction in MIFlags of '__badtype_for_flagref_in_apply_apply_0'
+def badtype_for_flagref_in_apply : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src), (DummyCXXPF $src):$pf),
+  (apply (G_MUL $dst, $src, $src, (MIFlags $pf)))>;
+
 // CHECK: error: Failed to parse one or more rules
 
 def MyCombiner: GICombiner<"GenMyCombiner", [
@@ -293,5 +306,7 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
   patfrag_in_apply,
   patfrag_cannot_be_root,
   inconsistent_arg_type,
-  root_def_has_multi_defs
+  root_def_has_multi_defs,
+  miflags_in_pf,
+  badtype_for_flagref_in_apply
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
index 77321a3f2b59dbe..e45a1c866a51544 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-errors.td
@@ -217,6 +217,50 @@ def def_named_imm_apply : GICombineRule<
   (apply (COPY i32:$tmp, $y),
          (COPY $x, (i32 0):$tmp):$foo)>;
 
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: MIFlags can only be present once on an instruction
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ZEXT ?:$dst, ?:$src, (MIFlags FmArcp), (MIFlags FmArcp))'
+def multi_miflags : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src, (MIFlags FmArcp), (MIFlags FmArcp)):$mi),
+  (apply (G_MUL $dst, $src, $src))>;
+
+def NotAMIFlagEnum;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: 'NotAMIFlagEnum' is not a subclass of 'MIFlagEnum'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ZEXT ?:$dst, ?:$src, (MIFlags NotAMIFlagEnum))'
+def not_miflagenum_1 : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src, (MIFlags NotAMIFlagEnum)):$mi),
+  (apply (G_MUL $dst, $src, $src))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: 'NotAMIFlagEnum' is not a subclass of 'MIFlagEnum'
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(G_ZEXT ?:$dst, ?:$src, (MIFlags (not NotAMIFlagEnum)))'
+def not_miflagenum_2 : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src, (MIFlags (not NotAMIFlagEnum))):$mi),
+
+  (apply (G_MUL $dst, $src, $src))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: matching/writing MIFlags is only allowed on CodeGenInstruction patterns
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: Failed to parse pattern: '(GIReplaceReg ?:$x, ?:$y, (MIFlags FmArcp))'
+def miflags_in_builtin : GICombineRule<
+  (defs root:$x),
+  (match (COPY $x, $y)),
+  (apply (GIReplaceReg $x, $y, (MIFlags FmArcp)))>;
+
+// CHECK: :[[@LINE+2]]:{{[0-9]+}}: error: 'match' patterns cannot refer to flags from other instructions
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: note: MIFlags in 'mi' refer to: impostor
+def using_flagref_in_match : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src, (MIFlags $impostor)):$mi),
+  (apply (G_MUL $dst, $src, $src))>;
+
+// CHECK: :[[@LINE+1]]:{{[0-9]+}}: error: unknown instruction '$impostor' referenced in MIFlags of '__badflagref_in_apply_apply_0'
+def badflagref_in_apply : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src):$mi),
+  (apply (G_MUL $dst, $src, $src, (MIFlags $impostor)))>;
+
 // CHECK: error: Failed to parse one or more rules
 
 def MyCombiner: GICombiner<"GenMyCombiner", [
@@ -251,5 +295,11 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
   bad_mo_type_not_a_valuetype,
   untyped_new_reg_in_apply,
   def_named_imm_match,
-  def_named_imm_apply
+  def_named_imm_apply,
+  multi_miflags,
+  not_miflagenum_1,
+  not_miflagenum_2,
+  miflags_in_builtin,
+  using_flagref_in_match,
+  badflagref_in_apply
 ]>;
diff --git a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
index fd41a7d1d72417e..26f3bb88da951c4 100644
--- a/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
+++ b/llvm/test/TableGen/GlobalISelCombinerEmitter/pattern-parsing.td
@@ -320,6 +320,28 @@ def TypeOfTest : GICombineRule<
          (G_ZEXT $tmp, $src)),
   (apply (G_MUL $dst, (GITypeOf<"$src"> 0), (GITypeOf<"$dst"> -1)))>;
 
+
+// CHECK:      (CombineRule name:MIFlagsTest id:11 root:dst
+// CHECK-NEXT:   (MatchPats
+// CHECK-NEXT:     <match_root>mi:(CodeGenInstructionPattern G_ZEXT operands:[<def>$dst, $src] (MIFlags (set MachineInstr::FmReassoc) (unset MachineInstr::FmNoNans, MachineInstr::FmArcp)))
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (ApplyPats
+// CHECK-NEXT:     <apply_root>__MIFlagsTest_apply_0:(CodeGenInstructionPattern G_MUL operands:[<def>$dst, $src, $src] (MIFlags (set MachineInstr::FmReassoc) (unset MachineInstr::FmNsz, MachineInstr::FmArcp) (copy mi)))
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable MatchPats
+// CHECK-NEXT:     dst -> mi
+// CHECK-NEXT:     src -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT:   (OperandTable ApplyPats
+// CHECK-NEXT:     dst -> __MIFlagsTest_apply_0
+// CHECK-NEXT:     src -> <live-in>
+// CHECK-NEXT:   )
+// CHECK-NEXT: )
+def MIFlagsTest : GICombineRule<
+  (defs root:$dst),
+  (match (G_ZEXT $dst, $src, (MIFlags FmReassoc, (not FmNoNans, FmArcp))):$mi),
+  (apply (G_MUL $dst, $src, $src, (MIFlags $mi, FmReassoc, (not FmNsz, FmArcp))))>;
+
 def MyCombiner: GICombiner<"GenMyCombiner", [
   WipOpcodeTest0,
   WipOpcodeTest1,
@@ -331,5 +353,6 @@ def MyCombiner: GICombiner<"GenMyCombiner", [
   PatFragTest1,
   VariadicsInTest,
   VariadicsOutTest,
-  TypeOfTest
+  TypeOfTest,
+  MIFlagsTest
 ]>;
diff --git a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
index 82df0c5e58f041b..5f05a282086f97c 100644
--- a/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp
@@ -80,6 +80,7 @@ constexpr StringLiteral PatFragClassName = "GICombinePatFrag";
 constexpr StringLiteral BuiltinInstClassName = "GIBuiltinInst";
 constexpr StringLiteral SpecialTyClassName = "GISpecialType";
 constexpr StringLiteral TypeOfClassName = "GITypeOf";
+constexpr StringLiteral MIFlagsEnumClassName = "MIFlagEnum";
 
 std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
   return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
@@ -786,6 +787,8 @@ class InstructionPattern : public Pattern {
 protected:
   InstructionPattern(unsigned K, StringRef Name) : Pattern(K, Name) {}
 
+  virtual void printExtras(raw_ostream &OS) const {}
+
   SmallVector<InstructionOperand, 4> Operands;
 };
 
@@ -843,6 +846,8 @@ void InstructionPattern::print(raw_ostream &OS, bool PrintName) const {
       Sep = ", ";
     }
     OS << "]";
+
+    printExtras(OS);
   });
 }
 
@@ -934,6 +939,25 @@ class OperandTable {
 
 //===- CodeGenInstructionPattern ------------------------------------------===//
 
+/// Helper class to contain data associated with a MIFlags operator.
+class MIFlagsInfo {
+public:
+  void addSetFlag(const Record *R) {
+    SetF.insert(R->getValueAsString("EnumName"));
+  }
+  void addUnsetFlag(const Record *R) {
+    UnsetF.insert(R->getValueAsString("EnumName"));
+  }
+  void addCopyFlag(StringRef InstName) { CopyF.insert(insertStrRef(InstName)); }
+
+  const auto &set_flags() const { return SetF; }
+  const auto &unset_flags() const { return UnsetF; }
+  const auto &copy_flags() const { return CopyF; }
+
+private:
+  SetVector<StringRef> SetF, UnsetF, CopyF;
+};
+
 /// Matches an instruction, e.g. `G_ADD $x, $y, $z`.
 class CodeGenInstructionPattern : public InstructionPattern {
 public:
@@ -953,11 +977,17 @@ class CodeGenInstructionPattern : public InstructionPattern {
   unsigned getNumInstDefs() const override;
   unsigned getNumInstOperands() const override;
 
+  MIFlagsInfo &getOrCreateMIFlagsInfo();
+  const MIFlagsInfo *getMIFlagsInfo() const { return FI.get(); }
+
   const CodeGenInstruction &getInst() const { return I; }
   StringRef getInstName() const override { return I.TheDef->getName(); }
 
 private:
+  void printExtras(raw_ostream &OS) const override;
+
   const CodeGenInstruction &I;
+  std::unique_ptr<MIFlagsInfo> FI;
 };
 
 bool CodeGenInstructionPattern::hasVariadicDefs() const {
@@ -991,6 +1021,26 @@ unsigned CodeGenInstructionPattern::getNumInstOperands() const {
                       : NumCGIOps;
 }
 
+MIFlagsInfo &CodeGenInstructionPattern::getOrCreateMIFlagsInfo() {
+  if (!FI)
+    FI = std::make_unique<MIFlagsInfo>();
+  return *FI;
+}
+
+void CodeGenInstructionPattern::printExtras(raw_ostream &OS) const {
+  if (!FI)
+    return;
+
+  OS << " (MIFlags";
+  if (!FI->set_flags().empty())
+    OS << " (set " << join(FI->set_flags(), ", ") << ")";
+  if (!FI->unset_flags().empty())
+    OS << " (unset " << join(FI->unset_flags(), ", ") << ")";
+  if (!FI->copy_flags().empty())
+    OS << " (copy " << join(FI->copy_flags(), ", ") << ")";
+  OS << ')';
+}
+
 //===- OperandTypeChecker -------------------------------------------------===//
 
 /// This is a trivial type checker for all operands in a set of
@@ -2275,6 +2325,8 @@ class CombineRuleBuilder {
   bool parseInstructionPatternOperand(InstructionPattern &IP,
                                       const Init *OpInit,
                                       const StringInit *OpName) const;
+  bool parseInstructionPatternMIFlags(InstructionPattern &IP,
+                                      const DagInit *Op) const;
   std::unique_ptr<PatFrag> parsePatFragImpl(const Record *Def) const;
   bool parsePatFragParamList(
       ArrayRef<SMLoc> DiagLoc, const DagInit &OpsList,
@@ -2721,6 +2773,19 @@ bool CombineRuleBuilder::checkSemantics() {
       continue;
     }
 
+    // MIFlags in match cannot use the following syntax: (MIFlags $mi)
+    if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(Pat)) {
+      if (auto *FI = CGP->getMIFlagsInfo()) {
+        if (!FI->copy_flags().empty()) {
+          PrintError(
+              "'match' patterns cannot refer to flags from other instructions");
+          PrintNote("MIFlags in '" + CGP->getName() +
+                    "' refer to: " + join(FI->copy_flags(), ", "));
+          return false;
+        }
+      }
+    }
+
     const auto *AOP = dyn_cast<AnyOpcodePattern>(Pat);
     if (!AOP)
       continue;
@@ -2745,6 +2810,28 @@ bool CombineRuleBuilder::checkSemantics() {
       return false;
     }
 
+    // Check that the insts mentioned in copy_flags exist.
+    if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(IP)) {
+      if (auto *FI = CGP->getMIFlagsInfo()) {
+        for (auto InstName : FI->copy_flags()) {
+          auto It = MatchPats.find(InstName);
+          if (It == MatchPats.end()) {
+            PrintError("unknown instruction '$" + InstName +
+                       "' referenced in MIFlags of '" + CGP->getName() + "'");
+            return false;
+          }
+
+          if (!isa<CodeGenInstructionPattern>(It->second.get())) {
+            PrintError(
+                "'$" + InstName +
+                "' does not refer to a CodeGenInstruction in MIFlags of '" +
+                CGP->getName() + "'");
+            return false;
+          }
+        }
+      }
+    }
+
     const auto *BIP = dyn_cast<BuiltinPattern>(IP);
     if (!BIP)
       continue;
@@ -3083,8 +3170,14 @@ CombineRuleBuilder::parseInstructionPattern(const Init &Arg,
   }
 
   for (unsigned K = 0; K < DagPat->getNumArgs(); ++K) {
-    if (!parseInstructionPatternOperand(*Pat, DagPat->getArg(K),
-                                        DagPat->getArgName(K)))
+    Init *Arg = DagPat->getArg(K);
+    if (auto *DagArg = getDagWithSpecificOperator(*Arg, "MIFlags")) {
+      if (!parseInstructionPatternMIFlags(*Pat, DagArg))
+        return nullptr;
+      continue;
+    }
+
+    if (!parseInstructionPatternOperand(*Pat, Arg, DagPat->getArgName(K)))
       return nullptr;
   }
 
@@ -3189,6 +3282,75 @@ bool CombineRuleBuilder::parseInstructionPatternOperand(
   return ParseErr();
 }
 
+bool CombineRuleBuilder::parseInstructionPatternMIFlags(
+    InstructionPattern &IP, const DagInit *Op) const {
+  auto *CGIP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGIP) {
+    PrintError("matching/writing MIFlags is only allowed on CodeGenInstruction "
+               "patterns");
+    return false;
+  }
+
+  const auto CheckFlagEnum = [&](const Record *R) {
+    if (!R->isSubClassOf(MIFlagsEnumClassName)) {
+      PrintError("'" + R->getName() + "' is not a subclass of '" +
+                 MIFlagsEnumClassName + "'");
+      return false;
+    }
+
+    return true;
+  };
+
+  if (CGIP->getMIFlagsInfo()) {
+    PrintError("MIFlags can only be present once on an instruction");
+    return false;
+  }
+
+  auto &FI = CGIP->getOrCreateMIFlagsInfo();
+  for (unsigned K = 0; K < Op->getNumArgs(); ++K) {
+    const Init *Arg = Op->getArg(K);
+
+    // Match/set a flag: (MIFlags FmNoNans)
+    if (const auto *Def = dyn_cast<DefInit>(Arg)) {
+      const Record *R = Def->getDef();
+      if (!CheckFlagEnum(R))
+        return false;
+
+      FI.addSetFlag(R);
+      continue;
+    }
+
+    // Do not match a flag/unset a flag: (MIFlags (not FmNoNans))
+    if (const DagInit *NotDag = getDagWithSpecificOperator(*Arg, "not")) {
+      for (const Init *NotArg : NotDag->getArgs()) {
+        const DefInit *DefArg = dyn_cast<DefInit>(NotArg);
+        if (!DefArg) {
+          PrintError("cannot parse '" + NotArg->getAsUnquotedString() +
+                     "': expected a '" + MIFlagsEnumClassName + "'");
+          return false;
+        }
+
+        const Record *R = DefArg->getDef();
+        if (!CheckFlagEnum(R))
+          return false;
+
+        FI.addUnsetFlag(R);
+        continue;
+      }
+
+      continue;
+    }
+
+    // Copy flags from a matched instruction: (MIFlags $mi)
+    if (isa<UnsetInit>(Arg)) {
+      FI.addCopyFlag(Op->getArgName(K)->getAsUnquotedString());
+      continue;
+    }
+  }
+
+  return true;
+}
+
 std::unique_ptr<PatFrag>
 CombineRuleBuilder::parsePatFragImpl(const Record *Def) const {
   auto StackTrace = PrettyStackTraceParse(*Def);
@@ -3319,7 +3481,7 @@ bool CombineRuleBuilder::emitMatchPattern(CodeExpansions &CE,
   auto StackTrace = PrettyStackTraceEmit(RuleDef, &IP);
 
   auto &M = addRuleMatcher(Alts);
-  InstructionMatcher &IM = M.addInstructionMatcher("root");
+  InstructionMatcher &IM = M.addInstructionMatcher(IP.getName());
   declareInstExpansion(CE, IM, IP.getName());
 
   DenseSet<const Pattern *> SeenPats;
@@ -3751,8 +3913,23 @@ bool CombineRuleBuilder::emitInstructionApplyPattern(
     DstMI.addRenderer<TempRegRenderer>(TempRegID);
   }
 
-  // TODO: works?
-  DstMI.chooseInsnToMutate(M);
+  // Render MIFlags
+  if (const auto *FI = CGIP.getMIFlagsInfo()) {
+    for (StringRef InstName : FI->copy_flags())
+      DstMI.addCopiedMIFlags(M.getInstructionMatcher(InstName));
+    for (StringRef F : FI->set_flags())
+      DstMI.addSetMIFlags(F);
+    for (StringRef F : FI->unset_flags())
+      DstMI.addUnsetMIFlags(F);
+  }
+
+  // Don't allow mutating opcodes for GISel combiners. We want a more precise
+  // handling of MIFlags so we require them to be explicitly preserved.
+  //
+  // TODO: We don't mutate very often, if at all in combiners, but it'd be nice
+  // to re-enable this. We'd then need to always clear MIFlags when mutating
+  // opcodes, and never mutate an inst that we copy flags from.
+  // DstMI.chooseInsnToMutate(M);
   declareInstExpansion(CE, DstMI, P.getName());
 
   return true;
@@ -3871,6 +4048,17 @@ bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
   IM.addPredicate<InstructionOpcodeMatcher>(&P.getInst());
   declareInstExpansion(CE, IM, P.getName());
 
+  // Check flags if needed.
+  if (const auto *FI = P.getMIFlagsInfo()) {
+    assert(FI->copy_flags().empty());
+
+    if (const auto &SetF = FI->set_flags(); !SetF.empty())
+      IM.addPredicate<MIFlagsInstructionPredicateMatcher>(SetF.getArrayRef());
+    if (const auto &UnsetF = FI->unset_flags(); !UnsetF.empty())
+      IM.addPredicate<MIFlagsInstructionPredicateMatcher>(UnsetF.getArrayRef(),
+                                                          /*CheckNot=*/true);
+  }
+
   for (const auto &[Idx, OriginalO] : enumerate(P.operands())) {
     // Remap the operand. This is used when emitting InstructionPatterns inside
     // PatFrags, so it can remap them to the arguments passed to the pattern.
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
index 6ec85269e6e20d0..5a4d32a34e2bcb8 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.cpp
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.cpp
@@ -1541,6 +1541,24 @@ void GenericInstructionPredicateMatcher::emitPredicateOpcodes(
         << MatchTable::LineBreak;
 }
 
+//===- MIFlagsInstructionPredicateMatcher ---------------------------------===//
+
+bool MIFlagsInstructionPredicateMatcher::isIdentical(
+    const PredicateMatcher &B) const {
+  if (!InstructionPredicateMatcher::isIdentical(B))
+    return false;
+  const auto &Other =
+      static_cast<const MIFlagsInstructionPredicateMatcher &>(B);
+  return Flags == Other.Flags && CheckNot == Other.CheckNot;
+}
+
+void MIFlagsInstructionPredicateMatcher::emitPredicateOpcodes(
+    MatchTable &Table, RuleMatcher &Rule) const {
+  Table << MatchTable::Opcode(CheckNot ? "GIM_MIFlagsNot" : "GIM_MIFlags")
+        << MatchTable::Comment("MI") << MatchTable::IntValue(InsnVarID)
+        << MatchTable::NamedValue(join(Flags, " | ")) << MatchTable::LineBreak;
+}
+
 //===- InstructionMatcher -------------------------------------------------===//
 
 OperandMatcher &
@@ -1956,6 +1974,30 @@ void BuildMIAction::chooseInsnToMutate(RuleMatcher &Rule) {
 
 void BuildMIAction::emitActionOpcodes(MatchTable &Table,
                                       RuleMatcher &Rule) const {
+  const auto AddMIFlags = [&]() {
+    for (const InstructionMatcher *IM : CopiedFlags) {
+      Table << MatchTable::Opcode("GIR_CopyMIFlags")
+            << MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
+            << MatchTable::Comment("OldInsnID")
+            << MatchTable::IntValue(IM->getInsnVarID())
+            << MatchTable::LineBreak;
+    }
+
+    if (!SetFlags.empty()) {
+      Table << MatchTable::Opcode("GIR_SetMIFlags")
+            << MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
+            << MatchTable::NamedValue(join(SetFlags, " | "))
+            << MatchTable::LineBreak;
+    }
+
+    if (!UnsetFlags.empty()) {
+      Table << MatchTable::Opcode("GIR_UnsetMIFlags")
+            << MatchTable::Comment("InsnID") << MatchTable::IntValue(InsnID)
+            << MatchTable::NamedValue(join(UnsetFlags, " | "))
+            << MatchTable::LineBreak;
+    }
+  };
+
   if (Matched) {
     assert(canMutate(Rule, Matched) &&
            "Arranged to mutate an insn that isn't mutatable");
@@ -1992,6 +2034,8 @@ void BuildMIAction::emitActionOpcodes(MatchTable &Table,
               << MatchTable::LineBreak;
       }
     }
+
+    AddMIFlags();
     return;
   }
 
@@ -2039,6 +2083,8 @@ void BuildMIAction::emitActionOpcodes(MatchTable &Table,
           << MatchTable::LineBreak;
   }
 
+  AddMIFlags();
+
   // FIXME: This is a hack but it's sufficient for ISel. We'll need to do
   //        better for combines. Particularly when there are multiple match
   //        roots.
diff --git a/llvm/utils/TableGen/GlobalISelMatchTable.h b/llvm/utils/TableGen/GlobalISelMatchTable.h
index 364f2a1ec725d53..469390d7312329b 100644
--- a/llvm/utils/TableGen/GlobalISelMatchTable.h
+++ b/llvm/utils/TableGen/GlobalISelMatchTable.h
@@ -790,6 +790,7 @@ class PredicateMatcher {
     IPM_VectorSplatImm,
     IPM_NoUse,
     IPM_GenericPredicate,
+    IPM_MIFlags,
     OPM_SameOperand,
     OPM_ComplexPattern,
     OPM_IntrinsicID,
@@ -1628,6 +1629,28 @@ class GenericInstructionPredicateMatcher : public InstructionPredicateMatcher {
                             RuleMatcher &Rule) const override;
 };
 
+class MIFlagsInstructionPredicateMatcher : public InstructionPredicateMatcher {
+  SmallVector<StringRef, 2> Flags;
+  bool CheckNot; // false = GIM_MIFlags, true = GIM_MIFlagsNot
+
+public:
+  MIFlagsInstructionPredicateMatcher(unsigned InsnVarID,
+                                     ArrayRef<StringRef> FlagsToCheck,
+                                     bool CheckNot = false)
+      : InstructionPredicateMatcher(IPM_MIFlags, InsnVarID),
+        Flags(FlagsToCheck), CheckNot(CheckNot) {
+    sort(Flags);
+  }
+
+  static bool classof(const InstructionPredicateMatcher *P) {
+    return P->getKind() == IPM_MIFlags;
+  }
+
+  bool isIdentical(const PredicateMatcher &B) const override;
+  void emitPredicateOpcodes(MatchTable &Table,
+                            RuleMatcher &Rule) const override;
+};
+
 /// Generates code to check for the absence of use of the result.
 // TODO? Generalize this to support checking for one use.
 class NoUsePredicateMatcher : public InstructionPredicateMatcher {
@@ -2233,6 +2256,10 @@ class BuildMIAction : public MatchAction {
   std::vector<std::unique_ptr<OperandRenderer>> OperandRenderers;
   SmallPtrSet<Record *, 4> DeadImplicitDefs;
 
+  std::vector<const InstructionMatcher *> CopiedFlags;
+  std::vector<StringRef> SetFlags;
+  std::vector<StringRef> UnsetFlags;
+
   /// True if the instruction can be built solely by mutating the opcode.
   bool canMutate(RuleMatcher &Rule, const InstructionMatcher *Insn) const;
 
@@ -2247,6 +2274,12 @@ class BuildMIAction : public MatchAction {
   unsigned getInsnID() const { return InsnID; }
   const CodeGenInstruction *getCGI() const { return I; }
 
+  void addSetMIFlags(StringRef Flag) { SetFlags.push_back(Flag); }
+  void addUnsetMIFlags(StringRef Flag) { UnsetFlags.push_back(Flag); }
+  void addCopiedMIFlags(const InstructionMatcher &IM) {
+    CopiedFlags.push_back(&IM);
+  }
+
   void chooseInsnToMutate(RuleMatcher &Rule);
 
   void setDeadImplicitDef(Record *R) { DeadImplicitDefs.insert(R); }



More information about the llvm-commits mailing list