[Mlir-commits] [mlir] 13d0578 - [MLIR][TableGen] Fix ambiguous build methods when inferring result types.
Rahul Joshi
llvmlistbot at llvm.org
Mon Aug 10 10:05:17 PDT 2020
Author: Rahul Joshi
Date: 2020-08-10T10:05:06-07:00
New Revision: 13d05787d0d2dfdfd81939c2e5c41b6a913f5619
URL: https://github.com/llvm/llvm-project/commit/13d05787d0d2dfdfd81939c2e5c41b6a913f5619
DIFF: https://github.com/llvm/llvm-project/commit/13d05787d0d2dfdfd81939c2e5c41b6a913f5619.diff
LOG: [MLIR][TableGen] Fix ambiguous build methods when inferring result types.
- Fix ODS framework to suppress build methods that infer result types and are
ambiguous with collective variants. This applies to operations with a single variadic
inputs whose result types can be inferred.
- Extended OpBuildGenTest to test these kinds of ops.
Differential Revision: https://reviews.llvm.org/D85060
Added:
Modified:
mlir/include/mlir/TableGen/Operator.h
mlir/lib/TableGen/Operator.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-result.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/TableGen/OpBuildGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 29d4caa32467..d7fac87af0be 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -151,6 +151,17 @@ class Operator {
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }
+ // Returns true of the operation has a single variadic arg.
+ bool hasSingleVariadicArg() const;
+
+ // Returns true if the operation has a single variadic result.
+ bool hasSingleVariadicResult() const {
+ return getNumResults() == 1 && getResult(0).isVariadic();
+ }
+
+ // Returns true of the operation has no variadic regions.
+ bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
+
using arg_iterator = const Argument *;
using arg_range = llvm::iterator_range<arg_iterator>;
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 3dd924566a8f..9d3995641bca 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -134,6 +134,11 @@ unsigned tblgen::Operator::getNumVariableLengthOperands() const {
});
}
+bool tblgen::Operator::hasSingleVariadicArg() const {
+ return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() &&
+ getOperand(0).isVariadic();
+}
+
tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
return arguments.begin();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 742033b130bc..c1bc754da804 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1526,4 +1526,31 @@ def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> {
let results = (outs Variadic<AnyType>:$resultA, Variadic<AnyType>:$resultB);
}
+// Single variadic arg, non variadic results, with SameOperandsAndResultType.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType trait.
+def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs AnyType:$result);
+}
+
+// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType and InferTypeOpInterface.
+def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
+ [SameOperandsAndResultType, InferTypeOpInterface]> {
+ let arguments = (ins Variadic<AnyType>:$inputs);
+ let results = (outs AnyType:$result);
+
+ let extraClassDeclaration = [{
+ static LogicalResult inferReturnTypes(MLIRContext *,
+ Optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({operands[0].getType()});
+ return success();
+ }
+ }];
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 4b091e4ad49f..bdb0765ab541 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -110,8 +110,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
let results = (outs AnyTensor:$result);
}
-// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
-// CHECK: odsState.addTypes({input.front().getType()});
+// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes )
+// CHECK: odsState.addTypes({operands[0].getType()});
// Test with inferred shapes and interleaved with operands/attributes.
//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 989008d53f9f..9f00b801710d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -232,6 +232,10 @@ class OpEmitter {
// operand's type as all results' types.
void genUseOperandAsResultTypeCollectiveParamBuilder();
+ // Returns true if the inferred collective param build method should be
+ // generated.
+ bool shouldGenerateInferredTypeCollectiveParamBuilder();
+
// Generates the build() method that takes aggregate operands/attributes
// parameters. This build() method uses inferred types as result types.
// Requires: The type needs to be inferable via InferTypeOpInterface.
@@ -984,40 +988,37 @@ void OpEmitter::genSeparateArgParamBuilder() {
// result
//
// In that case, skip generating such ambiguous build methods here.
- bool hasSingleVariadicResult =
- op.getNumResults() == 1 && op.getResult(0).isVariadic();
-
- bool hasSingleVariadicArg =
- op.getNumArgs() == 1 &&
- op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
- op.getOperand(0).isVariadic();
- bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
-
for (auto attrType : attrBuilderType) {
// Case 3b above.
- if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
- hasSingleVariadicResult))
+ if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() &&
+ op.hasSingleVariadicResult()))
emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
- if (canInferType(op))
- emit(attrType, TypeParamKind::None, /*inferType=*/true);
+ if (canInferType(op)) {
+ // When inferType = true, the generated build method does not have
+ // result types. If the op has a single variadic arg, then this build
+ // method will be ambiguious with the collective inferred build method
+ // generated in `genInferredTypeCollectiveParamBuilder`. If we are going
+ // to generate that collective inferred method, suppress generating the
+ // ambiguious build method here.
+ bool buildMethodAmbiguious =
+ op.hasSingleVariadicArg() &&
+ shouldGenerateInferredTypeCollectiveParamBuilder();
+ if (!buildMethodAmbiguious)
+ emit(attrType, TypeParamKind::None, /*inferType=*/true);
+ }
// The separate arg + collective param kind method will be:
// (a) Same as the separate arg + separate param kind method if there is
// only one variadic result.
// (b) Ambiguous with the collective params method under conditions in (3a)
// above.
// In either case, skip generating such build method.
- if (!hasSingleVariadicResult &&
- !(hasNoVariadicRegions && hasSingleVariadicArg))
+ if (!op.hasSingleVariadicResult() &&
+ !(op.hasNoVariadicRegions() && op.hasSingleVariadicArg()))
emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
}
}
void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
- // If this op has a variadic result, we cannot generate this builder because
- // we don't know how many results to create.
- if (op.getNumVariableLengthResults() != 0)
- return;
-
int numResults = op.getNumResults();
// Signature
@@ -1055,6 +1056,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
+bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() {
+ return canInferType(op) && op.getNumSuccessors() == 0;
+}
+
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
// TODO: Expand to support regions.
std::string params =
@@ -1209,8 +1214,21 @@ void OpEmitter::genBuilder() {
// to facilitate
diff erent call patterns.
if (op.getNumVariableLengthResults() == 0) {
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
- genUseOperandAsResultTypeSeparateParamBuilder();
- genUseOperandAsResultTypeCollectiveParamBuilder();
+ // If the operation has a single variadic input, then the build method
+ // generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be
+ // ambiguious with the one generated by
+ // `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have
+ // a single `ValueRange` argument for operands, and the collective one
+ // will have a `ArrayRef<NamedAttribute>` argument initalized to empty).
+ // Suppress such ambiguious build method.
+ if (!op.hasSingleVariadicArg())
+ genUseOperandAsResultTypeSeparateParamBuilder();
+
+ // The build method generated by the inferred type collective param
+ // builder and one generated here have the same arguments and hence
+ // generating both will be ambiguious. Enable just one of them.
+ if (!shouldGenerateInferredTypeCollectiveParamBuilder())
+ genUseOperandAsResultTypeCollectiveParamBuilder();
}
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
@@ -1269,7 +1287,7 @@ void OpEmitter::genCollectiveParamBuilder() {
// Generate builder that infers type too.
// TODO: Expand to handle regions and successors.
- if (canInferType(op) && op.getNumSuccessors() == 0)
+ if (shouldGenerateInferredTypeCollectiveParamBuilder())
genInferredTypeCollectiveParamBuilder();
}
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index e90f96b87d63..3e3256e96cd0 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -63,6 +63,28 @@ class OpBuildGenTest : public ::testing::Test {
concreteOp.erase();
}
+ // Helper method to test ops with inferred result types and single variadic
+ // input.
+ template <typename OpTy>
+ void testSingleVariadicInputInferredType() {
+ // Test separate arg, separate param build method.
+ auto op = builder.create<OpTy>(loc, i32Ty, ArrayRef<Value>{cstI32, cstI32});
+ verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+ // Test collective params build method.
+ op = builder.create<OpTy>(loc, ArrayRef<Type>{i32Ty},
+ ArrayRef<Value>{cstI32, cstI32});
+ verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+ // Test build method with no result types, default value of attributes.
+ op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32});
+ verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+ // Test build method with no result types and supplied attributes.
+ op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32}, attrs);
+ verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs);
+ }
+
protected:
MLIRContext ctx;
OpBuilder builder;
@@ -178,4 +200,19 @@ TEST_F(OpBuildGenTest,
verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs);
}
+// The next 2 tests test supression of ambiguious build methods for ops that
+// have a single variadic input, and single non-variadic result, and which
+// support the SameOperandsAndResultType trait and and optionally the
+// InferOpTypeInterface interface. For such ops, the ODS framework generates
+// build methods with no result types as they are inferred from the input types.
+TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
+ testSingleVariadicInputInferredType<TableGenBuildOp4>();
+}
+
+TEST_F(
+ OpBuildGenTest,
+ BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
+ testSingleVariadicInputInferredType<TableGenBuildOp5>();
+}
+
} // namespace mlir
More information about the Mlir-commits
mailing list