[Mlir-commits] [mlir] b0774e5 - [mlir][ods] ODS ops get an `extraClassDefinition`

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 5 17:43:31 PST 2022


Author: Mogball
Date: 2022-01-06T01:43:26Z
New Revision: b0774e5f500b5bb68451ee3f0590035d0f6e4e54

URL: https://github.com/llvm/llvm-project/commit/b0774e5f500b5bb68451ee3f0590035d0f6e4e54
DIFF: https://github.com/llvm/llvm-project/commit/b0774e5f500b5bb68451ee3f0590035d0f6e4e54.diff

LOG: [mlir][ods] ODS ops get an `extraClassDefinition`

Extra definitions are placed in the generated source file for each op class. The substitution `$cppClass` is replaced by the op's C++ class name.

This is useful when declaring but not defining methods in TableGen base classes:

```
class BaseOp<string mnemonic>
    : Op<MyDialect, mnemonic, [DeclareOpInterfaceMethods<SomeInterface>] {
  let extraClassDeclaration = [{
    // ZOp is declared at at the bottom of the file and is incomplete here
    ZOp getParent();
  }];
  let extraClassDefinition = [{
    int $cppClass::someInterfaceMethod() {
      return someUtilityFunction(*this);
    }
    ZOp $cppClass::getParent() {
      return dyn_cast<ZOp>(this->getParentOp());
    }
  }];
}
```

Certain things may prevent defining these functions inline, in the declaration. In this example, `ZOp` in the same dialect is incomplete at the function declaration because ops classes are declared in alphabetical order. Alternatively, functions may be too big to be desired as inlined, or they may require dependencies that create cyclic includes, or they may be calling a templated utility function that one may not want to expose in a header. If the functions are not inlined, then inheriting from the base class N times means that each function will need to be defined N times. With `extraClassDefinitions`, they only need to be defined once.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D115783

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Class.h
    mlir/include/mlir/TableGen/Operator.h
    mlir/lib/TableGen/Class.cpp
    mlir/lib/TableGen/Operator.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/OpClass.cpp
    mlir/tools/mlir-tblgen/OpClass.h
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index ec9e6fdc80ddb..1e1abdc20d2f7 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -964,6 +964,16 @@ Note that `extraClassDeclaration` is a mechanism intended for long-tail cases by
 power users; for not-yet-implemented widely-applicable cases, improving the
 infrastructure is preferable.
 
+### Extra definitions
+
+When defining base op classes in TableGen that are inherited many times by
+
diff erent ops, users may want to provide common definitions of utility and
+interface functions. However, many of these definitions may not be desirable or
+possible in `extraClassDeclaration`, which append them to the op's C++ class
+declaration. In these cases, users can add an `extraClassDefinition` to define
+code that is added to the generated source file inside the op's C++ namespace.
+The substitution `$cppClass` is replaced by the op's C++ class name.
+
 ### Generated C++ code
 
 [OpDefinitionsGen][OpDefinitionsGen] processes the op definition spec file and

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index f1a5446ad1f97..8e70f48440089 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2445,6 +2445,11 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
   // Additional code that will be added to the public part of the generated
   // C++ code of the op declaration.
   code extraClassDeclaration = ?;
+
+  // Additional code that will be added to the generated source file. The
+  // generated code is placed inside the op's C++ namespace. `$cppClass` is
+  // replaced by the op's C++ class name.
+  code extraClassDefinition = ?;
 }
 
 // Base class for ops with static/dynamic offset, sizes and strides

diff  --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 1f310fe1d0823..a8a710ff85fed 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -532,22 +532,32 @@ class VisibilityDeclaration
   Visibility visibility;
 };
 
-/// Unstructured extra class declarations, from TableGen definitions. The
-/// default visibility of extra class declarations is up to the owning class.
+/// Unstructured extra class declarations and definitions, from TableGen
+/// definitions. The default visibility of extra class declarations is up to the
+/// owning class.
 class ExtraClassDeclaration
     : public ClassDeclarationBase<ClassDeclaration::ExtraClassDeclaration> {
 public:
   /// Create an extra class declaration.
-  ExtraClassDeclaration(StringRef extraClassDeclaration)
-      : extraClassDeclaration(extraClassDeclaration) {}
+  ExtraClassDeclaration(StringRef extraClassDeclaration,
+                        StringRef extraClassDefinition = "")
+      : extraClassDeclaration(extraClassDeclaration),
+        extraClassDefinition(extraClassDefinition) {}
 
   /// Write the extra class declarations.
   void writeDeclTo(raw_indented_ostream &os) const override;
 
+  /// Write the extra class definitions.
+  void writeDefTo(raw_indented_ostream &os,
+                  StringRef namePrefix) const override;
+
 private:
   /// The string of the extra class declarations. It is re-indented before
   /// printed.
   StringRef extraClassDeclaration;
+  /// The string of the extra class definitions. It is re-indented before
+  /// printed.
+  StringRef extraClassDefinition;
 };
 
 /// A class used to emit C++ classes from Tablegen.  Contains a list of public

diff  --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 44f10440c1e33..ddfb7dd0178bf 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -235,6 +235,9 @@ class Operator {
   // Returns this op's extra class declaration code.
   StringRef getExtraClassDeclaration() const;
 
+  // Returns this op's extra class definition code.
+  StringRef getExtraClassDefinition() const;
+
   // Returns the Tablegen definition this operator was constructed from.
   // TODO: do not expose the TableGen record, this is a temporary solution to
   // OpEmitter requiring a Record because Operator does not provide enough

diff  --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp
index 9b7124e2e3a59..a7c02d3ae543b 100644
--- a/mlir/lib/TableGen/Class.cpp
+++ b/mlir/lib/TableGen/Class.cpp
@@ -260,6 +260,11 @@ void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const {
   os.printReindented(extraClassDeclaration);
 }
 
+void ExtraClassDeclaration::writeDefTo(raw_indented_ostream &os,
+                                       StringRef namePrefix) const {
+  os.printReindented(extraClassDefinition);
+}
+
 //===----------------------------------------------------------------------===//
 // Class definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index f1c1fe5346661..cde617dcd30b3 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -128,6 +128,13 @@ StringRef Operator::getExtraClassDeclaration() const {
   return def.getValueAsString(attr);
 }
 
+StringRef Operator::getExtraClassDefinition() const {
+  constexpr auto attr = "extraClassDefinition";
+  if (def.isValueUnset(attr))
+    return {};
+  return def.getValueAsString(attr);
+}
+
 const llvm::Record &Operator::getDef() const { return def; }
 
 bool Operator::skipDefaultBuilders() const {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6fad11b85ad8c..28dbc271d72dc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -382,7 +382,10 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
 
   let extraClassDeclaration = [{
     /// Return the callee of this operation.
-    ::mlir::CallInterfaceCallable getCallableForCallee() {
+    ::mlir::CallInterfaceCallable getCallableForCallee();
+  }];
+  let extraClassDefinition = [{
+    ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
       return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
     }
   }];

diff  --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 9524dc9210b82..3512212272f4a 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -15,8 +15,10 @@ using namespace mlir::tblgen;
 // OpClass definitions
 //===----------------------------------------------------------------------===//
 
-OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
+OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
+                 std::string extraClassDefinition)
     : Class(name.str()), extraClassDeclaration(extraClassDeclaration),
+      extraClassDefinition(std::move(extraClassDefinition)),
       parent(addParent("::mlir::Op")) {
   parent.addTemplateParam(getClassName().str());
   declare<VisibilityDeclaration>(Visibility::Public);
@@ -30,5 +32,5 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
 void OpClass::finalize() {
   Class::finalize();
   declare<VisibilityDeclaration>(Visibility::Public);
-  declare<ExtraClassDeclaration>(extraClassDeclaration);
+  declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
 }

diff  --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h
index b0558a0e55139..6b90dd2c3a3a3 100644
--- a/mlir/tools/mlir-tblgen/OpClass.h
+++ b/mlir/tools/mlir-tblgen/OpClass.h
@@ -25,7 +25,8 @@ class OpClass : public Class {
   /// - inheritance of `print`
   /// - a type alias for the associated adaptor class
   ///
-  OpClass(StringRef name, StringRef extraClassDeclaration);
+  OpClass(StringRef name, StringRef extraClassDeclaration,
+          std::string extraClassDefinition);
 
   /// Add an op trait.
   void addTrait(Twine trait) { parent.addTemplateParam(trait.str()); }
@@ -39,6 +40,8 @@ class OpClass : public Class {
 private:
   /// Hand-written extra class declarations.
   StringRef extraClassDeclaration;
+  /// Hand-written extra class definitions.
+  std::string extraClassDefinition;
   /// The parent class, which also contains the traits to be inherited.
   ParentClass &parent;
 };

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index f024b90d33404..8511df9c54e60 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -557,10 +557,18 @@ static void genAttributeVerifier(
   }
 }
 
+/// Op extra class definitions have a `$cppClass` substitution that is to be
+/// replaced by the C++ class name.
+static std::string formatExtraDefinitions(const Operator &op) {
+  FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
+  return tgfmt(op.getExtraClassDefinition(), &ctx).str();
+}
+
 OpEmitter::OpEmitter(const Operator &op,
                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
     : def(op.getDef()), op(op),
-      opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
+      opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
+              formatExtraDefinitions(op)),
       staticVerifierEmitter(staticVerifierEmitter) {
   verifyCtx.withOp("(*this->getOperation())");
   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");


        


More information about the Mlir-commits mailing list