[llvm] [mlir] [mlir][ods] Allow sharding of op definitions (PR #89423)

Jeff Niu via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 13:42:43 PDT 2024


https://github.com/Mogball updated https://github.com/llvm/llvm-project/pull/89423

>From f341dcf4775c468514483be7a972e1e182aea4f1 Mon Sep 17 00:00:00 2001
From: Mogball <jeffniu22 at gmail.com>
Date: Thu, 23 Jun 2022 20:44:59 +0000
Subject: [PATCH] [mlir][ods] Allow sharding of op definitions

Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the
op class definitions and op list into N segments, e.g.

```
// mlir-tblgen -gen-op-defs -op-shard-count=2

void FooDialect::initialize() {
  addOperations<
  >();
  addOperations<
  >();
}

```

When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.

stack-info: PR: https://github.com/llvm/llvm-project/pull/89423, branch: users/mogball/pr_1
---
 mlir/CMakeLists.txt                           |   3 +
 mlir/cmake/modules/AddMLIR.cmake              |  38 ++++
 mlir/cmake/modules/CMakeLists.txt             |   2 +
 mlir/cmake/modules/MLIRConfig.cmake.in        |   1 +
 mlir/include/mlir/TableGen/CodeGenHelpers.h   |  12 +-
 mlir/lib/TableGen/CodeGenHelpers.cpp          |  15 +-
 mlir/test/mlir-tblgen/shard-op-defs.td        |  33 ++++
 mlir/tools/mlir-src-sharder/CMakeLists.txt    |  14 ++
 .../mlir-src-sharder/mlir-src-sharder.cpp     | 114 ++++++++++++
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 164 ++++++++++++++----
 mlir/tools/mlir-tblgen/OpGenHelpers.cpp       |  25 ++-
 mlir/tools/mlir-tblgen/OpGenHelpers.h         |   5 +
 .../llvm-project-overlay/mlir/BUILD.bazel     |   9 +
 .../llvm-project-overlay/mlir/tblgen.bzl      | 133 ++++++++++++++
 14 files changed, 519 insertions(+), 49 deletions(-)
 create mode 100644 mlir/test/mlir-tblgen/shard-op-defs.td
 create mode 100644 mlir/tools/mlir-src-sharder/CMakeLists.txt
 create mode 100644 mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp

diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 5c4301af040b47..4c0ef8387b8dff 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
 add_subdirectory(tools/mlir-linalg-ods-gen)
 add_subdirectory(tools/mlir-pdll)
 add_subdirectory(tools/mlir-tblgen)
+add_subdirectory(tools/mlir-src-sharder)
 set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
 set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
 
 add_subdirectory(include/mlir)
 add_subdirectory(lib)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 1d2ed748bc2f13..afb74fb2d00025 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
   tablegen(MLIR ${ARGV})
   set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
       PARENT_SCOPE)
+
+  # Get the current set of include paths for this td file.
+  cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
+  get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
+  list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
+  # Filter out any empty include items.
+  list(REMOVE_ITEM tblgen_includes "")
+
+  # Build the absolute path for the current input file.
+  if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+    set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+  else()
+    set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
+  endif()
+
+  # Append the includes used for this file to the tablegen_compile_commands
+  # file.
+  file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
+      "--- !FileInfo:\n"
+      "  filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
+      "  includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
+  )
 endfunction()
 
 # Clear out any pre-existing compile_commands file before processing. This
@@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
   add_dependencies(mlir-headers MLIR${dialect}IncGen)
 endfunction()
 
+# Declare sharded dialect operation declarations and definitions
+function(add_sharded_ops ops_target shard_count)
+  set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
+  mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
+  mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
+  set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
+  foreach(index RANGE ${shard_count})
+    set(SHARDED_SRC ${ops_target}.${index}.cpp)
+    list(APPEND SHARDED_SRCS ${SHARDED_SRC})
+    tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
+    set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
+  endforeach()
+  add_public_tablegen_target(MLIR${ops_target}ShardGen)
+  set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
+endfunction()
+
 # Declare a dialect in the include directory
 function(add_mlir_interface interface)
   set(LLVM_TARGET_DEFINITIONS ${interface}.td)
diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt
index 8d2904ef46dfe8..3ac1c79b090ed6 100644
--- a/mlir/cmake/modules/CMakeLists.txt
+++ b/mlir/cmake/modules/CMakeLists.txt
@@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
 # Refer to the best host mlir-tbgen, which might be a host-optimized version
 set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
 set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
 
 configure_file(
   ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
@@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
 # if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
 set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
 set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
 
 configure_file(
   ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in
index d4da3cd98cce98..7076d94a32f2bc 100644
--- a/mlir/cmake/modules/MLIRConfig.cmake.in
+++ b/mlir/cmake/modules/MLIRConfig.cmake.in
@@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
 set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
 set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
 set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
 set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
 set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
 set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index dd17a44c889bbe..c263c69c53d1e3 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -99,8 +99,14 @@ class NamespaceEmitter {
 ///
 class StaticVerifierFunctionEmitter {
 public:
+  /// Create a constraint uniquer with a unique prefix derived from the record
+  /// keeper with an optional tag.
   StaticVerifierFunctionEmitter(raw_ostream &os,
-                                const llvm::RecordKeeper &records);
+                                const llvm::RecordKeeper &records,
+                                StringRef tag = "");
+
+  /// Collect and unique all the constraints used by operations.
+  void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
 
   /// Collect and unique all compatible type, attribute, successor, and region
   /// constraints from the operations in the file and emit them at the top of
@@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter {
   ///
   /// Constraints that do not meet the restriction that they can only reference
   /// `$_self` and `$_op` are not uniqued.
-  void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
+  void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
 
   /// Unique all compatible type and attribute constraints from a pattern file
   /// and emit them at the top of the generated file.
@@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
   /// Emit pattern constraints.
   void emitPatternConstraints();
 
-  /// Collect and unique all the constraints used by operations.
-  void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
   /// Collect and unique all pattern constraints.
   void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
 
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index d906de6b56afc0..59865146e20bc4 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -24,7 +24,8 @@ using namespace mlir::tblgen;
 
 /// Generate a unique label based on the current file name to prevent name
 /// collisions if multiple generated files are included at once.
-static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
+static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
+                                        StringRef tag) {
   // Use the input file name when generating a unique name.
   std::string inputFilename = records.getInputFilename();
 
@@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
   nameRef.consume_back(".td");
 
   // Sanitize any invalid characters.
-  std::string uniqueName;
+  std::string uniqueName(tag);
   for (char c : nameRef) {
     if (llvm::isAlnum(c) || c == '_')
       uniqueName.push_back(c);
@@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
 }
 
 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
-    raw_ostream &os, const llvm::RecordKeeper &records)
-    : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
+    raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
+    : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
 
 void StaticVerifierFunctionEmitter::emitOpConstraints(
-    ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
-  collectOpConstraints(opDefs);
-  if (emitDecl)
-    return;
-
+    ArrayRef<llvm::Record *> opDefs) {
   NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
   emitTypeConstraints();
   emitAttrConstraints();
diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td
new file mode 100644
index 00000000000000..84ac6b0fbe9ebe
--- /dev/null
+++ b/mlir/test/mlir-tblgen/shard-op-defs.td
@@ -0,0 +1,33 @@
+// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
+// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "test";
+  let cppNamespace = "test";
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []> 
+    : Op<Test_Dialect, mnemonic, traits>;
+
+def OpA : Test_Op<"a">;
+def OpB : Test_Op<"b">;
+def OpC : Test_Op<"c">;
+
+// DECLS: OpA
+// DECLS: OpB
+// DECLS: OpC
+// DECLS: registerTestDialectOperations(
+// DECLS: registerTestDialectOperations0(
+// DECLS: registerTestDialectOperations1(
+
+// DEFS-LABEL: GET_OP_DEFS_0
+// DEFS: void test::registerTestDialectOperations(
+// DEFS: void test::registerTestDialectOperations0(
+// DEFS: OpAAdaptor
+// DEFS: OpBAdaptor
+
+// DEFS-LABEL: GET_OP_DEFS_1
+// DEFS: void test::registerTestDialectOperations1(
+// DEFS: OpCAdaptor
diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt
new file mode 100644
index 00000000000000..4ef870b61124ad
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt
@@ -0,0 +1,14 @@
+set(LLVM_LINK_COMPONENTS Support)
+set(LIBS MLIRSupport)
+
+add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
+  mlir-src-sharder.cpp
+
+  DEPENDS
+  ${LIBS}
+  )
+
+set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
+target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
+
+mlir_check_all_link_libraries(mlir-src-sharder)
diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
new file mode 100644
index 00000000000000..dc1e2939c7d25b
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
@@ -0,0 +1,114 @@
+//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+/// Create a dependency file for `-d` option.
+///
+/// This functionality is generally only for the benefit of the build system,
+/// and is modeled after the same option in TableGen.
+static LogicalResult createDependencyFile(StringRef outputFilename,
+                                          StringRef dependencyFile) {
+  if (outputFilename == "-") {
+    llvm::errs() << "error: the option -d must be used together with -o\n";
+    return failure();
+  }
+
+  std::string errorMessage;
+  std::unique_ptr<llvm::ToolOutputFile> outputFile =
+      openOutputFile(dependencyFile, &errorMessage);
+  if (!outputFile) {
+    llvm::errs() << errorMessage << "\n";
+    return failure();
+  }
+
+  outputFile->os() << outputFilename << ":\n";
+  outputFile->keep();
+  return success();
+}
+
+int main(int argc, char **argv) {
+  // FIXME: This is necessary because we link in TableGen, which defines its
+  // options as static variables.. some of which overlap with our options.
+  llvm::cl::ResetCommandLineParser();
+
+  llvm::cl::opt<unsigned> opShardIndex(
+      "op-shard-index", llvm::cl::desc("The current shard index"));
+  llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+                                           llvm::cl::desc("<input file>"),
+                                           llvm::cl::init("-"));
+  llvm::cl::opt<std::string> outputFilename(
+      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+      llvm::cl::init("-"));
+  llvm::cl::list<std::string> includeDirs(
+      "I", llvm::cl::desc("Directory of include files"),
+      llvm::cl::value_desc("directory"), llvm::cl::Prefix);
+  llvm::cl::opt<std::string> dependencyFilename(
+      "d", llvm::cl::desc("Dependency filename"),
+      llvm::cl::value_desc("filename"), llvm::cl::init(""));
+  llvm::cl::opt<bool> writeIfChanged(
+      "write-if-changed",
+      llvm::cl::desc("Only write to the output file if it changed"));
+
+  llvm::InitLLVM y(argc, argv);
+  llvm::cl::ParseCommandLineOptions(argc, argv);
+
+  // Open the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> inputFile =
+      openInputFile(inputFilename, &errorMessage);
+  if (!inputFile) {
+    llvm::errs() << errorMessage << "\n";
+    return 1;
+  }
+
+  // Write the output to a buffer.
+  std::string outputStr;
+  llvm::raw_string_ostream os(outputStr);
+  os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
+     << inputFile->getBuffer();
+
+  // Determine whether we need to write the output file.
+  bool shouldWriteOutput = true;
+  if (writeIfChanged) {
+    // Only update the real output file if there are any differences. This
+    // prevents recompilation of all the files depending on it if there aren't
+    // any.
+    if (auto existingOrErr =
+            llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
+      if (std::move(existingOrErr.get())->getBuffer() == os.str())
+        shouldWriteOutput = false;
+  }
+
+  // Populate the output file if necessary.
+  if (shouldWriteOutput) {
+    std::unique_ptr<llvm::ToolOutputFile> outputFile =
+        openOutputFile(outputFilename, &errorMessage);
+    if (!outputFile) {
+      llvm::errs() << errorMessage << "\n";
+      return 1;
+    }
+    outputFile->os() << os.str();
+    outputFile->keep();
+  }
+
+  // Always write the depfile, even if the main output hasn't changed. If it's
+  // missing, Ninja considers the output dirty.
+  if (!dependencyFilename.empty())
+    if (failed(createDependencyFile(outputFilename, dependencyFilename)))
+      return 1;
+
+  return 0;
+}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 53ed5cb7c043ec..63fe5a80990746 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef(
   emitter.adaptor.writeDefTo(os);
 }
 
-// Emits the opcode enum and op classes.
-static void emitOpClasses(const RecordKeeper &recordKeeper,
-                          const std::vector<Record *> &defs, raw_ostream &os,
-                          bool emitDecl) {
-  // First emit forward declaration for each class, this allows them to refer
-  // to each others in traits for example.
-  if (emitDecl) {
-    os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
-    os << "#undef GET_OP_FWD_DEFINES\n";
-    for (auto *def : defs) {
-      Operator op(*def);
-      NamespaceEmitter emitter(os, op.getCppNamespace());
-      os << "class " << op.getCppClassName() << ";\n";
-    }
-    os << "#endif\n\n";
-  }
-
-  IfDefScope scope("GET_OP_CLASSES", os);
+/// Emit the class declarations or definitions for the given op defs.
+static void
+emitOpClasses(const RecordKeeper &recordKeeper,
+              const std::vector<Record *> &defs, raw_ostream &os,
+              const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+              bool emitDecl) {
   if (defs.empty())
     return;
 
-  // Generate all of the locally instantiated methods first.
-  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
-  os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
-  staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
-
   for (auto *def : defs) {
     Operator op(*def);
     if (emitDecl) {
@@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
   }
 }
 
-// Emits a comma-separated list of the ops.
-static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
-  IfDefScope scope("GET_OP_LIST", os);
+/// Emit the declarations for the provided op classes.
+static void emitOpClassDecls(const RecordKeeper &recordKeeper,
+                             const std::vector<Record *> &defs,
+                             raw_ostream &os) {
+  // First emit forward declaration for each class, this allows them to refer
+  // to each others in traits for example.
+  for (auto *def : defs) {
+    Operator op(*def);
+    NamespaceEmitter emitter(os, op.getCppNamespace());
+    os << "class " << op.getCppClassName() << ";\n";
+  }
+
+  // Emit the op class declarations.
+  IfDefScope scope("GET_OP_CLASSES", os);
+  if (defs.empty())
+    return;
+  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
+  staticVerifierEmitter.collectOpConstraints(defs);
+  emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+                /*emitDecl=*/true);
+}
+
+/// Emit the definitions for the provided op classes.
+static void emitOpClassDefs(const RecordKeeper &recordKeeper,
+                            ArrayRef<Record *> defs, raw_ostream &os,
+                            StringRef constraintPrefix = "") {
+  if (defs.empty())
+    return;
+
+  // Generate all of the locally instantiated methods first.
+  StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
+                                                      constraintPrefix);
+  os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+  staticVerifierEmitter.collectOpConstraints(defs);
+  staticVerifierEmitter.emitOpConstraints(defs);
 
-  interleave(
-      // TODO: We are constructing the Operator wrapper instance just for
-      // getting it's qualified class name here. Reduce the overhead by having a
-      // lightweight version of Operator class just for that purpose.
-      defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
-      [&os]() { os << ",\n"; });
+  // Emit the classes.
+  emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+                /*emitDecl=*/false);
 }
 
+/// Emit op declarations for all op records.
 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Declarations", os, recordKeeper);
 
   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
-  emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
+  emitOpClassDecls(recordKeeper, defs, os);
+
+  // If we are generating sharded op definitions, emit the sharded op
+  // registration hooks.
+  SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+  shardOpDefinitions(defs, shardedDefs);
+  if (defs.empty() || shardedDefs.size() <= 1)
+    return false;
+
+  Dialect dialect = Operator(defs.front()).getDialect();
+  NamespaceEmitter ns(os, dialect);
+
+  const char *const opRegistrationHook =
+      "void register{0}Operations{1}({2}::{0} *dialect);\n";
+  os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
+                dialect.getCppNamespace());
+  for (unsigned i = 0; i < shardedDefs.size(); ++i) {
+    os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
+                  dialect.getCppNamespace());
+  }
 
   return false;
 }
 
+/// Generate the dialect op registration hook and the op class definitions for a
+/// shard of ops.
+static void emitOpDefShard(const RecordKeeper &recordKeeper,
+                           ArrayRef<Record *> defs, const Dialect &dialect,
+                           unsigned shardIndex, unsigned shardCount,
+                           raw_ostream &os) {
+  std::string shardGuard = "GET_OP_DEFS_";
+  std::string indexStr = std::to_string(shardIndex);
+  shardGuard += indexStr;
+  IfDefScope scope(shardGuard, os);
+
+  // Emit the op registration hook in the first shard.
+  const char *const opRegistrationHook =
+      "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
+  if (shardIndex == 0) {
+    os << formatv(opRegistrationHook, dialect.getCppNamespace(),
+                  dialect.getCppClassName(), "");
+    for (unsigned i = 0; i < shardCount; ++i) {
+      os << formatv("  {0}::register{1}Operations{2}(dialect);\n",
+                    dialect.getCppNamespace(), dialect.getCppClassName(), i);
+    }
+    os << "}\n";
+  }
+
+  // Generate the per-shard op registration hook.
+  os << formatv(opCommentHeader, dialect.getCppClassName(),
+                "Op Registration Hook")
+     << formatv(opRegistrationHook, dialect.getCppNamespace(),
+                dialect.getCppClassName(), shardIndex);
+  for (Record *def : defs) {
+    os << formatv("  ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
+                  Operator(def).getQualCppClassName());
+  }
+  os << "}\n";
+
+  // Generate the per-shard op definitions.
+  emitOpClassDefs(recordKeeper, defs, os, indexStr);
+}
+
+/// Emit op definitions for all op records.
 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Op Definitions", os, recordKeeper);
 
   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
-  emitOpList(defs, os);
-  emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
+  SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+  shardOpDefinitions(defs, shardedDefs);
+
+  // If no shard was requested, emit the regular op list and class definitions.
+  if (shardedDefs.size() == 1) {
+    {
+      IfDefScope scope("GET_OP_LIST", os);
+      interleave(
+          defs, os,
+          [&](Record *def) { os << Operator(def).getQualCppClassName(); },
+          ",\n");
+    }
+    {
+      IfDefScope scope("GET_OP_CLASSES", os);
+      emitOpClassDefs(recordKeeper, defs, os);
+    }
+    return false;
+  }
 
+  if (defs.empty())
+    return false;
+  Dialect dialect = Operator(defs.front()).getDialect();
+  for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
+    emitOpDefShard(recordKeeper, value, dialect, idx, shardedDefs.size(), os);
+  }
   return false;
 }
 
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index 7fd34df8460d39..c2a2423a240269 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -31,6 +31,10 @@ static cl::opt<std::string> opExcFilter(
     "op-exclude-regex",
     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
     cl::cat(opDefGenCat));
+static cl::opt<unsigned> opShardCount(
+    "op-shard-count",
+    cl::desc("The number of shards into which the op classes will be divided"),
+    cl::cat(opDefGenCat), cl::init(1));
 
 static std::string getOperationName(const Record &def) {
   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
@@ -79,4 +83,23 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
   reserved.insert("issubclass");
   reserved.insert("type");
   return reserved.contains(str);
-}
\ No newline at end of file
+}
+
+void mlir::tblgen::shardOpDefinitions(
+    ArrayRef<llvm::Record *> defs,
+    SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs) {
+  assert(opShardCount > 0 && "expected a positive shard count");
+  if (opShardCount == 1) {
+    shardedDefs.push_back(defs);
+    return;
+  }
+
+  unsigned minShardSize = defs.size() / opShardCount;
+  unsigned numMissing = defs.size() - minShardSize * opShardCount;
+  shardedDefs.reserve(opShardCount);
+  for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
+    unsigned size = minShardSize + (i < numMissing);
+    shardedDefs.push_back(defs.slice(start, size));
+    start += size;
+  }
+}
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 3dcff14d1221ee..1b43d5d3ce3a7d 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
 #define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
 
+#include "mlir/Support/LLVM.h"
 #include "llvm/TableGen/Record.h"
 #include <vector>
 
@@ -28,6 +29,10 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
 /// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
 bool isPythonReserved(llvm::StringRef str);
 
+/// Shard the op defintions into the number of shards set by "op-shard-count".
+void shardOpDefinitions(ArrayRef<llvm::Record *> defs,
+                        SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs);
+
 } // namespace tblgen
 } // namespace mlir
 
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 6c732b8f134901..00a40019ac2e7b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9727,6 +9727,15 @@ cc_binary(
     ],
 )
 
+cc_binary(
+    name = "mlir-src-sharder",
+    srcs = ["tools/mlir-src-sharder/mlir-src-sharder.cpp"],
+    deps = [
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
 cc_binary(
     name = "mlir-linalg-ods-yaml-gen",
     srcs = [
diff --git a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
index fdf6a57107ac34..e45ba1fe0ef721 100644
--- a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
+++ b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
@@ -432,3 +432,136 @@ def gentbl_cc_library(
         copts = copts,
         **kwargs
     )
+
+def _gentbl_shard_impl(ctx):
+    args = ctx.actions.args()
+    args.add(ctx.file.src_file)
+    args.add("-op-shard-index", ctx.attr.index)
+    args.add("-o", ctx.outputs.out.path)
+    ctx.actions.run(
+        outputs = [ctx.outputs.out],
+        inputs = [ctx.file.src_file],
+        executable = ctx.executable.sharder,
+        arguments = [args],
+        use_default_shell_env = True,
+        mnemonic = "ShardGenerate",
+    )
+
+gentbl_shard_rule = rule(
+    _gentbl_shard_impl,
+    doc = "",
+    output_to_genfiles = True,
+    attrs = {
+        "index": attr.int(mandatory = True, doc = ""),
+        "sharder": attr.label(
+            doc = "",
+            executable = True,
+            cfg = "exec",
+        ),
+        "src_file": attr.label(
+            doc = "",
+            allow_single_file = True,
+            mandatory = True,
+        ),
+        "out": attr.output(
+            doc = "",
+            mandatory = True,
+        ),
+    },
+)
+
+def gentbl_sharded_ops(
+        name,
+        tblgen,
+        sharder,
+        td_file,
+        shard_count,
+        src_file,
+        src_out,
+        hdr_out,
+        test = False,
+        includes = [],
+        strip_include_prefix = None,
+        deps = []):
+    """Generate sharded op declarations and definitions.
+
+    This special build rule shards op definitions in a TableGen file and generates multiple copies
+    of a template source file for including and compiling each shard. The rule defines a filegroup
+    consisting of the source shards, the generated source file, and the generated header file.
+
+    Args:
+      name: The name of the filegroup.
+      tblgen: The binary used to produce the output.
+      sharder: The source file sharder to use.
+      td_file: The primary table definitions file.
+      shard_count: The number of op definition shards to produce.
+      src_file: The source file template.
+      src_out: The generated source file.
+      hdr_out: The generated header file.
+      test: Whether this is a test target.
+      includes: See gentbl_rule.includes
+      deps: See gentbl_rule.deps
+      strip_include_prefix: Attribute to pass through to cc_library.
+    """
+    cc_lib_name = name + "__gentbl_cc_lib"
+    gentbl_cc_library(
+        name = cc_lib_name,
+        strip_include_prefix = strip_include_prefix,
+        includes = includes,
+        tbl_outs = [
+            (
+                [
+                    "-gen-op-defs",
+                    "-op-shard-count=" + str(shard_count),
+                ],
+                src_out,
+            ),
+            (
+                [
+                    "-gen-op-decls",
+                    "-op-shard-count=" + str(shard_count),
+                ],
+                hdr_out,
+            ),
+        ],
+        tblgen = tblgen,
+        td_file = td_file,
+        test = test,
+        deps = deps,
+    )
+    all_files = [hdr_out, src_out]
+    for i in range(0, shard_count):
+        out_file = "shard_copy_" + str(i) + "_" + src_file
+        gentbl_shard_rule(
+            index = i,
+            name = name + "__src_shard" + str(i),
+            testonly = test,
+            out = out_file,
+            sharder = sharder,
+            src_file = src_file,
+        )
+        all_files.append(out_file)
+    native.filegroup(name = name, srcs = all_files)
+
+def gentbl_sharded_op_defs(name, source_file, shard_count):
+    """Generates multiple copies of a source file that includes sharded op definitions.
+
+    Args:
+      name: The name of the rule.
+      source_file: The source to copy.
+      shard_count: The number of shards.
+
+    Returns:
+      A list of the copied filenames to be included in the dialect library.
+    """
+    copies = []
+    for i in range(0, shard_count):
+        out_file = "shard_copy_" + str(i) + "_" + source_file
+        copies.append(out_file)
+        native.genrule(
+            name = name + "_shard_" + str(i),
+            srcs = [source_file],
+            outs = [out_file],
+            cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)",
+        )
+    return copies



More information about the llvm-commits mailing list