[Mlir-commits] [mlir] ada0d41 - [mlir][ods] Allow filtering of ops

Jacques Pienaar llvmlistbot at llvm.org
Mon Jun 22 14:57:16 PDT 2020


Author: Jacques Pienaar
Date: 2020-06-22T14:56:54-07:00
New Revision: ada0d41dbc26f013e2741d1f9e0f164943342435

URL: https://github.com/llvm/llvm-project/commit/ada0d41dbc26f013e2741d1f9e0f164943342435
DIFF: https://github.com/llvm/llvm-project/commit/ada0d41dbc26f013e2741d1f9e0f164943342435.diff

LOG: [mlir][ods] Allow filtering of ops

Add option to filter which op the OpDefinitionsGen run on. This enables having multiple ops together in the same TD file but generating different CC files for them (useful if one wants to use multiclasses or split out 1 dialect into multiple different libraries). There is probably more general query here (e.g., split out all ops that don't have a verify method, or that are commutative) but filtering based on op name (e.g., test.a_op) seemed a reasonable start and didn't require inventing a query specification mechanism here.

Differential Revision: https://reviews.llvm.org/D82319

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/op-decl.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td
index f5bf03e57ce7..8d58d5bb7395 100644
--- a/mlir/test/mlir-tblgen/op-decl.td
+++ b/mlir/test/mlir-tblgen/op-decl.td
@@ -1,4 +1,5 @@
 // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck  %s
+// RUN: mlir-tblgen -gen-op-decls -op-regex="test.a_op" -I %S/../../include %s | FileCheck  %s --check-prefix=REDUCE
 
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -195,3 +196,5 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
 // CHECK-LABEL: _BOp declarations
 // CHECK: class _BOp : public Op<_BOp
 
+// REDUCE-LABEL: NS::AOp declarations
+// REDUCE-NOT: NS::BOp declarations

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 21dccd4f3d5a..6aa7b01dd89f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -21,6 +21,8 @@
 #include "mlir/TableGen/SideEffects.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -32,6 +34,13 @@ using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
 
+cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
+
+static cl::opt<std::string>
+    opFilter("op-regex",
+             cl::desc("Regex of name of op's to filter (no filter if empty)"),
+             cl::cat(opDefGenCat));
+
 static const char *const tblgenNamePrefix = "tblgen_";
 static const char *const generatedArgName = "odsArg";
 static const char *const builderOpState = "odsState";
@@ -2081,10 +2090,37 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
       [&os]() { os << ",\n"; });
 }
 
+static std::string getOperationName(const Record &def) {
+  auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
+  auto opName = def.getValueAsString("opName");
+  if (prefix.empty())
+    return std::string(opName);
+  return std::string(llvm::formatv("{0}.{1}", prefix, opName));
+}
+
+static std::vector<Record *>
+getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
+                         StringRef className) {
+  Record *classDef = recordKeeper.getClass(className);
+  if (!classDef)
+    PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
+
+  llvm::Regex includeRegex(opFilter);
+  std::vector<Record *> defs;
+  for (const auto &def : recordKeeper.getDefs()) {
+    if (def.second->isSubClassOf(classDef)) {
+      if (opFilter.empty() || includeRegex.match(getOperationName(*def.second)))
+        defs.push_back(def.second.get());
+    }
+  }
+
+  return defs;
+}
+
 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Declarations", os);
 
-  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
   emitOpClasses(defs, os, /*emitDecl=*/true);
 
   return false;
@@ -2093,7 +2129,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Definitions", os);
 
-  const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
+  const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
   emitOpList(defs, os);
   emitOpClasses(defs, os, /*emitDecl=*/false);
 


        


More information about the Mlir-commits mailing list