[Mlir-commits] [mlir] cf6f217 - [mlir][tblgen] Generate generic adaptors for Ops

Markus Böck llvmlistbot at llvm.org
Wed Jan 11 05:34:16 PST 2023


Author: Markus Böck
Date: 2023-01-11T14:32:21+01:00
New Revision: cf6f21751622fcae326d1fe13bc5afd74c4e720f

URL: https://github.com/llvm/llvm-project/commit/cf6f21751622fcae326d1fe13bc5afd74c4e720f
DIFF: https://github.com/llvm/llvm-project/commit/cf6f21751622fcae326d1fe13bc5afd74c4e720f.diff

LOG: [mlir][tblgen] Generate generic adaptors for Ops

This is part of the RFC for a better fold API: https://discourse.llvm.org/t/rfc-a-better-fold-api-using-more-generic-adaptors/67374

This patch implements the generation of generic adaptors through TableGen. These are essentially a generalization of Adaptors, as implemented previously, but instead of indexing into a `mlir::ValueRange`, they may index into any container, regardless of the element type. This allows the use of the convenient getter methods of Adaptors to be reused on ranges that are the result of some kind of mapping functions of an ops operands.
In the case of the fold API in the RFC, this would be `ArrayRef<Attribute>`, which is a mapping of the operands to their possibly-constant values.

Implementation wise, some special care was taken to not cause a compile time regression, nor to break any kind of source compatibility.
For that purpose, the current adaptor class was split into three:
* A generic adaptor base class, within the detail namespace as it is an implementation detail, which implements all APIs independent of the range type used for the operands. This is all the attribute and region related code. Since it is not templated, its implementation does not have to be inline and can be put into the cpp source file
* The actual generic adaptor, which has a template parameter for the range that should be indexed into for retrieving operands. It implements all the getters for operands, as they are dependent on the range type. It publicly inherits from the generic adaptor base class
* A class named as adaptors have been named so far, inheriting from the generic adaptor class with `mlir::ValueRange` as range to index into. It implements the rest of the API, specific to `mlir::ValueRange` adaptors, which have previously been part of the adaptor. This boils down to a constructor from the Op type as well as the verify function.

The last class having the exact same API surface and name as Adaptors did previously leads to full source compatibility.

Differential Revision: https://reviews.llvm.org/D140660

Added: 
    mlir/unittests/IR/AdaptorTest.cpp

Modified: 
    mlir/include/mlir/TableGen/Class.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/TableGen/Class.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/mlir-tblgen/op-decl-and-defs.td
    mlir/test/mlir-tblgen/op-operand.td
    mlir/tools/mlir-tblgen/OpClass.cpp
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 896880bc8cdb3..954ef5bec2d04 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -436,6 +436,13 @@ operator|(mlir::tblgen::Method::Properties lhs,
                                           static_cast<unsigned>(rhs));
 }
 
+inline constexpr mlir::tblgen::Method::Properties &
+operator|=(mlir::tblgen::Method::Properties &lhs,
+           mlir::tblgen::Method::Properties rhs) {
+  return lhs = mlir::tblgen::Method::Properties(static_cast<unsigned>(lhs) |
+                                                static_cast<unsigned>(rhs));
+}
+
 namespace mlir {
 namespace tblgen {
 
@@ -488,11 +495,27 @@ class UsingDeclaration
   /// Write the using declaration.
   void writeDeclTo(raw_indented_ostream &os) const override;
 
+  /// Add a template parameter.
+  template <typename ParamT>
+  void addTemplateParam(ParamT param) {
+    templateParams.insert(stringify(param));
+  }
+
+  /// Add a list of template parameters.
+  template <typename ContainerT>
+  void addTemplateParams(ContainerT &&container) {
+    templateParams.insert(std::begin(container), std::end(container));
+  }
+
 private:
   /// The name of the declaration, or a resolved name to an inherited function.
   std::string name;
   /// The type that is being aliased. Leave empty for inheriting functions.
   std::string value;
+  /// An optional list of class template parameters.
+  /// This is simply a ordered list of parameter names that are then added as
+  /// template type parameters when the using declaration is emitted.
+  SetVector<std::string, SmallVector<std::string>, StringSet<>> templateParams;
 };
 
 /// This class describes a class field.
@@ -583,21 +606,47 @@ class Class {
   /// returns a pointer to the new constructor.
   template <Method::Properties Properties = Method::None, typename... Args>
   Constructor *addConstructor(Args &&...args) {
+    Method::Properties defaultProperties = Method::Constructor;
+    // If the class has template parameters, the constructor has to be defined
+    // inline.
+    if (!templateParams.empty())
+      defaultProperties |= Method::Inline;
     return addConstructorAndPrune(Constructor(getClassName(),
-                                              Properties | Method::Constructor,
+                                              Properties | defaultProperties,
                                               std::forward<Args>(args)...));
   }
 
   /// Add a new method to this class and prune any methods made redundant by it.
   /// Returns null if the method was not added (because an existing method would
   /// make it redundant). Else, returns a pointer to the new method.
+  template <Method::Properties Properties = Method::None, typename RetTypeT,
+            typename NameT>
+  Method *addMethod(RetTypeT &&retType, NameT &&name,
+                    Method::Properties properties,
+                    ArrayRef<MethodParameter> parameters) {
+    // If the class has template parameters, the has to defined inline.
+    if (!templateParams.empty())
+      properties |= Method::Inline;
+    return addMethodAndPrune(Method(std::forward<RetTypeT>(retType),
+                                    std::forward<NameT>(name),
+                                    Properties | properties, parameters));
+  }
+
+  /// Add a method with statically-known properties.
+  template <Method::Properties Properties = Method::None, typename RetTypeT,
+            typename NameT>
+  Method *addMethod(RetTypeT &&retType, NameT &&name,
+                    ArrayRef<MethodParameter> parameters) {
+    return addMethod(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+                     Properties, parameters);
+  }
+
   template <Method::Properties Properties = Method::None, typename RetTypeT,
             typename NameT, typename... Args>
   Method *addMethod(RetTypeT &&retType, NameT &&name,
                     Method::Properties properties, Args &&...args) {
-    return addMethodAndPrune(
-        Method(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
-               Properties | properties, std::forward<Args>(args)...));
+    return addMethod(std::forward<RetTypeT>(retType), std::forward<NameT>(name),
+                     properties | Properties, {std::forward<Args>(args)...});
   }
 
   /// Add a method with statically-known properties.
@@ -674,6 +723,18 @@ class Class {
   /// Add a parent class.
   ParentClass &addParent(ParentClass parent);
 
+  /// Add a template parameter.
+  template <typename ParamT>
+  void addTemplateParam(ParamT param) {
+    templateParams.insert(stringify(param));
+  }
+
+  /// Add a list of template parameters.
+  template <typename ContainerT>
+  void addTemplateParams(ContainerT &&container) {
+    templateParams.insert(std::begin(container), std::end(container));
+  }
+
   /// Return the C++ name of the class.
   StringRef getClassName() const { return className; }
 
@@ -751,6 +812,9 @@ class Class {
 
   /// A list of declarations in the class, emitted in order.
   std::vector<std::unique_ptr<ClassDeclaration>> declarations;
+
+  /// An optional list of class template parameters.
+  SetVector<std::string, SmallVector<std::string>, StringSet<>> templateParams;
 };
 
 } // namespace tblgen

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 74df33eb1ed78..99c3eed731b17 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -64,6 +64,9 @@ class Operator {
   /// Returns the name of op's adaptor C++ class.
   std::string getAdaptorName() const;
 
+  /// Returns the name of op's generic adaptor C++ class.
+  std::string getGenericAdaptorName() const;
+
   /// Check invariants (like no duplicated or conflicted names) and abort the
   /// process if any invariant is broken.
   void assertInvariants() const;

diff  --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index a7c02d3ae543b..c562cfde0a9cc 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -228,6 +228,13 @@ void ParentClass::writeTo(raw_indented_ostream &os) const {
 //===----------------------------------------------------------------------===//
 
 void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
+  if (!templateParams.empty()) {
+    os << "template <";
+    llvm::interleaveComma(templateParams, os, [&](StringRef paramName) {
+      os << "typename " << paramName;
+    });
+    os << ">\n";
+  }
   os << "using " << name;
   if (!value.empty())
     os << " = " << value;
@@ -275,6 +282,13 @@ ParentClass &Class::addParent(ParentClass parent) {
 }
 
 void Class::writeDeclTo(raw_indented_ostream &os) const {
+  if (!templateParams.empty()) {
+    os << "template <";
+    llvm::interleaveComma(templateParams, os,
+                          [&](StringRef param) { os << "typename " << param; });
+    os << ">\n";
+  }
+
   // Declare the class.
   os << (isStruct ? "struct" : "class") << ' ' << className << ' ';
 
@@ -341,7 +355,7 @@ Visibility Class::getLastVisibilityDecl() const {
   });
   return it == reverseDecls.end()
              ? (isStruct ? Visibility::Public : Visibility::Private)
-             : cast<VisibilityDeclaration>(*it).getVisibility();
+             : cast<VisibilityDeclaration>(**it).getVisibility();
 }
 
 Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index ab4a944da5d2d..44177052aa61c 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -69,6 +69,10 @@ std::string Operator::getAdaptorName() const {
   return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
 }
 
+std::string Operator::getGenericAdaptorName() const {
+  return std::string(llvm::formatv("{0}GenericAdaptor", getCppClassName()));
+}
+
 /// Assert the invariants of accessors generated for the given name.
 static void assertAccessorInvariants(const Operator &op, StringRef name) {
   std::string accessorName =

diff  --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td
index 884c72ce9e6d5..3a5af1268fcfe 100644
--- a/mlir/test/mlir-tblgen/op-decl-and-defs.td
+++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td
@@ -52,20 +52,34 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 
 // CHECK-LABEL: NS::AOp declarations
 
-// CHECK: class AOpAdaptor {
+// CHECK: namespace detail {
+// CHECK: class AOpGenericAdaptorBase {
 // CHECK: public:
-// CHECK:   AOpAdaptor(::mlir::ValueRange values
-// CHECK:   ::mlir::ValueRange getODSOperands(unsigned index);
-// CHECK:   ::mlir::Value getA();
-// CHECK:   ::mlir::ValueRange getB();
 // CHECK:   ::mlir::IntegerAttr getAttr1Attr();
 // CHECK:   uint32_t getAttr1();
 // CHECK:   ::mlir::FloatAttr getSomeAttr2Attr();
 // CHECK:   ::std::optional< ::llvm::APFloat > getSomeAttr2();
 // CHECK:   ::mlir::Region &getSomeRegion();
 // CHECK:   ::mlir::RegionRange getSomeRegions();
+// CHECK: };
+// CHECK: }
+
+// CHECK: template <typename RangeT>
+// CHECK: class AOpGenericAdaptor : public detail::AOpGenericAdaptorBase {
+// CHECK: public:
+// CHECK:   AOpGenericAdaptor(RangeT values,
+// CHECK-SAME: odsOperands(values)
+// CHECK:   RangeT getODSOperands(unsigned index) {
+// CHECK:   ValueT getA() {
+// CHECK:   RangeT getB() {
 // CHECK: private:
-// CHECK:   ::mlir::ValueRange odsOperands;
+// CHECK:   RangeT odsOperands;
+// CHECK: };
+
+// CHECK: class AOpAdaptor : public AOpGenericAdaptor<::mlir::ValueRange> {
+// CHECK: public:
+// CHECK:   AOpAdaptor(AOp
+// CHECK:   ::mlir::LogicalResult verify(
 // CHECK: };
 
 // CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove
@@ -108,10 +122,10 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 
 // DEFS-LABEL: NS::AOp definitions
 
-// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions)
-// DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions()
+// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
+// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
 // DEFS-NEXT: return odsRegions.drop_front(1);
-// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions()
+// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()
 
 // Check AttrSizedOperandSegments
 // ---
@@ -127,15 +141,17 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
   );
 }
 
-// CHECK-LABEL: AttrSizedOperandOpAdaptor(
-// CHECK-SAME:    ::mlir::ValueRange values
-// CHECK-SAME:    ::mlir::DictionaryAttr attrs
-// CHECK:  ::mlir::ValueRange getA();
-// CHECK:  ::mlir::ValueRange getB();
-// CHECK:  ::mlir::Value getC();
-// CHECK:  ::mlir::ValueRange getD();
+// CHECK-LABEL: class AttrSizedOperandOpGenericAdaptorBase {
 // CHECK:  ::mlir::DenseIntElementsAttr getOperandSegmentSizes();
 
+// CHECK-LABEL: AttrSizedOperandOpGenericAdaptor(
+// CHECK-SAME:    RangeT values
+// CHECK-SAME:    ::mlir::DictionaryAttr attrs
+// CHECK:  RangeT getA() {
+// CHECK:  RangeT getB() {
+// CHECK:  ValueT getC() {
+// CHECK:  RangeT getD() {
+
 // Check op trait for 
diff erent number of operands
 // ---
 
@@ -166,7 +182,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
 }
 
 // CHECK-LABEL: NS::EOp declarations
-// CHECK:   ::mlir::Value getA();
+// CHECK:   ::mlir::TypedValue<::mlir::IntegerType> getA();
 // CHECK:   ::mlir::MutableOperandRange getAMutable();
 // CHECK:   ::mlir::TypedValue<::mlir::FloatType> getB();
 // CHECK:   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a)
@@ -335,6 +351,18 @@ def NS_SkipDefaultBuildersOp : NS_Op<"skip_default_builders", []> {
 // Check leading underscore in op name
 // ---
 
+def NS_VarOfVarOperandOp : NS_Op<"var_of_var_operand", []> {
+  let arguments = (ins
+    VariadicOfVariadic<F32, "var_size">:$var_of_var_attr,
+    DenseI32ArrayAttr:$var_size
+  );
+}
+
+// CHECK-LABEL: class VarOfVarOperandOpGenericAdaptor
+// CHECK: public:
+// CHECK: ::llvm::SmallVector<RangeT> getVarOfVarAttr() {
+
+
 def NS__AOp : NS_Op<"_op_with_leading_underscore", []>;
 
 // CHECK-LABEL: NS::_AOp declarations

diff  --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td
index 67ad724e1bd88..7a76f98ead581 100644
--- a/mlir/test/mlir-tblgen/op-operand.td
+++ b/mlir/test/mlir-tblgen/op-operand.td
@@ -14,8 +14,8 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
 
 // CHECK-LABEL: OpA definitions
 
-// CHECK:      OpAAdaptor::OpAAdaptor
-// CHECK-SAME: odsOperands(values), odsAttrs(attrs)
+// CHECK:      OpAGenericAdaptorBase::OpAGenericAdaptorBase
+// CHECK-SAME: odsAttrs(attrs)
 
 // CHECK:      void OpA::build
 // CHECK:        ::mlir::Value input
@@ -39,15 +39,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
   let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
 }
 
-// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput1
-// CHECK-NEXT:    return getODSOperands(0);
-
-// CHECK-LABEL: ::mlir::Value OpDAdaptor::getInput2
-// CHECK-NEXT:    return *getODSOperands(1).begin();
-
-// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput3
-// CHECK-NEXT:    return getODSOperands(2);
-
 // CHECK-LABEL: ::mlir::Operation::operand_range OpD::getInput1
 // CHECK-NEXT: return getODSOperands(0);
 

diff  --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 3512212272f4a..40b688f2b96ca 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -27,6 +27,11 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
   declare<UsingDeclaration>("Op::print");
   /// Type alias for the adaptor class.
   declare<UsingDeclaration>("Adaptor", className + "Adaptor");
+  declare<UsingDeclaration>("GenericAdaptor",
+                            className + "GenericAdaptor<RangeT>")
+      ->addTemplateParam("RangeT");
+  declare<UsingDeclaration>(
+      "FoldAdaptor", "GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>");
 }
 
 void OpClass::finalize() {

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 683f812ac64b0..3b45bb5b1cb96 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -117,11 +117,12 @@ static const char *const opSegmentSizeAttrInitCode = R"(
 ///
 /// {0}: The name of the segment attribute.
 /// {1}: The index of the main operand.
+/// {2}: The range type of adaptor.
 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
   auto tblgenTmpOperands = getODSOperands({1});
   auto sizes = {0}();
 
-  ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
+  ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
   for (int i = 0, e = sizes.size(); i < e; ++i) {{
     tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
     tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
@@ -1190,13 +1191,22 @@ void OpEmitter::genOptionalAttrRemovers() {
 // Generates the code to compute the start and end index of an operand or result
 // range.
 template <typename RangeT>
-static void
-generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
-                              int numVariadic, int numNonVariadic,
-                              StringRef rangeSizeCall, bool hasAttrSegmentSize,
-                              StringRef sizeAttrInit, RangeT &&odsValues) {
+static void generateValueRangeStartAndEnd(
+    Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
+    int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
+    bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
+
+  SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
+  if (isGenericAdaptorBase) {
+    parameters.emplace_back("unsigned", "odsOperandsSize");
+    // The range size is passed per parameter for generic adaptor bases as
+    // using the rangeSizeCall would require the operands, which are not
+    // accessible in the base class.
+    rangeSizeCall = "odsOperandsSize";
+  }
+
   auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
-                                   MethodParameter("unsigned", "index"));
+                                   parameters);
   if (!method)
     return;
   auto &body = method->body();
@@ -1218,8 +1228,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
   }
 }
 
-static std::string generateTypeForGetter(bool isAdaptor,
-                                         const NamedTypeConstraint &value) {
+static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
   std::string str = "::mlir::Value";
   /// If the CPPClassName is not a fully qualified type. Uses of types
   /// across Dialect fail because they are not in the correct namespace. So we
@@ -1229,7 +1238,7 @@ static std::string generateTypeForGetter(bool isAdaptor,
   /// https://github.com/llvm/llvm-project/issues/57279.
   /// Adaptor will have values that are not from the type of their operation and
   /// this is expected, so we dont generate TypedValue for Adaptor
-  if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" &&
+  if (value.constraint.getCPPClassName() != "::mlir::Type" &&
       StringRef(value.constraint.getCPPClassName()).startswith("::"))
     str = llvm::formatv("::mlir::TypedValue<{0}>",
                         value.constraint.getCPPClassName())
@@ -1248,12 +1257,12 @@ static std::string generateTypeForGetter(bool isAdaptor,
 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
 // of ops, in particular for one-operand ops that may not have the
 // `getOperand(unsigned)` method.
-static void generateNamedOperandGetters(const Operator &op, Class &opClass,
-                                        bool isAdaptor, StringRef sizeAttrInit,
-                                        StringRef rangeType,
-                                        StringRef rangeBeginCall,
-                                        StringRef rangeSizeCall,
-                                        StringRef getOperandCallPattern) {
+static void
+generateNamedOperandGetters(const Operator &op, Class &opClass,
+                            Class *genericAdaptorBase, StringRef sizeAttrInit,
+                            StringRef rangeType, StringRef rangeElementType,
+                            StringRef rangeBeginCall, StringRef rangeSizeCall,
+                            StringRef getOperandCallPattern) {
   const int numOperands = op.getNumOperands();
   const int numVariadicOperands = op.getNumVariableLengthOperands();
   const int numNormalOperands = numOperands - numVariadicOperands;
@@ -1281,10 +1290,31 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
 
   // First emit a few "sink" getter methods upon which we layer all nicer named
   // getter methods.
-  generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
-                                numVariadicOperands, numNormalOperands,
-                                rangeSizeCall, attrSizedOperands, sizeAttrInit,
-                                const_cast<Operator &>(op).getOperands());
+  // If generating for an adaptor, the method is put into the non-templated
+  // generic base class, to not require being defined in the header.
+  // Since the operand size can't be determined from the base class however,
+  // it has to be passed as an additional argument. The trampoline below
+  // generates the function with the same signature as the Op in the generic
+  // adaptor.
+  bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
+  generateValueRangeStartAndEnd(
+      /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
+      isGenericAdaptorBase,
+      /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands,
+      numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit,
+      const_cast<Operator &>(op).getOperands());
+  if (isGenericAdaptorBase) {
+    // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
+    // the index. This just calls the implementation in the base class but
+    // passes the operand size as parameter.
+    Method *method = opClass.addMethod("std::pair<unsigned, unsigned>",
+                                       "getODSOperandIndexAndLength",
+                                       MethodParameter("unsigned", "index"));
+    ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
+    MethodBody &body = method->body();
+    body.indent() << formatv(
+        "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall);
+  }
 
   auto *m = opClass.addMethod(rangeType, "getODSOperands",
                               MethodParameter("unsigned", "index"));
@@ -1301,19 +1331,23 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
       continue;
     std::string name = op.getGetterName(operand.name);
     if (operand.isOptional()) {
-      m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
+      m = opClass.addMethod(isGenericAdaptorBase
+                                ? rangeElementType
+                                : generateTypeForGetter(operand),
+                            name);
       ERROR_IF_PRUNED(m, name, op);
-      m->body() << "  auto operands = getODSOperands(" << i << ");\n"
-                << "  return operands.empty() ? ::mlir::Value() : "
-                   "*operands.begin();";
+      m->body().indent() << formatv(
+          "auto operands = getODSOperands({0});\n"
+          "return operands.empty() ? {1}{{} : *operands.begin();",
+          i, rangeElementType);
     } else if (operand.isVariadicOfVariadic()) {
       std::string segmentAttr = op.getGetterName(
           operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
-      if (isAdaptor) {
-        m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", name);
+      if (genericAdaptorBase) {
+        m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name);
         ERROR_IF_PRUNED(m, name, op);
         m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
-                                   segmentAttr, i);
+                                   segmentAttr, i, rangeType);
         continue;
       }
 
@@ -1326,7 +1360,10 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
       ERROR_IF_PRUNED(m, name, op);
       m->body() << "  return getODSOperands(" << i << ");";
     } else {
-      m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
+      m = opClass.addMethod(isGenericAdaptorBase
+                                ? rangeElementType
+                                : generateTypeForGetter(operand),
+                            name);
       ERROR_IF_PRUNED(m, name, op);
       m->body() << "  return *getODSOperands(" << i << ").begin();";
     }
@@ -1344,9 +1381,10 @@ void OpEmitter::genNamedOperandGetters() {
 
   generateNamedOperandGetters(
       op, opClass,
-      /*isAdaptor=*/false,
+      /*genericAdaptorBase=*/nullptr,
       /*sizeAttrInit=*/attrSizeInitCode,
       /*rangeType=*/"::mlir::Operation::operand_range",
+      /*rangeElementType=*/"::mlir::Value",
       /*rangeBeginCall=*/"getOperation()->operand_begin()",
       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
@@ -1431,9 +1469,9 @@ void OpEmitter::genNamedResultGetters() {
   }
 
   generateValueRangeStartAndEnd(
-      opClass, "getODSResultIndexAndLength", numVariadicResults,
-      numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
-      attrSizeInitCode, op.getResults());
+      opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
+      numVariadicResults, numNormalResults, "getOperation()->getNumResults()",
+      attrSizedResults, attrSizeInitCode, op.getResults());
 
   auto *m =
       opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
@@ -1448,8 +1486,7 @@ void OpEmitter::genNamedResultGetters() {
       continue;
     std::string name = op.getGetterName(result.name);
     if (result.isOptional()) {
-      m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
-                            name);
+      m = opClass.addMethod(generateTypeForGetter(result), name);
       ERROR_IF_PRUNED(m, name, op);
       m->body()
           << "  auto results = getODSResults(" << i << ");\n"
@@ -1459,8 +1496,7 @@ void OpEmitter::genNamedResultGetters() {
       ERROR_IF_PRUNED(m, name, op);
       m->body() << "  return getODSResults(" << i << ");";
     } else {
-      m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
-                            name);
+      m = opClass.addMethod(generateTypeForGetter(result), name);
       ERROR_IF_PRUNED(m, name, op);
       m->body() << "  return *getODSResults(" << i << ").begin();";
     }
@@ -2906,8 +2942,19 @@ void OpEmitter::genOpAsmInterface() {
 
 namespace {
 // Helper class to emit Op operand adaptors to an output stream.  Operand
-// adaptors are wrappers around ArrayRef<Value> that provide named operand
+// adaptors are wrappers around random access ranges that provide named operand
 // getters identical to those defined in the Op.
+// This currently generates 3 classes per Op:
+// * A Base class within the 'detail' namespace, which contains all logic and
+//   members independent of the random access range that is indexed into.
+//   In other words, it contains all the attribute and region getters.
+// * A templated class named '{OpName}GenericAdaptor' with a template parameter
+//   'RangeT' that is indexed into by the getters to access the operands.
+//   It contains all getters to access operands and inherits from the previous
+//   class.
+// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
+//   with 'mlir::ValueRange' as template parameter. It adds a constructor from
+//   an instance of the op type and a verify function.
 class OpOperandAdaptorEmitter {
 public:
   static void
@@ -2931,7 +2978,9 @@ class OpOperandAdaptorEmitter {
   // The operation for which to emit an adaptor.
   const Operator &op;
 
-  // The generated adaptor class.
+  // The generated adaptor classes.
+  Class genericAdaptorBase;
+  Class genericAdaptor;
   Class adaptor;
 
   // The emitter containing all of the locally emitted verification functions.
@@ -2945,42 +2994,47 @@ class OpOperandAdaptorEmitter {
 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     const Operator &op,
     const StaticVerifierFunctionEmitter &staticVerifierEmitter)
-    : op(op), adaptor(op.getAdaptorName()),
+    : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
+      genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
       staticVerifierEmitter(staticVerifierEmitter),
       emitHelper(op, /*emitForOp=*/false) {
-  adaptor.addField("::mlir::ValueRange", "odsOperands");
-  adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
-  adaptor.addField("::mlir::RegionRange", "odsRegions");
-  adaptor.addField("::std::optional<::mlir::OperationName>", "odsOpName");
+
+  genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
+  genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
+  genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
+  genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
+                                    "odsOpName");
+
+  genericAdaptor.addTemplateParam("RangeT");
+  genericAdaptor.addField("RangeT", "odsOperands");
+  genericAdaptor.addParent(
+      ParentClass("detail::" + genericAdaptorBase.getClassName()));
+  genericAdaptor.declare<UsingDeclaration>(
+      "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
+  genericAdaptor.declare<UsingDeclaration>(
+      "Base", "detail::" + genericAdaptorBase.getClassName());
 
   const auto *attrSizedOperands =
-      op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
+      op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
   {
     SmallVector<MethodParameter> paramList;
-    paramList.emplace_back("::mlir::ValueRange", "values");
     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
                            attrSizedOperands ? "" : "nullptr");
     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
-    auto *constructor = adaptor.addConstructor(std::move(paramList));
-
-    constructor->addMemberInitializer("odsOperands", "values");
-    constructor->addMemberInitializer("odsAttrs", "attrs");
-    constructor->addMemberInitializer("odsRegions", "regions");
+    auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
+    baseConstructor->addMemberInitializer("odsAttrs", "attrs");
+    baseConstructor->addMemberInitializer("odsRegions", "regions");
 
-    MethodBody &body = constructor->body();
+    MethodBody &body = baseConstructor->body();
     body.indent() << "if (odsAttrs)\n";
     body.indent() << formatv(
         "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
         op.getOperationName());
-  }
 
-  {
-    auto *constructor =
-        adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
-    constructor->addMemberInitializer("odsOperands", "op->getOperands()");
-    constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
-    constructor->addMemberInitializer("odsRegions", "op->getRegions()");
-    constructor->addMemberInitializer("odsOpName", "op->getName()");
+    paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
+    auto *constructor = genericAdaptor.addConstructor(std::move(paramList));
+    constructor->addMemberInitializer("Base", "attrs, regions");
+    constructor->addMemberInitializer("odsOperands", "values");
   }
 
   std::string sizeAttrInit;
@@ -2988,16 +3042,18 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
                            emitHelper.getAttr(operandSegmentAttrName));
   }
-  generateNamedOperandGetters(op, adaptor,
-                              /*isAdaptor=*/true, sizeAttrInit,
-                              /*rangeType=*/"::mlir::ValueRange",
+  generateNamedOperandGetters(op, genericAdaptor,
+                              /*genericAdaptorBase=*/&genericAdaptorBase,
+                              /*sizeAttrInit=*/sizeAttrInit,
+                              /*rangeType=*/"RangeT",
+                              /*rangeElementType=*/"ValueT",
                               /*rangeBeginCall=*/"odsOperands.begin()",
                               /*rangeSizeCall=*/"odsOperands.size()",
                               /*getOperandCallPattern=*/"odsOperands[{0}]");
 
   // Any invalid overlap for `getOperands` will have been diagnosed before here
   // already.
-  if (auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"))
+  if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
     m->body() << "  return odsOperands;";
 
   FmtContext fctx;
@@ -3006,7 +3062,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
   // Generate named accessor with Attribute return type.
   auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
                                      Attribute attr) {
-    auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
+    auto *method =
+        genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr");
     ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
     auto &body = method->body().indent();
     body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
@@ -3028,7 +3085,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
   };
 
   {
-    auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
+    auto *m =
+        genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes");
     ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
     m->body() << "  return odsAttrs;";
   }
@@ -3039,7 +3097,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       continue;
     std::string emitName = op.getGetterName(name);
     emitAttrWithStorageType(name, emitName, attr);
-    emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
+    emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr);
   }
 
   unsigned numRegions = op.getNumRegions();
@@ -3051,25 +3109,44 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     // Generate the accessors for a variadic region.
     std::string name = op.getGetterName(region.name);
     if (region.isVariadic()) {
-      auto *m = adaptor.addMethod("::mlir::RegionRange", name);
+      auto *m = genericAdaptorBase.addMethod("::mlir::RegionRange", name);
       ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
       m->body() << formatv("  return odsRegions.drop_front({0});", i);
       continue;
     }
 
-    auto *m = adaptor.addMethod("::mlir::Region &", name);
+    auto *m = genericAdaptorBase.addMethod("::mlir::Region &", name);
     ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
     m->body() << formatv("  return *odsRegions[{0}];", i);
   }
   if (numRegions > 0) {
     // Any invalid overlap for `getRegions` will have been diagnosed before here
     // already.
-    if (auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"))
+    if (auto *m =
+            genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions"))
       m->body() << "  return odsRegions;";
   }
 
+  StringRef genericAdaptorClassName = genericAdaptor.getClassName();
+  adaptor.addParent(ParentClass(genericAdaptorClassName))
+      .addTemplateParam("::mlir::ValueRange");
+  adaptor.declare<VisibilityDeclaration>(Visibility::Public);
+  adaptor.declare<UsingDeclaration>(genericAdaptorClassName +
+                                    "::" + genericAdaptorClassName);
+  {
+    // Constructor taking the Op as single parameter.
+    auto *constructor =
+        adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
+    constructor->addMemberInitializer(
+        adaptor.getClassName(),
+        "op->getOperands(), op->getAttrDictionary(), op->getRegions()");
+  }
+
   // Add verification function.
   addVerification();
+
+  genericAdaptorBase.finalize();
+  genericAdaptor.finalize();
   adaptor.finalize();
 }
 
@@ -3090,14 +3167,26 @@ void OpOperandAdaptorEmitter::emitDecl(
     const Operator &op,
     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
     raw_ostream &os) {
-  OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
+  OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
+  {
+    NamespaceEmitter ns(os, "detail");
+    emitter.genericAdaptorBase.writeDeclTo(os);
+  }
+  emitter.genericAdaptor.writeDeclTo(os);
+  emitter.adaptor.writeDeclTo(os);
 }
 
 void OpOperandAdaptorEmitter::emitDef(
     const Operator &op,
     const StaticVerifierFunctionEmitter &staticVerifierEmitter,
     raw_ostream &os) {
-  OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
+  OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
+  {
+    NamespaceEmitter ns(os, "detail");
+    emitter.genericAdaptorBase.writeDefTo(os);
+  }
+  emitter.genericAdaptor.writeDefTo(os);
+  emitter.adaptor.writeDefTo(os);
 }
 
 // Emits the opcode enum and op classes.

diff  --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
new file mode 100644
index 0000000000000..a3efb34889f41
--- /dev/null
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -0,0 +1,63 @@
+//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace test;
+
+using testing::ElementsAre;
+
+TEST(Adaptor, GenericAdaptorsOperandAccess) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+  Builder builder(&context);
+
+  // Has normal and Variadic arguments.
+  MixedNormalVariadicOperandOp::FoldAdaptor a({});
+  {
+    SmallVector<int> v = {0, 1, 2, 3, 4};
+    MixedNormalVariadicOperandOp::GenericAdaptor<ArrayRef<int>> b(v);
+    EXPECT_THAT(b.getInput1(), ElementsAre(0, 1));
+    EXPECT_EQ(b.getInput2(), 2);
+    EXPECT_THAT(b.getInput3(), ElementsAre(3, 4));
+  }
+
+  // Has optional arguments.
+  OIListSimple::FoldAdaptor c({}, nullptr);
+  {
+    // Optional arguments return the default constructed value if not present.
+    // Using optional instead of plain int here to 
diff erentiate absence of
+    // value from the value 0.
+    SmallVector<std::optional<int>> v = {0, 4};
+    OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(
+        v, builder.getDictionaryAttr({builder.getNamedAttr(
+               "operand_segment_sizes",
+               builder.getDenseI32ArrayAttr({1, 0, 1}))}));
+    EXPECT_EQ(d.getArg0(), 0);
+    EXPECT_EQ(d.getArg1(), std::nullopt);
+    EXPECT_EQ(d.getArg2(), 4);
+  }
+
+  // Has VariadicOfVariadic arguments.
+  FormatVariadicOfVariadicOperand::FoldAdaptor e({});
+  {
+    SmallVector<int> v = {0, 1, 2, 3, 4};
+    FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(
+        v, builder.getDictionaryAttr({builder.getNamedAttr(
+               "operand_segments", builder.getDenseI32ArrayAttr({3, 2, 0}))}));
+    SmallVector<ArrayRef<int>> operand = f.getOperand();
+    ASSERT_EQ(operand.size(), (std::size_t)3);
+    EXPECT_THAT(operand[0], ElementsAre(0, 1, 2));
+    EXPECT_THAT(operand[1], ElementsAre(3, 4));
+    EXPECT_THAT(operand[2], ElementsAre());
+  }
+}

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 904ea13447bb0..41964ab0d30af 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRIRTests
+  AdaptorTest.cpp
   AttributeTest.cpp
   BlockAndValueMapping.cpp
   DialectTest.cpp


        


More information about the Mlir-commits mailing list