[Mlir-commits] [mlir] [mlir][ODS] Allow inferring operand types from multiple variables (PR #127517)

Kunwar Grover llvmlistbot at llvm.org
Mon Feb 17 08:20:19 PST 2025


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/127517

This patch adds support for inferring operand types from multiple operand types.

The patch introduces a new `InferTypesFrom` class and makes the backend rely on it, since it's a more general class. The older `TypesMatchWith` class is backwards compatible with this change and works the same as before.

Inferring result types could also be added with this change, but is more complex as we need to generate a different builder call as well (which needs more intrusive changes into the operation definition builder)

>From c0aeaf53c184dff5ba7b4d5e0807888ace77d18f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Mon, 17 Feb 2025 15:59:47 +0000
Subject: [PATCH] [TableGen] Allow inferring operand types from mupltiple
 variables

---
 mlir/include/mlir/IR/OpBase.td              |  28 +++--
 mlir/lib/TableGen/Operator.cpp              |  26 ++--
 mlir/test/lib/Dialect/Test/TestOpsSyntax.td |  13 ++
 mlir/test/mlir-tblgen/op-format.mlir        |   9 ++
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp |   4 +-
 mlir/tools/mlir-tblgen/OpFormatGen.cpp      | 126 ++++++++++++--------
 6 files changed, 132 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 51b60972203e7..f992481d4aa31 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -556,20 +556,30 @@ class AllShapesMatch<list<string> names> :
 class AllTypesMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
 
-// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// A type constraint that denotes `transform(unpack(lhs.getTypes())) == rhs.getType()`.
 // An optional comparator function may be provided that changes the above form
-// into: `comparator(transform(lhs.getType()), rhs.getType())`.
-class TypesMatchWith<string summary, string lhsArg, string rhsArg,
-                     string transform, string comparator = "std::equal_to<>()">
+// into: `comparator(transform(unpack(lhs.getTypes())), rhs.getType())`.
+class InferTypesFrom<string summary, list<string> lhsArg, string rhsArg,
+                     string transform, 
+                     string comparator = "std::equal_to<>()">
   : PredOpTrait<summary, CPred<
-      comparator # "(" #
-      !subst("$_self", "$" # lhsArg # ".getType()", transform) #
-      ", $" # rhsArg # ".getType())">> {
-  string lhs = lhsArg;
-  string rhs = rhsArg;
+    comparator # "(" #
+    !foldl(transform, !range(lhsArg), acc, i, !subst("$arg" # i, "$" # lhsArg[i] # ".getType()", acc)) #
+    ", $" # rhsArg # ".getType()" # ")">> {
+  list<string> args = lhsArg;
+  string target = rhsArg;
   string transformer = transform;
 }
 
+// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// An optional comparator function may be provided that changes the above form
+// into: `comparator(transform(lhs.getType()), rhs.getType())`.
+class TypesMatchWith<string summary, string lhsArg, string rhsArg,
+                     string transform, string comparator = "std::equal_to<>()"> 
+  : InferTypesFrom<summary, [lhsArg], rhsArg, 
+                   !subst("$_self", "$arg0", transform), 
+                   comparator>;
+
 // The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
 // and not present returns success.
 class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 20a43ef15d09e..ec5561b89ea74 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -387,7 +387,8 @@ void Operator::populateTypeInferenceInfo(
   if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
     // Check for a non-variable length operand to use as the type anchor.
     auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
-      NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
+      NamedTypeConstraint *operand =
+          llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
       return operand && !operand->isVariableLength();
     });
     if (operandI == arguments.end())
@@ -396,7 +397,7 @@ void Operator::populateTypeInferenceInfo(
     // All result types are inferred from the operand type.
     int operandIdx = operandI - arguments.begin();
     for (int i = 0; i < getNumResults(); ++i)
-      resultTypeMapping.emplace_back(operandIdx, "$_self");
+      resultTypeMapping.emplace_back(operandIdx, "$arg0");
 
     allResultsHaveKnownTypes = true;
     traits.push_back(Trait::create(inferTrait->getDefInit()));
@@ -424,12 +425,12 @@ void Operator::populateTypeInferenceInfo(
   for (auto [idx, infer] : llvm::enumerate(inference)) {
     if (getResult(idx).constraint.getBuilderCall()) {
       infer.sources.emplace_back(InferredResultType::mapResultIndex(idx),
-                                 "$_self");
+                                 "$arg0");
       infer.inferred = true;
     }
   }
 
-  // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
+  // Use `AllTypesMatch` and `InferTypesFrom` operation traits to build the
   // result type inference graph.
   for (const Trait &trait : traits) {
     const Record &def = trait.getDef();
@@ -445,10 +446,11 @@ void Operator::populateTypeInferenceInfo(
       if (&traitDef->getDef() == inferTrait)
         return;
 
-    // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
+    // The `InferTypesFrom` trait represents a 1 -> 1 type inference edge with a
     // type transformer.
-    if (def.isSubClassOf("TypesMatchWith")) {
-      int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs"));
+    if (def.isSubClassOf("InferTypesFrom")) {
+      int target =
+          argumentsAndResultsIndex.lookup(def.getValueAsString("target"));
       // Ignore operand type inference.
       if (InferredResultType::isArgIndex(target))
         continue;
@@ -457,8 +459,10 @@ void Operator::populateTypeInferenceInfo(
       // If the type of the result has already been inferred, do nothing.
       if (infer.inferred)
         continue;
-      int sourceIndex =
-          argumentsAndResultsIndex.lookup(def.getValueAsString("lhs"));
+      std::vector<StringRef> args = def.getValueAsListOfStrings("args");
+      assert(args.size() == 1 &&
+             "multiple arguments for result inference not yet supported.");
+      int sourceIndex = argumentsAndResultsIndex.lookup(args[0]);
       infer.sources.emplace_back(sourceIndex,
                                  def.getValueAsString("transformer").str());
       // Locally propagate inferredness.
@@ -493,7 +497,7 @@ void Operator::populateTypeInferenceInfo(
       for (int resultIndex : resultIndices) {
         ResultTypeInference &infer = inference[resultIndex];
         if (!infer.inferred) {
-          infer.sources.assign(1, {*fullyInferredIndex, "$_self"});
+          infer.sources.assign(1, {*fullyInferredIndex, "$arg0"});
           infer.inferred = true;
         }
       }
@@ -504,7 +508,7 @@ void Operator::populateTypeInferenceInfo(
           if (resultIndex == otherResultIndex)
             continue;
           inference[resultIndex].sources.emplace_back(
-              InferredResultType::unmapResultIndex(otherResultIndex), "$_self");
+              InferredResultType::unmapResultIndex(otherResultIndex), "$arg0");
         }
       }
     }
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
index 2848cb994231b..33e4b9a623636 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
@@ -653,6 +653,19 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
   let assemblyFormat = "attr-dict $value `:` type($value)";
 }
 
+//===----------------------------------------------------------------------===//
+// InferTypesFrom type inference
+
+def FormatTypesMatchMultipleVarOp : TEST_Op<"format_types_match_multiple_var", [
+    InferTypesFrom<"result type is a tuple of types of value1 and value2",
+                   ["value1", "result"], "value2",
+                   "TupleType::get($_ctxt, {$arg0, $arg1})">]> {
+  let arguments = (ins AnyType:$value1,
+                       AnyType:$value2);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1) `->` type($result)";
+}
+
 //===----------------------------------------------------------------------===//
 // InferTypeOpInterface type inference in assembly format
 
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 08b0c52413a75..445826afb5ed5 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -4,6 +4,8 @@
 %i64 = "foo.op"() : () -> (i64)
 // CHECK: %[[I32:.*]] =
 %i32 = "foo.op"() : () -> (i32)
+// CHECK: %[[I64_I32_TUP:.*]]
+%i64_i32_tuple = "foo.op"() : () -> (tuple<i64, i32>)
 // CHECK: %[[MEMREF:.*]] =
 %memref = "foo.op"() : () -> (memref<1xf64>)
 
@@ -481,6 +483,13 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 // CHECK: test.format_types_match_context %[[I64]] : i64
 %ignored_res6 = test.format_types_match_context %i64 : i64
 
+//===----------------------------------------------------------------------===//
+// InferTypesFrom type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_types_match_multiple_var
+%ignored_res6a = test.format_types_match_multiple_var %i64, %i64_i32_tuple  : i64 -> i32
+
 //===----------------------------------------------------------------------===//
 // InferTypeOpInterface type inference
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 629e863dac5e3..ab50686e67ba6 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3707,7 +3707,6 @@ void OpEmitter::genTypeInterfaceMethods() {
           typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
                      "].getType()")
                         .str();
-
           // If this is an attribute, index into the attribute dictionary.
         } else {
           auto *attr =
@@ -3743,7 +3742,8 @@ void OpEmitter::genTypeInterfaceMethods() {
         continue;
       }
       body << "  ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
-           << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
+           << tgfmt(infer.getTransformer(), &fctx.addSubst("arg0", typeStr))
+           << ";\n";
       constructedIndices[i] = inferredTypeIdx - 1;
     }
   }
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fe724e86d6707..81edf36f9d4d7 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -303,22 +303,33 @@ struct OperationFormat {
     std::optional<int> getBuilderIdx() const { return builderIdx; }
     void setBuilderIdx(int idx) { builderIdx = idx; }
 
+    int getNumArgs() const { return resolver.size(); }
+
     /// Get the variable this type is resolved to, or nullptr.
-    const NamedTypeConstraint *getVariable() const {
-      return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
+    const NamedTypeConstraint *getVariable(int i) const {
+      return resolver.empty()
+                 ? nullptr
+                 : llvm::dyn_cast_if_present<const NamedTypeConstraint *>(
+                       resolver[i]);
     }
     /// Get the attribute this type is resolved to, or nullptr.
-    const NamedAttribute *getAttribute() const {
-      return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
+    const NamedAttribute *getAttribute(int i) const {
+      return resolver.empty()
+                 ? nullptr
+                 : llvm::dyn_cast_if_present<const NamedAttribute *>(
+                       resolver[i]);
     }
     /// Get the transformer for the type of the variable, or std::nullopt.
     std::optional<StringRef> getVarTransformer() const {
       return variableTransformer;
     }
-    void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
+    void setResolver(const SmallVector<ConstArgument, 1> &arg,
+                     std::optional<StringRef> transformer) {
       resolver = arg;
       variableTransformer = transformer;
-      assert(getVariable() || getAttribute());
+      assert(llvm::all_of(llvm::seq<int>(arg.size()), [&](int i) {
+        return getVariable(i) || getAttribute(i);
+      }));
     }
 
   private:
@@ -327,7 +338,7 @@ struct OperationFormat {
     std::optional<int> builderIdx;
     /// If the type is resolved based upon another operand or result, this is
     /// the variable or the attribute that this type is resolved to.
-    ConstArgument resolver;
+    SmallVector<ConstArgument, 1> resolver;
     /// If the type is resolved based upon another operand or result, this is
     /// a transformer to apply to the variable when resolving.
     std::optional<StringRef> variableTransformer;
@@ -1685,23 +1696,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
     std::optional<StringRef> transformer = resolver.getVarTransformer();
     if (!transformer)
       continue;
-    // Ensure that we don't verify the same variables twice.
-    const NamedTypeConstraint *variable = resolver.getVariable();
-    if (!variable || !verifiedVariables.insert(variable).second)
-      continue;
+    for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
+      // Ensure that we don't verify the same variables twice.
+      const NamedTypeConstraint *variable = resolver.getVariable(i);
+      if (!variable || !verifiedVariables.insert(variable).second)
+        continue;
 
-    auto constraint = variable->constraint;
-    body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
-         << "    (void)type;\n"
-         << "    if (!("
-         << tgfmt(constraint.getConditionTemplate(),
-                  &verifierFCtx.withSelf("type"))
-         << ")) {\n"
-         << formatv("      return parser.emitError(parser.getNameLoc()) << "
-                    "\"'{0}' must be {1}, but got \" << type;\n",
-                    variable->name, constraint.getSummary())
-         << "    }\n"
-         << "  }\n";
+      auto constraint = variable->constraint;
+      body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
+           << "    (void)type;\n"
+           << "    if (!("
+           << tgfmt(constraint.getConditionTemplate(),
+                    &verifierFCtx.withSelf("type"))
+           << ")) {\n"
+           << formatv("      return parser.emitError(parser.getNameLoc()) << "
+                      "\"'{0}' must be {1}, but got \" << type;\n",
+                      variable->name, constraint.getSummary())
+           << "    }\n"
+           << "  }\n";
+    }
   }
 
   // Initialize the set of buildable types.
@@ -1717,26 +1730,30 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
     if (std::optional<int> val = resolver.getBuilderIdx()) {
       body << "odsBuildableType" << *val;
-    } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
-      if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
-        FmtContext fmtContext;
-        fmtContext.addSubst("_ctxt", "parser.getContext()");
-        if (var->isVariadic())
-          fmtContext.withSelf(var->name + "Types");
-        else
-          fmtContext.withSelf(var->name + "Types[0]");
-        body << tgfmt(*tform, &fmtContext);
-      } else {
-        body << var->name << "Types";
-        if (!var->isVariadic())
-          body << "[0]";
+    } else if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
+      FmtContext fmtContext;
+      fmtContext.addSubst("_ctxt", "parser.getContext()");
+      for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
+        std::string substName = "arg" + std::to_string(i);
+        if (const NamedTypeConstraint *var = resolver.getVariable(i)) {
+          if (var->isVariadic())
+            fmtContext.addSubst(substName, var->name + "Types");
+          else
+            fmtContext.addSubst(substName, var->name + "Types[0]");
+        } else if (const NamedAttribute *attr = resolver.getAttribute(i)) {
+          fmtContext.addSubst(substName, attr->name + "Attr.getType()");
+        } else {
+          assert(false && "resolver arguements should be a type constraint or "
+                          "an attribute");
+        }
       }
-    } else if (const NamedAttribute *attr = resolver.getAttribute()) {
-      if (std::optional<StringRef> tform = resolver.getVarTransformer())
-        body << tgfmt(*tform,
-                      &FmtContext().withSelf(attr->name + "Attr.getType()"));
-      else
-        body << attr->name << "Attr.getType()";
+      body << tgfmt(*tform, &fmtContext);
+    } else if (const NamedTypeConstraint *var = resolver.getVariable(0)) {
+      body << var->name << "Types";
+      if (!var->isVariadic())
+        body << "[0]";
+    } else if (const NamedAttribute *attr = resolver.getAttribute(0)) {
+      body << attr->name << "Attr.getType()";
     } else {
       body << curVar << "Types";
     }
@@ -2717,7 +2734,7 @@ class OpFormatParser : public FormatParser {
   /// type as well as an optional transformer to apply to that type in order to
   /// properly resolve the type of a variable.
   struct TypeResolutionInstance {
-    ConstArgument resolver;
+    SmallVector<ConstArgument, 1> resolver;
     std::optional<StringRef> transformer;
   };
 
@@ -2827,7 +2844,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
     } else if (def.getName() == "SameOperandsAndResultType") {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
-    } else if (def.isSubClassOf("TypesMatchWith")) {
+    } else if (def.isSubClassOf("InferTypesFrom")) {
       handleTypesMatchConstraint(variableTyResolver, def);
     } else if (!op.allResultTypesKnown()) {
       // This doesn't check the name directly to handle
@@ -3228,9 +3245,9 @@ void OpFormatParser::handleAllTypesMatchConstraint(
 
     // Mark this value as the type resolver for the other variables.
     for (unsigned j = 0; j != i; ++j)
-      variableTyResolver[values[j]] = {arg, std::nullopt};
+      variableTyResolver[values[j]] = {{arg}, std::nullopt};
     for (unsigned j = i + 1; j != e; ++j)
-      variableTyResolver[values[j]] = {arg, std::nullopt};
+      variableTyResolver[values[j]] = {{arg}, std::nullopt};
   }
 }
 
@@ -3251,21 +3268,26 @@ void OpFormatParser::handleSameTypesConstraint(
   // Set the resolvers for each operand and result.
   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
     if (!seenOperandTypes.test(i))
-      variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
+      variableTyResolver[op.getOperand(i).name] = {{resolver}, std::nullopt};
   if (includeResults) {
     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
       if (!seenResultTypes.test(i))
-        variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
+        variableTyResolver[op.getResultName(i)] = {{resolver}, std::nullopt};
   }
 }
 
 void OpFormatParser::handleTypesMatchConstraint(
     StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
-  StringRef lhsName = def.getValueAsString("lhs");
-  StringRef rhsName = def.getValueAsString("rhs");
+  std::vector<StringRef> args = def.getValueAsListOfStrings("args");
+  StringRef target = def.getValueAsString("target");
   StringRef transformer = def.getValueAsString("transformer");
-  if (ConstArgument arg = findSeenArg(lhsName))
-    variableTyResolver[rhsName] = {arg, transformer};
+
+  SmallVector<ConstArgument, 1> resolutionArgs;
+  llvm::for_each(args, [&](StringRef arg) {
+    if (ConstArgument seenArg = findSeenArg(arg))
+      resolutionArgs.push_back(seenArg);
+  });
+  variableTyResolver[target] = {resolutionArgs, transformer};
 }
 
 ConstArgument OpFormatParser::findSeenArg(StringRef name) {



More information about the Mlir-commits mailing list