[Mlir-commits] [mlir] 72df59d - [mlir] resolve types from attributes in assemblyFormat
Mehdi Amini
llvmlistbot at llvm.org
Mon Jul 6 21:40:33 PDT 2020
Author: Martin Waitz
Date: 2020-07-07T04:40:01Z
New Revision: 72df59d59097a0c9ed5f420401cf377eaff896aa
URL: https://github.com/llvm/llvm-project/commit/72df59d59097a0c9ed5f420401cf377eaff896aa
DIFF: https://github.com/llvm/llvm-project/commit/72df59d59097a0c9ed5f420401cf377eaff896aa.diff
LOG: [mlir] resolve types from attributes in assemblyFormat
An operation can specify that an operation or result type matches the
type of another operation, result, or attribute via the `AllTypesMatch`
or `TypesMatchWith` constraints.
Use these constraints to also automatically resolve types in the
automatically generated assembly parser.
This way, only the attribute needs to be listed in `assemblyFormat`,
e.g. for constant operations.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D78434
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 01dcb722ad07..025eaf616d73 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -780,8 +780,8 @@ There are many operations that have known type equality constraints registered
as traits on the operation; for example the true, false, and result values of a
`select` operation often have the same type. The assembly format may inspect
these equal constraints to discern the types of missing variables. The currently
-supported traits are: `AllTypesMatch`, `SameTypeOperands`, and
-`SameOperandsAndResultType`.
+supported traits are: `AllTypesMatch`, `TypesMatchWith`, `SameTypeOperands`,
+and `SameOperandsAndResultType`.
### `hasCanonicalizer`
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 14b2a7851be1..19e636b3df32 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1352,6 +1352,46 @@ def FormatInferVariadicTypeFromNonVariadic
let assemblyFormat = "$operands attr-dict `:` type($result)";
}
+//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
+//===----------------------------------------------------------------------===//
+
+def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
+ AllTypesMatch<["value1", "value2", "result"]>
+ ]> {
+ let arguments = (ins AnyType:$value1, AnyType:$value2);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)";
+}
+
+def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
+ AllTypesMatch<["value1", "value2", "result"]>
+ ]> {
+ let arguments = (ins AnyAttr:$value1, AnyType:$value2);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value1 `,` $value2";
+}
+
+//===----------------------------------------------------------------------===//
+// TypesMatchWith type inference
+//===----------------------------------------------------------------------===//
+
+def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
+ TypesMatchWith<"result type matches operand", "value", "result", "$_self">
+ ]> {
+ let arguments = (ins AnyType:$value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
+def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
+ TypesMatchWith<"result type matches constant", "value", "result", "$_self">
+ ]> {
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value";
+}
+
//===----------------------------------------------------------------------===//
// Test SideEffects
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 157f6cbf4159..49ac3d26f926 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -108,3 +108,23 @@ test.format_optional_operand_result_b_op : i64
// CHECK: test.format_infer_variadic_type_from_non_variadic %[[I64]], %[[I64]] : i64
test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
+
+//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_all_types_match_var %[[I64]], %[[I64]] : i64
+%ignored_res1 = test.format_all_types_match_var %i64, %i64 : i64
+
+// CHECK: test.format_all_types_match_attr 1 : i64, %[[I64]]
+%ignored_res2 = test.format_all_types_match_attr 1 : i64, %i64
+
+//===----------------------------------------------------------------------===//
+// TypesMatchWith type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_types_match_var %[[I64]] : i64
+%ignored_res3 = test.format_types_match_var %i64 : i64
+
+// CHECK: test.format_types_match_attr 1 : i64
+%ignored_res4 = test.format_types_match_attr 1 : i64
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 8628cf252712..1cfcf32f8c06 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -270,6 +270,10 @@ class OptionalElement : public Element {
//===----------------------------------------------------------------------===//
namespace {
+
+using ConstArgument =
+ llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
+
struct OperationFormat {
/// This class represents a specific resolver for an operand or result type.
class TypeResolution {
@@ -280,15 +284,22 @@ struct OperationFormat {
Optional<int> getBuilderIdx() const { return builderIdx; }
void setBuilderIdx(int idx) { builderIdx = idx; }
- /// Get the variable this type is resolved to, or None.
- const NamedTypeConstraint *getVariable() const { return variable; }
+ /// Get the variable this type is resolved to, or nullptr.
+ const NamedTypeConstraint *getVariable() const {
+ return resolver.dyn_cast<const NamedTypeConstraint *>();
+ }
+ /// Get the attribute this type is resolved to, or nullptr.
+ const NamedAttribute *getAttribute() const {
+ return resolver.dyn_cast<const NamedAttribute *>();
+ }
+ /// Get the transformer for the type of the variable, or None.
Optional<StringRef> getVarTransformer() const {
return variableTransformer;
}
- void setVariable(const NamedTypeConstraint *var,
- Optional<StringRef> transformer) {
- variable = var;
+ void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
+ resolver = arg;
variableTransformer = transformer;
+ assert(getVariable() || getAttribute());
}
private:
@@ -296,8 +307,8 @@ struct OperationFormat {
/// 'buildableTypes' in the parent format.
Optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
- /// the variable that this type is resolved to.
- const NamedTypeConstraint *variable;
+ /// the variable or the attribute that this type is resolved to.
+ ConstArgument resolver;
/// If the type is resolved based upon another operand or result, this is
/// a transformer to apply to the variable when resolving.
Optional<StringRef> variableTransformer;
@@ -729,7 +740,7 @@ void OperationFormat::genParserTypeResolution(Operator &op,
continue;
// Ensure that we don't verify the same variables twice.
const NamedTypeConstraint *variable = resolver.getVariable();
- if (!verifiedVariables.insert(variable).second)
+ if (!variable || !verifiedVariables.insert(variable).second)
continue;
auto constraint = variable->constraint;
@@ -764,6 +775,12 @@ void OperationFormat::genParserTypeResolution(Operator &op,
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
else
body << var->name << "Types";
+ } else if (const NamedAttribute *attr = resolver.getAttribute()) {
+ if (Optional<StringRef> tform = resolver.getVarTransformer())
+ body << tgfmt(*tform,
+ &FmtContext().withSelf(attr->name + "Attr.getType()"));
+ else
+ body << attr->name << "Attr.getType()";
} else {
body << curVar << "Types";
}
@@ -1353,7 +1370,7 @@ class FormatParser {
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
struct TypeResolutionInstance {
- const NamedTypeConstraint *type;
+ ConstArgument resolver;
Optional<StringRef> transformer;
};
@@ -1392,10 +1409,15 @@ class FormatParser {
void handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults);
+ /// Check for inferable type resolution based on another operand, result, or
+ /// attribute.
+ void handleTypesMatchConstraint(
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
+ llvm::Record def);
- /// Returns an argument with the given name that has been seen within the
- /// format.
- const NamedTypeConstraint *findSeenArg(StringRef name);
+ /// Returns an argument or attribute with the given name that has been seen
+ /// within the format.
+ ConstArgument findSeenArg(StringRef name);
/// Parse a specific element.
LogicalResult parseElement(std::unique_ptr<Element> &element,
@@ -1504,9 +1526,7 @@ LogicalResult FormatParser::parse() {
} else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
- if (const auto *lhsArg = findSeenArg(def.getValueAsString("lhs")))
- variableTyResolver[def.getValueAsString("rhs")] = {
- lhsArg, def.getValueAsString("transformer")};
+ handleTypesMatchConstraint(variableTyResolver, def);
}
}
@@ -1615,8 +1635,8 @@ LogicalResult FormatParser::verifyOperands(
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
if (varResolverIt != variableTyResolver.end()) {
- fmt.operandTypes[i].setVariable(varResolverIt->second.type,
- varResolverIt->second.transformer);
+ TypeResolutionInstance &resolver = varResolverIt->second;
+ fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
continue;
}
@@ -1654,8 +1674,8 @@ LogicalResult FormatParser::verifyResults(
// Check to see if we can infer this type from another variable.
auto varResolverIt = variableTyResolver.find(op.getResultName(i));
if (varResolverIt != variableTyResolver.end()) {
- fmt.resultTypes[i].setVariable(varResolverIt->second.type,
- varResolverIt->second.transformer);
+ TypeResolutionInstance resolver = varResolverIt->second;
+ fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
continue;
}
@@ -1702,7 +1722,7 @@ void FormatParser::handleAllTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
- const NamedTypeConstraint *arg = findSeenArg(values[i]);
+ ConstArgument arg = findSeenArg(values[i]);
if (!arg)
continue;
@@ -1739,11 +1759,23 @@ void FormatParser::handleSameTypesConstraint(
}
}
-const NamedTypeConstraint *FormatParser::findSeenArg(StringRef name) {
- if (auto *arg = findArg(op.getOperands(), name))
+void FormatParser::handleTypesMatchConstraint(
+ llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
+ llvm::Record def) {
+ StringRef lhsName = def.getValueAsString("lhs");
+ StringRef rhsName = def.getValueAsString("rhs");
+ StringRef transformer = def.getValueAsString("transformer");
+ if (ConstArgument arg = findSeenArg(lhsName))
+ variableTyResolver[rhsName] = {arg, transformer};
+}
+
+ConstArgument FormatParser::findSeenArg(StringRef name) {
+ if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
- if (auto *arg = findArg(op.getResults(), name))
+ if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
+ if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
+ return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
return nullptr;
}
More information about the Mlir-commits
mailing list