[Mlir-commits] [mlir] 72df59d - [mlir] resolve types from attributes in assemblyFormat

Mehdi Amini llvmlistbot at llvm.org
Mon Jul 6 21:40:33 PDT 2020


Author: Martin Waitz
Date: 2020-07-07T04:40:01Z
New Revision: 72df59d59097a0c9ed5f420401cf377eaff896aa

URL: https://github.com/llvm/llvm-project/commit/72df59d59097a0c9ed5f420401cf377eaff896aa
DIFF: https://github.com/llvm/llvm-project/commit/72df59d59097a0c9ed5f420401cf377eaff896aa.diff

LOG: [mlir] resolve types from attributes in assemblyFormat

An operation can specify that an operation or result type matches the
type of another operation, result, or attribute via the `AllTypesMatch`
or `TypesMatchWith` constraints.

Use these constraints to also automatically resolve types in the
automatically generated assembly parser.
This way, only the attribute needs to be listed in `assemblyFormat`,
e.g. for constant operations.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D78434

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 01dcb722ad07..025eaf616d73 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -780,8 +780,8 @@ There are many operations that have known type equality constraints registered
 as traits on the operation; for example the true, false, and result values of a
 `select` operation often have the same type. The assembly format may inspect
 these equal constraints to discern the types of missing variables. The currently
-supported traits are: `AllTypesMatch`, `SameTypeOperands`, and
-`SameOperandsAndResultType`.
+supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`,
+and `SameOperandsAndResultType`.
 
 ### `hasCanonicalizer`
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 14b2a7851be1..19e636b3df32 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1352,6 +1352,46 @@ def FormatInferVariadicTypeFromNonVariadic
   let assemblyFormat = "$operands attr-dict `:` type($result)";
 }
 
+//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
+//===----------------------------------------------------------------------===//
+
+def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
+    AllTypesMatch<["value1", "value2", "result"]>
+  ]> {
+  let arguments = (ins AnyType:$value1, AnyType:$value2);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
+}
+
+def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
+    AllTypesMatch<["value1", "value2", "result"]>
+  ]> {
+  let arguments = (ins AnyAttr:$value1, AnyType:$value2);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "attr-dict $value1 `,` $value2";
+}
+
+//===----------------------------------------------------------------------===//
+// TypesMatchWith type inference
+//===----------------------------------------------------------------------===//
+
+def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
+    TypesMatchWith<"result type matches operand", "value", "result", "$_self">
+  ]> {
+  let arguments = (ins AnyType:$value);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
+def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
+    TypesMatchWith<"result type matches constant", "value", "result", "$_self">
+  ]> {
+  let arguments = (ins AnyAttr:$value);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "attr-dict $value";
+}
+
 //===----------------------------------------------------------------------===//
 // Test SideEffects
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 157f6cbf4159..49ac3d26f926 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -108,3 +108,23 @@ test.format_optional_operand_result_b_op : i64
 
 // CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64
 test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
+
+//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64
+%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64
+
+// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]]
+%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64
+
+//===----------------------------------------------------------------------===//
+// TypesMatchWith type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_types_match_var %[[I64]] : i64
+%ignored_res3 = test.format_types_match_var %i64 : i64
+
+// CHECK: test.format_types_match_attr 1 : i64
+%ignored_res4 = test.format_types_match_attr 1 : i64

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 8628cf252712..1cfcf32f8c06 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -270,6 +270,10 @@ class OptionalElement : public Element {
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+using ConstArgument =
+    llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
+
 struct OperationFormat {
   /// This class represents a specific resolver for an operand or result type.
   class TypeResolution {
@@ -280,15 +284,22 @@ struct OperationFormat {
     Optional<int> getBuilderIdx() const { return builderIdx; }
     void setBuilderIdx(int idx) { builderIdx = idx; }
 
-    /// Get the variable this type is resolved to, or None.
-    const NamedTypeConstraint *getVariable() const { return variable; }
+    /// Get the variable this type is resolved to, or nullptr.
+    const NamedTypeConstraint *getVariable() const {
+      return resolver.dyn_cast<const NamedTypeConstraint *>();
+    }
+    /// Get the attribute this type is resolved to, or nullptr.
+    const NamedAttribute *getAttribute() const {
+      return resolver.dyn_cast<const NamedAttribute *>();
+    }
+    /// Get the transformer for the type of the variable, or None.
     Optional<StringRef> getVarTransformer() const {
       return variableTransformer;
     }
-    void setVariable(const NamedTypeConstraint *var,
-                     Optional<StringRef> transformer) {
-      variable = var;
+    void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
+      resolver = arg;
       variableTransformer = transformer;
+      assert(getVariable() || getAttribute());
     }
 
   private:
@@ -296,8 +307,8 @@ struct OperationFormat {
     /// 'buildableTypes' in the parent format.
     Optional<int> builderIdx;
     /// If the type is resolved based upon another operand or result, this is
-    /// the variable that this type is resolved to.
-    const NamedTypeConstraint *variable;
+    /// the variable or the attribute that this type is resolved to.
+    ConstArgument resolver;
     /// If the type is resolved based upon another operand or result, this is
     /// a transformer to apply to the variable when resolving.
     Optional<StringRef> variableTransformer;
@@ -729,7 +740,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
       continue;
     // Ensure that we don't verify the same variables twice.
     const NamedTypeConstraint *variable = resolver.getVariable();
-    if (!verifiedVariables.insert(variable).second)
+    if (!variable || !verifiedVariables.insert(variable).second)
       continue;
 
     auto constraint = variable->constraint;
@@ -764,6 +775,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
         body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
       else
         body << var->name << "Types";
+    } else if (const NamedAttribute *attr = resolver.getAttribute()) {
+      if (Optional<StringRef> tform = resolver.getVarTransformer())
+        body << tgfmt(*tform,
+                      &FmtContext().withSelf(attr->name + "Attr.getType()"));
+      else
+        body << attr->name << "Attr.getType()";
     } else {
       body << curVar << "Types";
     }
@@ -1353,7 +1370,7 @@ class FormatParser {
   /// type as well as an optional transformer to apply to that type in order to
   /// properly resolve the type of a variable.
   struct TypeResolutionInstance {
-    const NamedTypeConstraint *type;
+    ConstArgument resolver;
     Optional<StringRef> transformer;
   };
 
@@ -1392,10 +1409,15 @@ class FormatParser {
   void handleSameTypesConstraint(
       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
       bool includeResults);
+  /// Check for inferable type resolution based on another operand, result, or
+  /// attribute.
+  void handleTypesMatchConstraint(
+      llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
+      llvm::Record def);
 
-  /// Returns an argument with the given name that has been seen within the
-  /// format.
-  const NamedTypeConstraint *findSeenArg(StringRef name);
+  /// Returns an argument or attribute with the given name that has been seen
+  /// within the format.
+  ConstArgument findSeenArg(StringRef name);
 
   /// Parse a specific element.
   LogicalResult parseElement(std::unique_ptr<Element> &element,
@@ -1504,9 +1526,7 @@ LogicalResult FormatParser::parse() {
     } else if (def.getName() == "SameOperandsAndResultType") {
       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
     } else if (def.isSubClassOf("TypesMatchWith")) {
-      if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
-        variableTyResolver[def.getValueAsString("rhs")] = {
-            lhsArg, def.getValueAsString("transformer")};
+      handleTypesMatchConstraint(variableTyResolver, def);
     }
   }
 
@@ -1615,8 +1635,8 @@ LogicalResult FormatParser::verifyOperands(
     // Check to see if we can infer this type from another variable.
     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
     if (varResolverIt != variableTyResolver.end()) {
-      fmt.operandTypes[i].setVariable(varResolverIt->second.type,
-                                      varResolverIt->second.transformer);
+      TypeResolutionInstance &resolver = varResolverIt->second;
+      fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
       continue;
     }
 
@@ -1654,8 +1674,8 @@ LogicalResult FormatParser::verifyResults(
     // Check to see if we can infer this type from another variable.
     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
     if (varResolverIt != variableTyResolver.end()) {
-      fmt.resultTypes[i].setVariable(varResolverIt->second.type,
-                                     varResolverIt->second.transformer);
+      TypeResolutionInstance resolver = varResolverIt->second;
+      fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
       continue;
     }
 
@@ -1702,7 +1722,7 @@ void FormatParser::handleAllTypesMatchConstraint(
     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
   for (unsigned i = 0, e = values.size(); i != e; ++i) {
     // Check to see if this value matches a resolved operand or result type.
-    const NamedTypeConstraint *arg = findSeenArg(values[i]);
+    ConstArgument arg = findSeenArg(values[i]);
     if (!arg)
       continue;
 
@@ -1739,11 +1759,23 @@ void FormatParser::handleSameTypesConstraint(
   }
 }
 
-const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
-  if (auto *arg = findArg(op.getOperands(), name))
+void FormatParser::handleTypesMatchConstraint(
+    llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
+    llvm::Record def) {
+  StringRef lhsName = def.getValueAsString("lhs");
+  StringRef rhsName = def.getValueAsString("rhs");
+  StringRef transformer = def.getValueAsString("transformer");
+  if (ConstArgument arg = findSeenArg(lhsName))
+    variableTyResolver[rhsName] = {arg, transformer};
+}
+
+ConstArgument FormatParser::findSeenArg(StringRef name) {
+  if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
-  if (auto *arg = findArg(op.getResults(), name))
+  if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
+  if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
+    return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
   return nullptr;
 }
 


        


More information about the Mlir-commits mailing list