[Mlir-commits] [mlir] 1bd1eda - [mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface
River Riddle
llvmlistbot at llvm.org
Thu Apr 28 12:58:38 PDT 2022
Author: River Riddle
Date: 2022-04-28T12:57:59-07:00
New Revision: 1bd1edaf4006ff66a88ac59e0931f22105003a26
URL: https://github.com/llvm/llvm-project/commit/1bd1edaf4006ff66a88ac59e0931f22105003a26
DIFF: https://github.com/llvm/llvm-project/commit/1bd1edaf4006ff66a88ac59e0931f22105003a26.diff
LOG: [mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface
This allows for using attribute types in result type inference for use with
InferTypeOpInterface. This was a TODO before, but it isn't much
additional work to properly support this. After this commit,
arith::ConstantOp can now have its InferTypeOpInterface implementation automatically
generated.
Differential Revision: https://reviews.llvm.org/D124580
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/TableGen/CodeGenHelpers.h
mlir/include/mlir/TableGen/Constraint.h
mlir/lib/TableGen/Constraint.cpp
mlir/lib/TableGen/Operator.cpp
mlir/python/mlir/dialects/_arith_ops_ext.py
mlir/test/Dialect/Arithmetic/invalid.mlir
mlir/test/IR/diagnostic-handler.mlir
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 6305c6947c75..a46c5119296b 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -124,9 +124,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- TypesMatchWith<
- "result and attribute have the same type",
- "value", "result", "$_self">]> {
+ AllTypesMatch<["value", "result"]>]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
@@ -154,8 +152,6 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
let builders = [
- OpBuilder<(ins "Attribute":$value),
- [{ build($_builder, $_state, value.getType(), value); }]>,
OpBuilder<(ins "Attribute":$value, "Type":$type),
[{ build($_builder, $_state, type, value); }]>,
];
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index d4d3294ea2a6..69b3a897fc38 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -187,19 +187,9 @@ class StaticVerifierFunctionEmitter {
/// ensure that the static functions have a unique name.
std::string uniqueOutputLabel;
- /// Unique constraints by their predicate and summary. Constraints that share
- /// the same predicate may have
diff erent descriptions; ensure that the
- /// correct error message is reported when verification fails.
- struct ConstraintUniquer {
- static Constraint getEmptyKey();
- static Constraint getTombstoneKey();
- static unsigned getHashValue(Constraint constraint);
- static bool isEqual(Constraint lhs, Constraint rhs);
- };
/// Use a MapVector to ensure that functions are generated deterministically.
- using ConstraintMap =
- llvm::MapVector<Constraint, std::string,
- llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
+ using ConstraintMap = llvm::MapVector<Constraint, std::string,
+ llvm::DenseMap<Constraint, unsigned>>;
/// A generic function to emit constraints
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h
index 0f6c2b58faaf..0c74f89189e9 100644
--- a/mlir/include/mlir/TableGen/Constraint.h
+++ b/mlir/include/mlir/TableGen/Constraint.h
@@ -94,4 +94,20 @@ struct AppliedConstraint {
} // namespace tblgen
} // namespace mlir
+namespace llvm {
+/// Unique constraints by their predicate and summary. Constraints that share
+/// the same predicate may have
diff erent descriptions; ensure that the
+/// correct error message is reported when verification fails.
+template <>
+struct DenseMapInfo<mlir::tblgen::Constraint> {
+ using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
+
+ static mlir::tblgen::Constraint getEmptyKey();
+ static mlir::tblgen::Constraint getTombstoneKey();
+ static unsigned getHashValue(mlir::tblgen::Constraint constraint);
+ static bool isEqual(mlir::tblgen::Constraint lhs,
+ mlir::tblgen::Constraint rhs);
+};
+} // namespace llvm
+
#endif // MLIR_TABLEGEN_CONSTRAINT_H_
diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp
index 0c5e034a8ee6..8e62120ad837 100644
--- a/mlir/lib/TableGen/Constraint.cpp
+++ b/mlir/lib/TableGen/Constraint.cpp
@@ -108,3 +108,34 @@ AppliedConstraint::AppliedConstraint(Constraint &&constraint,
std::vector<std::string> &&entities)
: constraint(constraint), self(std::string(self)),
entities(std::move(entities)) {}
+
+Constraint DenseMapInfo<Constraint>::getEmptyKey() {
+ return Constraint(RecordDenseMapInfo::getEmptyKey(),
+ Constraint::CK_Uncategorized);
+}
+
+Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
+ return Constraint(RecordDenseMapInfo::getTombstoneKey(),
+ Constraint::CK_Uncategorized);
+}
+
+unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
+ if (constraint == getEmptyKey())
+ return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
+ if (constraint == getTombstoneKey()) {
+ return RecordDenseMapInfo::getHashValue(
+ RecordDenseMapInfo::getTombstoneKey());
+ }
+ return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
+}
+
+bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
+ if (lhs == rhs)
+ return true;
+ if (lhs == getEmptyKey() || lhs == getTombstoneKey())
+ return false;
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs.getPredicate() == rhs.getPredicate() &&
+ lhs.getSummary() == rhs.getSummary();
+}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 2a0d49fcfccf..35afb8d7f694 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -357,10 +357,6 @@ void Operator::populateTypeInferenceInfo(
continue;
}
- if (getArg(*mi).is<NamedAttribute *>()) {
- // TODO: Handle attributes.
- continue;
- }
resultTypeMapping[i].emplace_back(*mi);
found = true;
}
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
index e35f5f2a4794..c755df255c1e 100644
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ b/mlir/python/mlir/dialects/_arith_ops_ext.py
@@ -41,11 +41,11 @@ def __init__(self,
loc=None,
ip=None):
if isinstance(value, int):
- super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
- super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
- super().__init__(result, value, loc=loc, ip=ip)
+ super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir
index 71014b10726f..47f5f1f511c1 100644
--- a/mlir/test/Dialect/Arithmetic/invalid.mlir
+++ b/mlir/test/Dialect/Arithmetic/invalid.mlir
@@ -25,7 +25,7 @@ func.func @non_signless_constant() {
// -----
func.func @complex_constant_wrong_attribute_type() {
- // expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ // expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
%0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
return
}
@@ -50,7 +50,7 @@ func.func @bitcast_
diff erent_bit_widths(%arg : f16) -> f32 {
func.func @constant() {
^bb:
- %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
@@ -58,7 +58,7 @@ func.func @constant() {
func.func @constant_out_of_range() {
^bb:
- %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
@@ -66,7 +66,7 @@ func.func @constant_out_of_range() {
func.func @constant_wrong_type() {
^bb:
- %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
+ %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
return
}
diff --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir
index f94a632bbb8f..592656cefeb5 100644
--- a/mlir/test/IR/diagnostic-handler.mlir
+++ b/mlir/test/IR/diagnostic-handler.mlir
@@ -5,7 +5,7 @@
// Emit the first available call stack in the fused location.
func.func @constant_out_of_range() {
- // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
+ // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
// CHECK-NEXT: mysource2:1:0: note: called from
// CHECK-NEXT: mysource3:2:0: note: called from
%x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index a190a47a9eea..d4d8746ac441 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -123,7 +123,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
// CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL2 : NS_Op<"op_with_all_types_constraint",
[AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
@@ -133,5 +134,18 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
// CHECK-NOT: }
-// CHECK: inferredReturnTypes[0] = operands[2].getType();
-// CHECK: inferredReturnTypes[1] = operands[0].getType();
+// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
+// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
+// CHECK: inferredReturnTypes[1] = odsInferredType1;
+
+def OpL3 : NS_Op<"op_with_all_types_constraint",
+ [AllTypesMatch<["a", "b"]>]> {
+ let arguments = (ins I32Attr:$a);
+ let results = (outs AnyType:$b);
+}
+
+// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
+// CHECK-NOT: }
+// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
+// CHECK: inferredReturnTypes[0] = odsInferredType0;
diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
index 7fb6d953b47f..80fb27b76cad 100644
--- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
@@ -234,41 +234,6 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
//===----------------------------------------------------------------------===//
// Constraint Uniquing
-using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
- return Constraint(RecordDenseMapInfo::getEmptyKey(),
- Constraint::CK_Uncategorized);
-}
-
-Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
- return Constraint(RecordDenseMapInfo::getTombstoneKey(),
- Constraint::CK_Uncategorized);
-}
-
-unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
- Constraint constraint) {
- if (constraint == getEmptyKey())
- return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
- if (constraint == getTombstoneKey()) {
- return RecordDenseMapInfo::getHashValue(
- RecordDenseMapInfo::getTombstoneKey());
- }
- return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
-}
-
-bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
- Constraint rhs) {
- if (lhs == rhs)
- return true;
- if (lhs == getEmptyKey() || lhs == getTombstoneKey())
- return false;
- if (rhs == getEmptyKey() || rhs == getTombstoneKey())
- return false;
- return lhs.getPredicate() == rhs.getPredicate() &&
- lhs.getSummary() == rhs.getSummary();
-}
-
/// An attribute constraint that references anything other than itself and the
/// current op cannot be generically extracted into a function. Most
/// prohibitive are operands and results, which require calls to
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6f5f82aa7dd5..c80fe5a22641 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2336,23 +2336,60 @@ void OpEmitter::genTypeInterfaceMethods() {
fctx.withBuilder("odsBuilder");
body << " ::mlir::Builder odsBuilder(context);\n";
- auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
- if (!type.isArg())
- return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
- auto argIndex = type.getArg();
- assert(!op.getArg(argIndex).is<NamedAttribute *>());
+ // Preprocess the result types and build all of the types used during
+ // inferrence. This limits the amount of duplicated work when a type is used
+ // to infer multiple others.
+ llvm::DenseMap<Constraint, int> constraintsTypes;
+ llvm::DenseMap<int, int> argumentsTypes;
+ int inferredTypeIdx = 0;
+ for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+ auto type = op.getSameTypeAsResult(i).front();
+
+ // If the type isn't an argument, it refers to a buildable type.
+ if (!type.isArg()) {
+ auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
+ if (!it.second)
+ continue;
+
+ // If we haven't seen this constraint, generate a variable for it.
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
+ << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
+ continue;
+ }
+
+ // Otherwise, this is an argument.
+ int argIndex = type.getArg();
+ auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
+ if (!it.second)
+ continue;
+ body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
+
+ // If this is an operand, just index into operand list to access the type.
auto arg = op.getArgToOperandOrAttribute(argIndex);
- if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
- return body << "operands[" << arg.operandOrAttributeIndex()
- << "].getType()";
- return body << "attributes[" << arg.operandOrAttributeIndex()
- << "].getType()";
- };
+ if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
+ body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
+
+ // If this is an attribute, index into the attribute dictionary.
+ } else {
+ auto *attr =
+ op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
+ body << "attributes.get(\"" << attr->name << "\").getType()";
+ }
+ body << ";\n";
+ }
+ // Perform a second pass that handles assigning the inferred types to the
+ // results.
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
- body << " inferredReturnTypes[" << i << "] = ";
auto types = op.getSameTypeAsResult(i);
- emitType(types[0]) << ";\n";
+
+ // Append the inferred type.
+ auto type = types.front();
+ body << " inferredReturnTypes[" << i << "] = odsInferredType"
+ << (type.isArg() ? argumentsTypes[type.getArg()]
+ : constraintsTypes[type.getType()])
+ << ";\n";
+
if (types.size() == 1)
continue;
// TODO: We could verify equality here, but skipping that for verification.
More information about the Mlir-commits
mailing list