[Mlir-commits] [mlir] 13d0578 - [MLIR][TableGen] Fix ambiguous build methods when inferring result types.

Rahul Joshi llvmlistbot at llvm.org
Mon Aug 10 10:05:17 PDT 2020


Author: Rahul Joshi
Date: 2020-08-10T10:05:06-07:00
New Revision: 13d05787d0d2dfdfd81939c2e5c41b6a913f5619

URL: https://github.com/llvm/llvm-project/commit/13d05787d0d2dfdfd81939c2e5c41b6a913f5619
DIFF: https://github.com/llvm/llvm-project/commit/13d05787d0d2dfdfd81939c2e5c41b6a913f5619.diff

LOG: [MLIR][TableGen] Fix ambiguous build methods when inferring result types.

- Fix ODS framework to suppress build methods that infer result types and are
  ambiguous with collective variants. This applies to operations with a single variadic
  inputs whose result types can be inferred.
- Extended OpBuildGenTest to test these kinds of ops.

Differential Revision: https://reviews.llvm.org/D85060

Added: 
    

Modified: 
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/TableGen/Operator.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-result.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/TableGen/OpBuildGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 29d4caa32467..d7fac87af0be 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -151,6 +151,17 @@ class Operator {
   // Returns the total number of arguments.
   int getNumArgs() const { return arguments.size(); }
 
+  // Returns true of the operation has a single variadic arg.
+  bool hasSingleVariadicArg() const;
+
+  // Returns true if the operation has a single variadic result.
+  bool hasSingleVariadicResult() const {
+    return getNumResults() == 1 && getResult(0).isVariadic();
+  }
+
+  // Returns true of the operation has no variadic regions.
+  bool hasNoVariadicRegions() const { return getNumVariadicRegions() == 0; }
+
   using arg_iterator = const Argument *;
   using arg_range = llvm::iterator_range<arg_iterator>;
 

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 3dd924566a8f..9d3995641bca 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -134,6 +134,11 @@ unsigned tblgen::Operator::getNumVariableLengthOperands() const {
   });
 }
 
+bool tblgen::Operator::hasSingleVariadicArg() const {
+  return getNumArgs() == 1 && getArg(0).is<tblgen::NamedTypeConstraint *>() &&
+         getOperand(0).isVariadic();
+}
+
 tblgen::Operator::arg_iterator tblgen::Operator::arg_begin() const {
   return arguments.begin();
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 742033b130bc..c1bc754da804 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1526,4 +1526,31 @@ def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> {
   let results = (outs Variadic<AnyType>:$resultA, Variadic<AnyType>:$resultB);
 }
 
+// Single variadic arg, non variadic results, with SameOperandsAndResultType.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType trait.
+def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> {
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs AnyType:$result);
+}
+
+// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
+// Tests suppression of ambiguious build methods for operations with
+// SameOperandsAndResultType and InferTypeOpInterface.
+def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
+      [SameOperandsAndResultType, InferTypeOpInterface]> {
+  let arguments = (ins Variadic<AnyType>:$inputs);
+  let results = (outs AnyType:$result);
+
+  let extraClassDeclaration = [{
+    static LogicalResult inferReturnTypes(MLIRContext *, 
+          Optional<Location> location, ValueRange operands,
+          DictionaryAttr attributes, RegionRange regions,
+          SmallVectorImpl<Type> &inferredReturnTypes) {
+      inferredReturnTypes.assign({operands[0].getType()});
+      return success();
+    }
+   }];
+}
+
 #endif // TEST_OPS

diff  --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td
index 4b091e4ad49f..bdb0765ab541 100644
--- a/mlir/test/mlir-tblgen/op-result.td
+++ b/mlir/test/mlir-tblgen/op-result.td
@@ -110,8 +110,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA
   let results = (outs AnyTensor:$result);
 }
 
-// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input)
-// CHECK: odsState.addTypes({input.front().getType()});
+// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes )
+// CHECK: odsState.addTypes({operands[0].getType()});
 
 // Test with inferred shapes and interleaved with operands/attributes.
 //

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 989008d53f9f..9f00b801710d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -232,6 +232,10 @@ class OpEmitter {
   // operand's type as all results' types.
   void genUseOperandAsResultTypeCollectiveParamBuilder();
 
+  // Returns true if the inferred collective param build method should be
+  // generated.
+  bool shouldGenerateInferredTypeCollectiveParamBuilder();
+
   // Generates the build() method that takes aggregate operands/attributes
   // parameters. This build() method uses inferred types as result types.
   // Requires: The type needs to be inferable via InferTypeOpInterface.
@@ -984,40 +988,37 @@ void OpEmitter::genSeparateArgParamBuilder() {
   //          result
   //
   // In that case, skip generating such ambiguous build methods here.
-  bool hasSingleVariadicResult =
-      op.getNumResults() == 1 && op.getResult(0).isVariadic();
-
-  bool hasSingleVariadicArg =
-      op.getNumArgs() == 1 &&
-      op.getArg(0).is<tblgen::NamedTypeConstraint *>() &&
-      op.getOperand(0).isVariadic();
-  bool hasNoVariadicRegions = op.getNumVariadicRegions() == 0;
-
   for (auto attrType : attrBuilderType) {
     // Case 3b above.
-    if (!(hasNoVariadicRegions && hasSingleVariadicArg &&
-          hasSingleVariadicResult))
+    if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() &&
+          op.hasSingleVariadicResult()))
       emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
-    if (canInferType(op))
-      emit(attrType, TypeParamKind::None, /*inferType=*/true);
+    if (canInferType(op)) {
+      // When inferType = true, the generated build method does not have
+      // result types. If the op has a single variadic arg, then this build
+      // method will be ambiguious with the collective inferred build method
+      // generated in `genInferredTypeCollectiveParamBuilder`. If we are going
+      // to generate that collective inferred method, suppress generating the
+      // ambiguious build method here.
+      bool buildMethodAmbiguious =
+          op.hasSingleVariadicArg() &&
+          shouldGenerateInferredTypeCollectiveParamBuilder();
+      if (!buildMethodAmbiguious)
+        emit(attrType, TypeParamKind::None, /*inferType=*/true);
+    }
     // The separate arg + collective param kind method will be:
     // (a) Same as the separate arg + separate param kind method if there is
     //     only one variadic result.
     // (b) Ambiguous with the collective params method under conditions in (3a)
     //     above.
     // In either case, skip generating such build method.
-    if (!hasSingleVariadicResult &&
-        !(hasNoVariadicRegions && hasSingleVariadicArg))
+    if (!op.hasSingleVariadicResult() &&
+        !(op.hasNoVariadicRegions() && op.hasSingleVariadicArg()))
       emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
   }
 }
 
 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
-  // If this op has a variadic result, we cannot generate this builder because
-  // we don't know how many results to create.
-  if (op.getNumVariableLengthResults() != 0)
-    return;
-
   int numResults = op.getNumResults();
 
   // Signature
@@ -1055,6 +1056,10 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
        << llvm::join(resultTypes, ", ") << "});\n\n";
 }
 
+bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() {
+  return canInferType(op) && op.getNumSuccessors() == 0;
+}
+
 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   // TODO: Expand to support regions.
   std::string params =
@@ -1209,8 +1214,21 @@ void OpEmitter::genBuilder() {
   //    to facilitate 
diff erent call patterns.
   if (op.getNumVariableLengthResults() == 0) {
     if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
-      genUseOperandAsResultTypeSeparateParamBuilder();
-      genUseOperandAsResultTypeCollectiveParamBuilder();
+      // If the operation has a single variadic input, then the build method
+      // generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be
+      // ambiguious with the one generated by
+      // `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have
+      // a single `ValueRange` argument for operands, and the collective one
+      // will have a `ArrayRef<NamedAttribute>` argument initalized to empty).
+      // Suppress such ambiguious build method.
+      if (!op.hasSingleVariadicArg())
+        genUseOperandAsResultTypeSeparateParamBuilder();
+
+      // The build method generated by the inferred type collective param
+      // builder and one generated here have the same arguments and hence
+      // generating both will be ambiguious. Enable just one of them.
+      if (!shouldGenerateInferredTypeCollectiveParamBuilder())
+        genUseOperandAsResultTypeCollectiveParamBuilder();
     }
     if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
       genUseAttrAsResultTypeBuilder();
@@ -1269,7 +1287,7 @@ void OpEmitter::genCollectiveParamBuilder() {
 
   // Generate builder that infers type too.
   // TODO: Expand to handle regions and successors.
-  if (canInferType(op) && op.getNumSuccessors() == 0)
+  if (shouldGenerateInferredTypeCollectiveParamBuilder())
     genInferredTypeCollectiveParamBuilder();
 }
 

diff  --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index e90f96b87d63..3e3256e96cd0 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -63,6 +63,28 @@ class OpBuildGenTest : public ::testing::Test {
     concreteOp.erase();
   }
 
+  // Helper method to test ops with inferred result types and single variadic
+  // input.
+  template <typename OpTy>
+  void testSingleVariadicInputInferredType() {
+    // Test separate arg, separate param build method.
+    auto op = builder.create<OpTy>(loc, i32Ty, ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test collective params build method.
+    op = builder.create<OpTy>(loc, ArrayRef<Type>{i32Ty},
+                              ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test build method with no result types, default value of attributes.
+    op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32});
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, noAttrs);
+
+    // Test build method with no result types and supplied attributes.
+    op = builder.create<OpTy>(loc, ArrayRef<Value>{cstI32, cstI32}, attrs);
+    verifyOp(std::move(op), {i32Ty}, {cstI32, cstI32}, attrs);
+  }
+
 protected:
   MLIRContext ctx;
   OpBuilder builder;
@@ -178,4 +200,19 @@ TEST_F(OpBuildGenTest,
   verifyOp(std::move(op), {i32Ty, f32Ty}, {cstI32}, attrs);
 }
 
+// The next 2 tests test supression of ambiguious build methods for ops that
+// have a single variadic input, and single non-variadic result, and which
+// support the SameOperandsAndResultType trait and and optionally the
+// InferOpTypeInterface interface. For such ops, the ODS framework generates
+// build methods with no result types as they are inferred from the input types.
+TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
+  testSingleVariadicInputInferredType<TableGenBuildOp4>();
+}
+
+TEST_F(
+    OpBuildGenTest,
+    BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
+  testSingleVariadicInputInferredType<TableGenBuildOp5>();
+}
+
 } // namespace mlir


        


More information about the Mlir-commits mailing list