[llvm] [TableGen][GlobalISel] Add rule-wide type inference (PR #66377)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 14 09:15:52 PDT 2023


================
@@ -1640,6 +1695,447 @@ class PrettyStackTraceEmit : public PrettyStackTraceEntry {
   }
 };
 
+//===- CombineRuleOperandTypeChecker --------------------------------------===//
+
+/// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
+/// On top of doing the same things as OperandTypeChecker, this also attempts to
+/// infer as many types as possible for temporary register defs & immediates in
+/// apply patterns.
+///
+/// The inference is trivial and leverages the MCOI OperandTypes encoded in
+/// CodeGenInstructions to infer types across patterns in a CombineRule. It's
+/// thus very limited and only supports CodeGenInstructions (but that's the main
+/// use case so it's fine).
+///
+/// We only try to infer untyped operands in apply patterns when they're temp
+/// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
+/// a named operand from a match pattern.
+class CombineRuleOperandTypeChecker : private OperandTypeChecker {
+public:
+  CombineRuleOperandTypeChecker(const Record &RuleDef,
+                                const OperandTable &MatchOpTable)
+      : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
+        MatchOpTable(MatchOpTable) {}
+
+  /// Records and checks a 'match' pattern.
+  bool processMatchPattern(InstructionPattern &P);
+
+  /// Records and checks an 'apply' pattern.
+  bool processApplyPattern(InstructionPattern &P);
+
+  /// Propagates types, then perform type inference and do a second round of
+  /// propagation in the apply patterns only if any types were inferred.
+  void propagateAndInferTypes();
+
+private:
+  /// \returns groups of operands of an instruction that share a common type.
+  /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
+  /// d have the same type too. b/c and a/d don't have to have the same type,
+  /// though.
+  using TypeEquivalenceClasses = std::vector<SetVector<StringRef>>;
+  static std::string toString(const SetVector<StringRef> &EqClass) {
+    return "[" + join(EqClass, ", ") + "]";
+  }
+
+  /// \returns true for `OPERAND_GENERIC_` 0 through 5.
+  /// These are the MCOI types that can be registers. The other MCOI types are
+  /// either immediates, or fancier operands used only post-ISel, so we don't
+  /// care about them for combiners.
+  static bool canMCOIOperandTypeBeARegister(StringRef MCOIType);
+
+  /// Finds the "MCOI::"" operand types for each operand of \p CGP.
+  ///
+  /// This is a bit trickier than it looks because we need to handle variadic
+  /// in/outs.
+  ///
+  /// e.g. for
+  ///   (G_BUILD_VECTOR $vec, $x, $y) ->
+  ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  ///    MCOI::OPERAND_GENERIC_1]
+  ///
+  /// For unknown types (which can happen in variadics where varargs types are
+  /// inconsistent), a unique name is given, e.g. "unknown_type_0".
+  static std::vector<std::string>
+  getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
+
+  /// Adds the TypeEquivalenceClasses for \p P in \p OutTEC.
+  void getInstEqClasses(const InstructionPattern &P,
+                        TypeEquivalenceClasses &OutTEC) const;
+
+  /// Calculates the TypeEquivalenceClasses for each instruction, then merges
+  /// them into a common set of TypeEquivalenceClasses for the whole rule.
+  ///
+  /// This works by looking for interestinc type equivalence classes, then
+  /// merging them. It loops until there's no more reductions possible.
+  ///
+  /// This essentially applies the "transitive" part of type inference. Let's
+  /// take the following equivalence classes:
+  ///   inst0: [a, b], [c, d]
+  ///   inst1: [b, c]
+  ///
+  /// If we see inst0 alone, we can't say that a and d have the same type -
+  /// they're not in the same equivalence classes. However if we just use logic,
+  /// we can say: "a == d because a == b, c == d and b == c".
+  ///
+  /// Merging condenses that information into a single big equivalence class
+  /// which can be looked at alone to make the same deduction.
+  ///   rule: [a, b, c, d]
+  TypeEquivalenceClasses getRuleEqClasses() const;
+
+  /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
+  /// TECs.
+  ///
+  /// This is achieved by trying to find a named operand in \p IP that shares
+  /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
+  /// operand instead.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferImmediateType(const InstructionPattern &IP,
+                                 unsigned ImmOpIdx,
+                                 const TypeEquivalenceClasses &TECs) const;
+
+  /// Looks inside \p TECs to infer \p OpName's type.
+  ///
+  /// \returns the inferred type or an empty PatternType if inference didn't
+  /// succeed.
+  PatternType inferNamedOperandType(const InstructionPattern &IP,
+                                    StringRef OpName,
+                                    const TypeEquivalenceClasses &TECs) const;
+
+  const Record &RuleDef;
+  SmallVector<InstructionPattern *, 8> MatchPats;
+  SmallVector<InstructionPattern *, 8> ApplyPats;
+
+  const OperandTable &MatchOpTable;
+};
+
+bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
+  MatchPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [](const auto &) {
+    // GITypeOf in match should be rejected after propagation is done.
+    return true;
+  });
+}
+
+bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
+  ApplyPats.push_back(&P);
+  return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
+    // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
+    const auto OpName = Ty.getTypeOfOpName();
+    if (MatchOpTable.lookup(OpName).Found)
+      return true;
+
+    PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
+                                     "') does not refer to a matched operand!");
+    return false;
+  });
+}
+
+void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
+  /// This works by first creating a series of "type equivalence classes".
+  ///
+  /// These are lists of InstructionPattern operands that are known to have the
+  /// same type.
+  ///
+  /// Then, we look at the apply patterns and find operands that need to be
+  /// inferred, try to find an equivalence class that they're a part of, then
+  /// select the best operand to use for the GITypeOf type. We try to aim for
+  /// defs of matched instructions because those are guaranteed to be registers.
+  propagateTypes();
+
+  TypeEquivalenceClasses TECs = getRuleEqClasses();
+
+  // Look at each operand that we want to infer one by one.
+  bool InferredAny = false;
+  for (auto *Pat : ApplyPats) {
+    for (unsigned K = 0; K < Pat->operands_size(); ++K) {
+      auto &Op = Pat->getOperand(K);
+
+      // We only want to take a look at untyped defs or immediates.
+      if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
+        continue;
+
+      // Infer defs & named immediates.
+      if (Op.isDef() || Op.isNamedImmediate()) {
+        // Check it's not a redefinition of a matched operand.
+        // In such cases, inference is not necessary because we just copy
+        // operands and don't create temporary registers.
+        if (MatchOpTable.lookup(Op.getOperandName()).Found)
+          continue;
+
+        // Inference is needed here, so try to do it.
+        if (PatternType Ty =
+                inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+
+        continue;
+      }
+
+      // Infer immediates
+      if (Op.hasImmValue()) {
+        if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
+          if (DebugTypeInfer)
+            errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << "\n";
+          Op.setType(Ty);
+          InferredAny = true;
+        }
+        continue;
+      }
+    }
+  }
+
+  // If we've inferred any types, we want to propagate them across the apply
+  // patterns. We may have added GITypeOfs referencing Match pattern operands so
+  // we don't want to also propagate there.
+  if (InferredAny) {
+    // FIXME: this is hacky, it'd be nice to have another entry point that
+    // doesn't do checking and just propagates.
+    OperandTypeChecker OTC(RuleDef.getLoc());
+    for (auto *Pat : ApplyPats) {
+      bool Res = OTC.check(*Pat, [&](const auto &) { return true; });
+      (void)Res;
+      assert(Res && "Cannot file checking here!");
+    }
+    OTC.propagateTypes();
+
+    if (DebugTypeInfer) {
+      errs() << "Apply patterns for rule " << RuleDef.getName()
+             << " after inference:\n";
+      for (auto *Pat : ApplyPats) {
+
+        Pat->print(errs(), /*PrintName*/ true);
+        errs() << "\n";
+      }
+      errs() << "\n";
+    }
+  }
+}
+
+PatternType CombineRuleOperandTypeChecker::inferImmediateType(
+    const InstructionPattern &IP, unsigned ImmOpIdx,
+    const TypeEquivalenceClasses &TECs) const {
+  // We can only infer CGPs.
+  const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
+  if (!CGP)
+    return {};
+
+  // For CGPs, we try to infer immediates by trying to infer another named
+  // operand that shares its type.
+  //
+  // e.g.
+  //    Pattern: G_BUILD_VECTOR $x, $y, 0
+  //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
+  //              MCOI::OPERAND_GENERIC_1]
+  //    $y has the same type as 0, so we can infer $y and get the type 0 should
+  //    have.
+
+  // We infer immediates by looking for a named operand that shares the same
+  // MCOI type.
+  const auto MCOITypes = getMCOIOperandTypes(*CGP);
+  StringRef ImmOpTy = MCOITypes[ImmOpIdx];
+
+  for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
+    if (Idx != ImmOpIdx && Ty == ImmOpTy) {
+      const auto &Op = IP.getOperand(Idx);
+      if (!Op.isNamedOperand())
+        continue;
+
+      // Named operand with the same name, try to infer that.
+      if (PatternType InferTy =
+              inferNamedOperandType(IP, Op.getOperandName(), TECs))
+        return InferTy;
+    }
+  }
+
+  return {};
+}
+
+PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
+    const InstructionPattern &IP, StringRef OpName,
+    const TypeEquivalenceClasses &TECs) const {
+  // Try to find a TEC that contains OpName.
+  for (const auto &TEC : TECs) {
+    if (!TEC.contains(OpName))
+      continue;
+
+    // This TEC mentions the operand. Look at all other operands in this TEC and
+    // try to find a suitable one.
+
+    // First, check for a def of a matched pattern. This is guaranteed to always
+    // be a register so we can blindly use that.
+    StringRef GoodOpName;
+    for (const auto &EqOp : TEC) {
+      if (EqOp == OpName)
+        continue;
+
+      const auto LookupRes = MatchOpTable.lookup(EqOp);
+      if (LookupRes.Def) // Favor defs
+        return PatternType::getTypeOf(EqOp);
+
+      // Otherwise just save this in case we don't find any def.
+      if (GoodOpName.empty() && LookupRes.Found)
+        GoodOpName = EqOp;
+    }
+
+    if (!GoodOpName.empty())
+      return PatternType::getTypeOf(GoodOpName);
+
+    // No good operand found, give up.
+    return {};
+  }
+
+  return {};
+}
+
+bool CombineRuleOperandTypeChecker::canMCOIOperandTypeBeARegister(
+    StringRef MCOIType) {
+  // Assume OPERAND_GENERIC_0 through 5 can be registers.
+  return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
+}
+
+std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
+    const CodeGenInstructionPattern &CGP) {
+  // TODO: Cache this?
+
+  static unsigned UnknownTypeIdx = 0;
+
+  std::vector<std::string> OpTypes;
+  auto &CGI = CGP.getInst();
+  Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
+                          ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
+                          : nullptr;
+  std::string VarArgsTyName =
+      VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
+                : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
+
+  for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
+    OpTypes.push_back(CGI.Operands[K].OperandType);
+
+  if (CGP.hasVariadicDefs()) {
+    for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
+      OpTypes.push_back(VarArgsTyName);
+  }
+
+  // If we had variadic defs, the op idx in the pattern won't match the op idx
+  // in the CGI anymore.
+  int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
+  assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
+
+  for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
+    unsigned CGIOpIdx = K + CGIOpOffset;
+    if (CGIOpIdx >= CGI.Operands.size()) {
+      assert(CGP.isVariadic());
+      OpTypes.push_back(VarArgsTyName);
+    } else {
+      OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
+    }
+  }
+
+  assert(OpTypes.size() == CGP.operands_size());
+  return OpTypes;
+}
+
+void CombineRuleOperandTypeChecker::getInstEqClasses(
+    const InstructionPattern &P, TypeEquivalenceClasses &OutTEC) const {
+  DenseMap<StringRef, std::vector<unsigned>> TyToOpIdx;
----------------
arsenm wrote:

Sink this after the early returns?

https://github.com/llvm/llvm-project/pull/66377


More information about the llvm-commits mailing list