[Mlir-commits] [mlir] 31f40f6 - [mlir] Add simple generator for return types

Jacques Pienaar llvmlistbot at llvm.org
Wed May 27 08:46:10 PDT 2020


Author: Jacques Pienaar
Date: 2020-05-27T08:45:55-07:00
New Revision: 31f40f603d0c00b313397196124c5f39090badf0

URL: https://github.com/llvm/llvm-project/commit/31f40f603d0c00b313397196124c5f39090badf0
DIFF: https://github.com/llvm/llvm-project/commit/31f40f603d0c00b313397196124c5f39090badf0.diff

LOG: [mlir] Add simple generator for return types

Take advantage of equality constrains to generate the type inference interface.
This is used for equality and trivially built types. The type inference method
is only generated when no type inference trait is specified already.

This reorders verification that changes some test error messages.

Differential Revision: https://reviews.llvm.org/D80484

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/TableGen/Attribute.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/TableGen/Attribute.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-decl.td
    mlir/test/mlir-tblgen/types.mlir
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index ddabae2225e7..42c431d13f8e 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1,4 +1,4 @@
-# Table-driven Operation Definition Specification (ODS)
+# Operation Definition Specification (ODS)
 
 In addition to specializing the `mlir::Op` C++ template, MLIR also supports
 defining operations in a table-driven manner. This is achieved via
@@ -526,10 +526,9 @@ static void build(OpBuilder &odsBuilder, OperationState &odsState,
                   IntegerAttr i32_attr, FloatAttr f32_attr, ...);
 
 // All operands/attributes have aggregate parameters.
-// Generated if InferTypeOpInterface interface is specified.
+// Generated if return type can be inferred.
 static void build(OpBuilder &odsBuilder, OperationState &odsState,
-                  ValueRange operands,
-                  ArrayRef<NamedAttribute> attributes);
+                  ValueRange operands, ArrayRef<NamedAttribute> attributes);
 
 // (And manually specified builders depending on the specific op.)
 ```
@@ -554,6 +553,12 @@ restrictions.) Otherwise, the builder of the third form will still be generated
 but default values for the attributes not at the end of the `arguments` list
 will not be supplied in the builder's signature.
 
+ODS will generate a builder that doesn't require return type specified if
+
+*   Op implements InferTypeOpInterface interface;
+*   All return types are either buildable types or are the same as a given
+    operand (e.g., `AllTypesMatch` constraint between operand and result);
+
 And there may potentially exist other builders depending on the specific op;
 please refer to the
 [generated C++ file](#run-mlir-tblgen-to-see-the-generated-content) for the

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 0278d7bbeb06..a9759fc6a734 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -91,10 +91,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast",
   let hasFolder = 1;
 }
 
-def Shape_ConstShapeOp : Shape_Op<"const_shape",
-    [ConstantLike,
-     NoSideEffect,
-     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
   let summary = "Creates a constant of !shape.shape type.";
   let description = [{
     Creates a !shape.shape with rank given by the length of `shape` and with

diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index f99939392e93..4571ca8ee9b3 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -230,6 +230,9 @@ class StructAttr : public Attribute {
   std::vector<StructFieldAttr> getAllFields() const;
 };
 
+// Name of infer type op interface.
+extern const char *inferTypeOpInterface;
+
 } // end namespace tblgen
 } // end namespace mlir
 

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index e65bc55a84f5..040f52314cea 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -23,6 +23,7 @@
 #include "mlir/TableGen/Type.h"
 #include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/SMLoc.h"
 
@@ -227,10 +228,45 @@ class Operator {
   // debugging purposes.
   void print(llvm::raw_ostream &os) const;
 
+  // Return whether all the result types are known.
+  bool allResultTypesKnown() const { return allResultsHaveKnownTypes; };
+
+  // Pair representing either a index to an argument or a type constraint. Only
+  // one of these entries should have the non-default value.
+  struct ArgOrType {
+    explicit ArgOrType(int index) : index(index), constraint(None) {}
+    explicit ArgOrType(TypeConstraint constraint)
+        : index(None), constraint(constraint) {}
+    bool isArg() const {
+      assert(constraint.hasValue() ^ index.hasValue());
+      return index.hasValue();
+    }
+    bool isType() const {
+      assert(constraint.hasValue() ^ index.hasValue());
+      return constraint.hasValue();
+    }
+
+    int getArg() const { return *index; }
+    TypeConstraint getType() const { return *constraint; }
+
+  private:
+    Optional<int> index;
+    Optional<TypeConstraint> constraint;
+  };
+
+  // Return all arguments or type constraints with same type as result[index].
+  // Requires: all result types are known.
+  ArrayRef<ArgOrType> getSameTypeAsResult(int index) const;
+
 private:
   // Populates the vectors containing operands, attributes, results and traits.
   void populateOpStructure();
 
+  // Populates type inference info (mostly equality) with input a mapping from
+  // names to indices for arguments and results.
+  void populateTypeInferenceInfo(
+      const llvm::StringMap<int> &argumentsAndResultsIndex);
+
   // The dialect of this op.
   Dialect dialect;
 
@@ -261,12 +297,18 @@ class Operator {
   // The regions of this op.
   SmallVector<NamedRegion, 1> regions;
 
+  // The argument with the same type as the result.
+  SmallVector<SmallVector<ArgOrType, 2>, 4> resultTypeMapping;
+
   // The number of native attributes stored in the leading positions of
   // `attributes`.
   int numNativeAttributes;
 
   // The TableGen definition of this op.
   const llvm::Record &def;
+
+  // Whether the type of all results are known.
+  bool allResultsHaveKnownTypes;
 };
 
 } // end namespace tblgen

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 095c41720fba..fa9552fc8694 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -223,15 +223,6 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
 
 OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
 
-LogicalResult
-ConstShapeOp::inferReturnTypes(MLIRContext *context,
-                               Optional<Location> location, ValueRange operands,
-                               DictionaryAttr attributes, RegionRange regions,
-                               SmallVectorImpl<Type> &inferredReturnTypes) {
-  inferredReturnTypes.push_back(ShapeType::get(context));
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // ConstSizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 89dce1958991..fe1fffbc1a69 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -288,3 +288,5 @@ tblgen::StructAttr::getAllFields() const {
 
   return attributes;
 }
+
+const char *mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 808ba7aabc76..2f77184980e2 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -14,6 +14,9 @@
 #include "mlir/TableGen/OpTrait.h"
 #include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Type.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
@@ -155,13 +158,13 @@ auto tblgen::Operator::getArgDecorators(int index) const
 
 const tblgen::OpTrait *tblgen::Operator::getTrait(StringRef trait) const {
   for (const auto &t : traits) {
-    if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
+    if (const auto *opTrait = dyn_cast<tblgen::NativeOpTrait>(&t)) {
       if (opTrait->getTrait() == trait)
         return opTrait;
-    } else if (auto opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
+    } else if (const auto *opTrait = dyn_cast<tblgen::InternalOpTrait>(&t)) {
       if (opTrait->getTrait() == trait)
         return opTrait;
-    } else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) {
+    } else if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&t)) {
       if (opTrait->getTrait() == trait)
         return opTrait;
     }
@@ -252,22 +255,126 @@ auto tblgen::Operator::getArg(int index) const -> Argument {
   return arguments[index];
 }
 
+// Mapping from result index to combined argument and result index. Arguments
+// are indexed to match getArg index, while the result indexes are mapped to
+// avoid overlap.
+static int resultIndex(int i) { return -1 - i; }
+
+bool tblgen::Operator::isVariadic() const {
+  return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
+                [](const NamedTypeConstraint &op) { return op.isVariadic(); });
+}
+
+void tblgen::Operator::populateTypeInferenceInfo(
+    const llvm::StringMap<int> &argumentsAndResultsIndex) {
+  // If the type inference op interface is not registered, then do not attempt
+  // to determine if the result types an be inferred.
+  auto &recordKeeper = def.getRecords();
+  auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
+  allResultsHaveKnownTypes = false;
+  if (!inferTrait)
+    return;
+
+  // If there are no results, the skip this else the build method generated
+  // overlaps with another autogenerated builder.
+  if (getNumResults() == 0)
+    return;
+
+  // Skip for ops with variadic operands/results.
+  // TODO: This can be relaxed.
+  if (isVariadic())
+    return;
+
+  // Skip cases currently being custom generated.
+  // TODO: Remove special cases.
+  if (getTrait("OpTrait::SameOperandsAndResultType"))
+    return;
+
+  // We create equivalence classes of argument/result types where arguments
+  // and results are mapped into the same index space and indices corresponding
+  // to the same type are in the same equivalence class.
+  llvm::EquivalenceClasses<int> ecs;
+  resultTypeMapping.resize(getNumResults());
+  // Captures the argument whose type matches a given result type. Preference
+  // towards capturing operands first before attributes.
+  auto captureMapping = [&](int i) {
+    bool found = false;
+    ecs.insert(resultIndex(i));
+    auto mi = ecs.findLeader(resultIndex(i));
+    for (auto me = ecs.member_end(); mi != me; ++mi) {
+      if (*mi < 0) {
+        auto tc = getResultTypeConstraint(i);
+        if (tc.getBuilderCall().hasValue()) {
+          resultTypeMapping[i].emplace_back(tc);
+          found = true;
+        }
+        continue;
+      }
+
+      if (auto *attr = getArg(*mi).dyn_cast<NamedAttribute *>()) {
+        // TODO: Handle attributes.
+        continue;
+      } else {
+        resultTypeMapping[i].emplace_back(*mi);
+        found = true;
+      }
+    }
+    return found;
+  };
+
+  for (const OpTrait &trait : traits) {
+    const llvm::Record &def = trait.getDef();
+    // If the infer type op interface was manually added, then treat it as
+    // intention that the op needs special handling.
+    // TODO: Reconsider whether to always generate, this is more conservative
+    // and keeps existing behavior so starting that way for now.
+    if (def.isSubClassOf(
+            llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
+      return;
+    if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+      if (opTrait->getTrait().startswith(inferTypeOpInterface))
+        return;
+
+    if (!def.isSubClassOf("AllTypesMatch"))
+      continue;
+
+    auto values = def.getValueAsListOfStrings("values");
+    auto root = argumentsAndResultsIndex.lookup(values.front());
+    for (StringRef str : values)
+      ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
+  }
+
+  // Verifies that all output types have a corresponding known input type
+  // and chooses matching operand or attribute (in that order) that
+  // matches it.
+  allResultsHaveKnownTypes =
+      all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
+
+  // If the types could be computed, then add type inference trait.
+  if (allResultsHaveKnownTypes)
+    traits.push_back(OpTrait::create(inferTrait->getDefInit()));
+}
+
 void tblgen::Operator::populateOpStructure() {
   auto &recordKeeper = def.getRecords();
-  auto typeConstraintClass = recordKeeper.getClass("TypeConstraint");
-  auto attrClass = recordKeeper.getClass("Attr");
-  auto derivedAttrClass = recordKeeper.getClass("DerivedAttr");
-  auto opVarClass = recordKeeper.getClass("OpVariable");
+  auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
+  auto *attrClass = recordKeeper.getClass("Attr");
+  auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
+  auto *opVarClass = recordKeeper.getClass("OpVariable");
   numNativeAttributes = 0;
 
   DagInit *argumentValues = def.getValueAsDag("arguments");
   unsigned numArgs = argumentValues->getNumArgs();
 
+  // Mapping from name of to argument or result index. Arguments are indexed
+  // to match getArg index, while the results are negatively indexed.
+  llvm::StringMap<int> argumentsAndResultsIndex;
+
   // Handle operands and native attributes.
   for (unsigned i = 0; i != numArgs; ++i) {
-    auto arg = argumentValues->getArg(i);
+    auto *arg = argumentValues->getArg(i);
     auto givenName = argumentValues->getArgNameStr(i);
-    auto argDefInit = dyn_cast<DefInit>(arg);
+    auto *argDefInit = dyn_cast<DefInit>(arg);
     if (!argDefInit)
       PrintFatalError(def.getLoc(),
                       Twine("undefined type for argument #") + Twine(i));
@@ -290,6 +397,8 @@ void tblgen::Operator::populateOpStructure() {
       PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
                                     "from TypeConstraint or Attr are allowed");
     }
+    if (!givenName.empty())
+      argumentsAndResultsIndex[givenName] = i;
   }
 
   // Handle derived attributes.
@@ -348,6 +457,8 @@ void tblgen::Operator::populateOpStructure() {
     if (resultDef->isSubClassOf(opVarClass))
       resultDef = resultDef->getValueAsDef("constraint");
     results.push_back({name, TypeConstraint(resultDef)});
+    if (!name.empty())
+      argumentsAndResultsIndex[name] = resultIndex(i);
   }
 
   // Handle successors
@@ -375,17 +486,19 @@ void tblgen::Operator::populateOpStructure() {
 
   // Create list of traits, skipping over duplicates: appending to lists in
   // tablegen is easy, making them unique less so, so dedupe here.
-  if (auto traitList = def.getValueAsListInit("traits")) {
+  if (auto *traitList = def.getValueAsListInit("traits")) {
     // This is uniquing based on pointers of the trait.
     SmallPtrSet<const llvm::Init *, 32> traitSet;
     traits.reserve(traitSet.size());
-    for (auto traitInit : *traitList) {
+    for (auto *traitInit : *traitList) {
       // Keep traits in the same order while skipping over duplicates.
       if (traitSet.insert(traitInit).second)
         traits.push_back(OpTrait::create(traitInit));
     }
   }
 
+  populateTypeInferenceInfo(argumentsAndResultsIndex);
+
   // Handle regions
   auto *regionsDag = def.getValueAsDag("regions");
   auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
@@ -415,6 +528,12 @@ void tblgen::Operator::populateOpStructure() {
   LLVM_DEBUG(print(llvm::dbgs()));
 }
 
+auto tblgen::Operator::getSameTypeAsResult(int index) const
+    -> ArrayRef<ArgOrType> {
+  assert(allResultTypesKnown());
+  return resultTypeMapping[index];
+}
+
 ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
 
 bool tblgen::Operator::hasDescription() const {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9e95932b5680..997d8eb44ae5 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -756,15 +756,6 @@ def OpSymbolBindingA : TEST_Op<"symbol_binding_a", []> {
 def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> {
   let arguments = (ins I32:$operand);
   let results = (outs I32);
-
-  let builders = [
-    OpBuilder<
-      "OpBuilder &builder, OperationState &state, Value operand",
-      [{
-        state.types.assign({builder.getIntegerType(32)});
-        state.addOperands({operand});
-      }]>
-  ];
 }
 def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> {
   let arguments = (ins I32:$operand);
@@ -868,17 +859,6 @@ def AnotherThreeResultOp : TEST_Op<"another_three_result"> {
 def TwoResultOp : TEST_Op<"two_result"> {
   let arguments = (ins MultiResultOpEnum:$kind);
   let results = (outs I32:$result1, F32:$result2);
-
-  let builders = [
-    OpBuilder<
-      "OpBuilder &builder, OperationState &state, IntegerAttr kind",
-      [{
-        auto i32 = builder.getIntegerType(32);
-        auto f32 = builder.getF32Type();
-        state.types.assign({i32, f32});
-        state.addAttribute("kind", kind);
-      }]>
-  ];
 }
 
 def AnotherTwoResultOp : TEST_Op<"another_two_result"> {

diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index c68d03c96b30..565f1921125a 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -1,6 +1,7 @@
 // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck --dump-input-on-failure %s
 
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 def Test_Dialect : Dialect {
@@ -44,8 +45,6 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
   }];
 }
 
-// CHECK: class AOp;
-
 // CHECK-LABEL: NS::AOp declarations
 
 // CHECK: class AOpOperandAdaptor {
@@ -150,6 +149,26 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
 // CHECK:   Value b();
 // CHECK:   static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a)
 
+// Check that all types match constraint results in generating builder.
+// ---
+
+def NS_FOp : NS_Op<"op_with_all_types_constraint",
+    [AllTypesMatch<["a", "b"]>]> {
+  let arguments = (ins AnyType:$a);
+  let results = (outs AnyType:$b);
+}
+
+// CHECK-LABEL: class FOp :
+// CHECK: static LogicalResult inferReturnTypes
+
+def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
+  let arguments = (ins AnyType:$a);
+  let results = (outs I32:$b);
+}
+
+// CHECK-LABEL: class GOp :
+// CHECK: static LogicalResult inferReturnTypes
+
 // Check that default builders can be suppressed.
 // ---
 

diff  --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir
index 6850b77c7672..6a0a80ca5e5f 100644
--- a/mlir/test/mlir-tblgen/types.mlir
+++ b/mlir/test/mlir-tblgen/types.mlir
@@ -438,7 +438,7 @@ func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) {
 // -----
 
 func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
-  // expected-error at +1 {{all of {x, res} have same type}}
+  // expected-error at +1 {{type incompatible with return type of operation}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
   return
 }
@@ -446,7 +446,7 @@ func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>
 // -----
 
 func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
-  // expected-error at +1 {{all of {x, res} have same type}}
+  // expected-error at +1 {{type incompatible with return type of operation}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
   return
 }

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 2010262f2185..0b55825d1a46 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -295,9 +295,15 @@ class OpEmitter {
   // Generate the OpInterface methods.
   void genOpInterfaceMethods();
 
+  // Generate op interface method.
+  void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait);
+
   // Generate the side effect interface methods.
   void genSideEffectInterfaceMethods();
 
+  // Generate the type inference interface methods.
+  void genTypeInterfaceMethods();
+
 private:
   // The TableGen record for this op.
   // TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
@@ -321,6 +327,7 @@ OpEmitter::OpEmitter(const Operator &op)
   verifyCtx.withOp("(*this->getOperation())");
 
   genTraits();
+
   // Generate C++ code for various op methods. The order here determines the
   // methods in the generated file.
   genOpAsmInterface();
@@ -341,6 +348,7 @@ OpEmitter::OpEmitter(const Operator &op)
   genOpInterfaceMethods();
   generateOpFormat(op, opClass);
   genSideEffectInterfaceMethods();
+  genTypeInterfaceMethods();
 }
 
 void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
@@ -750,6 +758,10 @@ static bool canGenerateUnwrappedBuilder(Operator &op) {
   return canGenerate;
 }
 
+static bool canInferType(Operator &op) {
+  return op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
+}
+
 void OpEmitter::genSeparateArgParamBuilder() {
   SmallVector<AttrParamKind, 2> attrBuilderType;
   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -814,11 +826,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
     llvm_unreachable("unhandled TypeParamKind");
   };
 
-  bool canInferType =
-      op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0;
   for (auto attrType : attrBuilderType) {
     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
-    if (canInferType)
+    if (canInferType(op))
       emit(attrType, TypeParamKind::None, /*inferType=*/true);
     // Emit separate arg build with collective type, unless there is only one
     // variadic result, in which case the above would have already generated
@@ -1070,11 +1080,8 @@ void OpEmitter::genCollectiveParamBuilder() {
   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
 
   // Generate builder that infers type too.
-  // TODO(jpienaar): Subsume this with general checking if type can be inferred
-  // automatically.
   // TODO(jpienaar): Expand to handle regions and successors.
-  if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0 &&
-      op.getNumSuccessors() == 0)
+  if (canInferType(op) && op.getNumSuccessors() == 0)
     genInferredTypeCollectiveParamBuilder();
 }
 
@@ -1318,40 +1325,43 @@ void OpEmitter::genFolderDecls() {
   }
 }
 
+void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
+  auto interface = opTrait->getOpInterface();
+
+  // Get the set of methods that should always be declared.
+  auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
+  llvm::StringSet<> alwaysDeclaredMethods;
+  alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
+                               alwaysDeclaredMethodsVec.end());
+
+  for (const OpInterfaceMethod &method : interface.getMethods()) {
+    // Don't declare if the method has a body.
+    if (method.getBody())
+      continue;
+    // Don't declare if the method has a default implementation and the op
+    // didn't request that it always be declared.
+    if (method.getDefaultImplementation() &&
+        !alwaysDeclaredMethods.count(method.getName()))
+      continue;
+
+    std::string args;
+    llvm::raw_string_ostream os(args);
+    interleaveComma(method.getArguments(), os,
+                    [&](const OpInterfaceMethod::Argument &arg) {
+                      os << arg.type << " " << arg.name;
+                    });
+    opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
+                      method.isStatic() ? OpMethod::MP_Static
+                                        : OpMethod::MP_None,
+                      /*declOnly=*/true);
+  }
+}
+
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
-    auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait);
-    if (!opTrait || !opTrait->shouldDeclareMethods())
-      continue;
-    auto interface = opTrait->getOpInterface();
-
-    // Get the set of methods that should always be declared.
-    auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
-    llvm::StringSet<> alwaysDeclaredMethods;
-    alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
-                                 alwaysDeclaredMethodsVec.end());
-
-    for (const OpInterfaceMethod &method : interface.getMethods()) {
-      // Don't declare if the method has a body.
-      if (method.getBody())
-        continue;
-      // Don't declare if the method has a default implementation and the op
-      // didn't request that it always be declared.
-      if (method.getDefaultImplementation() &&
-          !alwaysDeclaredMethods.count(method.getName()))
-        continue;
-
-      std::string args;
-      llvm::raw_string_ostream os(args);
-      interleaveComma(method.getArguments(), os,
-                      [&](const OpInterfaceMethod::Argument &arg) {
-                        os << arg.type << " " << arg.name;
-                      });
-      opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
-                        method.isStatic() ? OpMethod::MP_Static
-                                          : OpMethod::MP_None,
-                        /*declOnly=*/true);
-    }
+    if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
+      if (opTrait->shouldDeclareMethods())
+        genOpInterfaceMethod(opTrait);
   }
 }
 
@@ -1431,6 +1441,46 @@ void OpEmitter::genSideEffectInterfaceMethods() {
   }
 }
 
+void OpEmitter::genTypeInterfaceMethods() {
+  if (!op.allResultTypesKnown())
+    return;
+
+  auto &method = opClass.newMethod(
+      "LogicalResult", "inferReturnTypes",
+      "MLIRContext* context, Optional<Location> location, "
+      "ValueRange operands, DictionaryAttr attributes, RegionRange regions, "
+      "SmallVectorImpl<Type>& inferredReturnTypes",
+      OpMethod::MP_Static,
+      /*declOnly=*/false);
+  auto &os = method.body();
+  os << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
+
+  FmtContext fctx;
+  fctx.withBuilder("odsBuilder");
+  os << "  Builder odsBuilder(context);\n";
+
+  auto emitType =
+      [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
+    if (type.isArg()) {
+      auto argIndex = type.getArg();
+      assert(!op.getArg(argIndex).is<NamedAttribute *>());
+      return os << "operands[" << argIndex << "].getType()";
+    } else {
+      return os << tgfmt(*type.getType().getBuilderCall(), &fctx);
+    }
+  };
+
+  for (int i = 0, e = op.getNumResults(); i != e; ++i) {
+    os << "  inferredReturnTypes[" << i << "] = ";
+    auto types = op.getSameTypeAsResult(i);
+    emitType(types[0]) << ";\n";
+    if (types.size() == 1)
+      continue;
+    // TODO: We could verify equality here, but skipping that for verification.
+  }
+  os << "  return success();";
+}
+
 void OpEmitter::genParser() {
   if (!hasStringAttribute(def, "parser") ||
       hasStringAttribute(def, "assemblyFormat"))


        


More information about the Mlir-commits mailing list