[llvm] [mlir] [mlir][ods] Allow sharding of op definitions (PR #89411)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 19 09:24:29 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jeff Niu (Mogball)
<details>
<summary>Changes</summary>
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g.
```
// mlir-tblgen -gen-op-defs -op-shard-count=2
void FooDialect::initialize() {
addOperations<
>();
addOperations<
>();
}
```
When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset.
---
Patch is 30.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89411.diff
14 Files Affected:
- (modified) mlir/CMakeLists.txt (+3)
- (modified) mlir/cmake/modules/AddMLIR.cmake (+38)
- (modified) mlir/cmake/modules/CMakeLists.txt (+2)
- (modified) mlir/cmake/modules/MLIRConfig.cmake.in (+1)
- (modified) mlir/include/mlir/TableGen/CodeGenHelpers.h (+8-4)
- (modified) mlir/lib/TableGen/CodeGenHelpers.cpp (+6-9)
- (added) mlir/test/mlir-tblgen/shard-op-defs.td (+33)
- (added) mlir/tools/mlir-src-sharder/CMakeLists.txt (+14)
- (added) mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp (+114)
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+130-35)
- (modified) mlir/tools/mlir-tblgen/OpGenHelpers.cpp (+24-1)
- (modified) mlir/tools/mlir-tblgen/OpGenHelpers.h (+5)
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+9)
- (modified) utils/bazel/llvm-project-overlay/mlir/tblgen.bzl (+133)
``````````diff
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 5c4301af040b47..4c0ef8387b8dff 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
+add_subdirectory(tools/mlir-src-sharder)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
add_subdirectory(include/mlir)
add_subdirectory(lib)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 1d2ed748bc2f13..afb74fb2d00025 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)
+
+ # Get the current set of include paths for this td file.
+ cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
+ get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
+ list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
+ # Filter out any empty include items.
+ list(REMOVE_ITEM tblgen_includes "")
+
+ # Build the absolute path for the current input file.
+ if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+ set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+ else()
+ set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
+ endif()
+
+ # Append the includes used for this file to the tablegen_compile_commands
+ # file.
+ file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
+ "--- !FileInfo:\n"
+ " filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
+ " includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
+ )
endfunction()
# Clear out any pre-existing compile_commands file before processing. This
@@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()
+# Declare sharded dialect operation declarations and definitions
+function(add_sharded_ops ops_target shard_count)
+ set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
+ mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
+ mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
+ set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
+ foreach(index RANGE ${shard_count})
+ set(SHARDED_SRC ${ops_target}.${index}.cpp)
+ list(APPEND SHARDED_SRCS ${SHARDED_SRC})
+ tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
+ set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
+ endforeach()
+ add_public_tablegen_target(MLIR${ops_target}ShardGen)
+ set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
+endfunction()
+
# Declare a dialect in the include directory
function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt
index 8d2904ef46dfe8..3ac1c79b090ed6 100644
--- a/mlir/cmake/modules/CMakeLists.txt
+++ b/mlir/cmake/modules/CMakeLists.txt
@@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# Refer to the best host mlir-tbgen, which might be a host-optimized version
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
@@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in
index d4da3cd98cce98..7076d94a32f2bc 100644
--- a/mlir/cmake/modules/MLIRConfig.cmake.in
+++ b/mlir/cmake/modules/MLIRConfig.cmake.in
@@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index dd17a44c889bbe..c263c69c53d1e3 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -99,8 +99,14 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
+ /// Create a constraint uniquer with a unique prefix derived from the record
+ /// keeper with an optional tag.
StaticVerifierFunctionEmitter(raw_ostream &os,
- const llvm::RecordKeeper &records);
+ const llvm::RecordKeeper &records,
+ StringRef tag = "");
+
+ /// Collect and unique all the constraints used by operations.
+ void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
@@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter {
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
+ void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
@@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
/// Emit pattern constraints.
void emitPatternConstraints();
- /// Collect and unique all the constraints used by operations.
- void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index d906de6b56afc0..59865146e20bc4 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -24,7 +24,8 @@ using namespace mlir::tblgen;
/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
-static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
+static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
+ StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
@@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
nameRef.consume_back(".td");
// Sanitize any invalid characters.
- std::string uniqueName;
+ std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
@@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
}
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
- raw_ostream &os, const llvm::RecordKeeper &records)
- : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
+ raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
+ : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
- collectOpConstraints(opDefs);
- if (emitDecl)
- return;
-
+ ArrayRef<llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td
new file mode 100644
index 00000000000000..84ac6b0fbe9ebe
--- /dev/null
+++ b/mlir/test/mlir-tblgen/shard-op-defs.td
@@ -0,0 +1,33 @@
+// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
+// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "test";
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []>
+ : Op<Test_Dialect, mnemonic, traits>;
+
+def OpA : Test_Op<"a">;
+def OpB : Test_Op<"b">;
+def OpC : Test_Op<"c">;
+
+// DECLS: OpA
+// DECLS: OpB
+// DECLS: OpC
+// DECLS: registerTestDialectOperations(
+// DECLS: registerTestDialectOperations0(
+// DECLS: registerTestDialectOperations1(
+
+// DEFS-LABEL: GET_OP_DEFS_0
+// DEFS: void test::registerTestDialectOperations(
+// DEFS: void test::registerTestDialectOperations0(
+// DEFS: OpAAdaptor
+// DEFS: OpBAdaptor
+
+// DEFS-LABEL: GET_OP_DEFS_1
+// DEFS: void test::registerTestDialectOperations1(
+// DEFS: OpCAdaptor
diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt
new file mode 100644
index 00000000000000..4ef870b61124ad
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt
@@ -0,0 +1,14 @@
+set(LLVM_LINK_COMPONENTS Support)
+set(LIBS MLIRSupport)
+
+add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
+ mlir-src-sharder.cpp
+
+ DEPENDS
+ ${LIBS}
+ )
+
+set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
+target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
+
+mlir_check_all_link_libraries(mlir-src-sharder)
diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
new file mode 100644
index 00000000000000..dc1e2939c7d25b
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
@@ -0,0 +1,114 @@
+//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+/// Create a dependency file for `-d` option.
+///
+/// This functionality is generally only for the benefit of the build system,
+/// and is modeled after the same option in TableGen.
+static LogicalResult createDependencyFile(StringRef outputFilename,
+ StringRef dependencyFile) {
+ if (outputFilename == "-") {
+ llvm::errs() << "error: the option -d must be used together with -o\n";
+ return failure();
+ }
+
+ std::string errorMessage;
+ std::unique_ptr<llvm::ToolOutputFile> outputFile =
+ openOutputFile(dependencyFile, &errorMessage);
+ if (!outputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return failure();
+ }
+
+ outputFile->os() << outputFilename << ":\n";
+ outputFile->keep();
+ return success();
+}
+
+int main(int argc, char **argv) {
+ // FIXME: This is necessary because we link in TableGen, which defines its
+ // options as static variables.. some of which overlap with our options.
+ llvm::cl::ResetCommandLineParser();
+
+ llvm::cl::opt<unsigned> opShardIndex(
+ "op-shard-index", llvm::cl::desc("The current shard index"));
+ llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+ llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+ llvm::cl::list<std::string> includeDirs(
+ "I", llvm::cl::desc("Directory of include files"),
+ llvm::cl::value_desc("directory"), llvm::cl::Prefix);
+ llvm::cl::opt<std::string> dependencyFilename(
+ "d", llvm::cl::desc("Dependency filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::init(""));
+ llvm::cl::opt<bool> writeIfChanged(
+ "write-if-changed",
+ llvm::cl::desc("Only write to the output file if it changed"));
+
+ llvm::InitLLVM y(argc, argv);
+ llvm::cl::ParseCommandLineOptions(argc, argv);
+
+ // Open the input file.
+ std::string errorMessage;
+ std::unique_ptr<llvm::MemoryBuffer> inputFile =
+ openInputFile(inputFilename, &errorMessage);
+ if (!inputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ // Write the output to a buffer.
+ std::string outputStr;
+ llvm::raw_string_ostream os(outputStr);
+ os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
+ << inputFile->getBuffer();
+
+ // Determine whether we need to write the output file.
+ bool shouldWriteOutput = true;
+ if (writeIfChanged) {
+ // Only update the real output file if there are any differences. This
+ // prevents recompilation of all the files depending on it if there aren't
+ // any.
+ if (auto existingOrErr =
+ llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
+ if (std::move(existingOrErr.get())->getBuffer() == os.str())
+ shouldWriteOutput = false;
+ }
+
+ // Populate the output file if necessary.
+ if (shouldWriteOutput) {
+ std::unique_ptr<llvm::ToolOutputFile> outputFile =
+ openOutputFile(outputFilename, &errorMessage);
+ if (!outputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+ outputFile->os() << os.str();
+ outputFile->keep();
+ }
+
+ // Always write the depfile, even if the main output hasn't changed. If it's
+ // missing, Ninja considers the output dirty.
+ if (!dependencyFilename.empty())
+ if (failed(createDependencyFile(outputFilename, dependencyFilename)))
+ return 1;
+
+ return 0;
+}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 53ed5cb7c043ec..ff26b2a61662e2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef(
emitter.adaptor.writeDefTo(os);
}
-// Emits the opcode enum and op classes.
-static void emitOpClasses(const RecordKeeper &recordKeeper,
- const std::vector<Record *> &defs, raw_ostream &os,
- bool emitDecl) {
- // First emit forward declaration for each class, this allows them to refer
- // to each others in traits for example.
- if (emitDecl) {
- os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
- os << "#undef GET_OP_FWD_DEFINES\n";
- for (auto *def : defs) {
- Operator op(*def);
- NamespaceEmitter emitter(os, op.getCppNamespace());
- os << "class " << op.getCppClassName() << ";\n";
- }
- os << "#endif\n\n";
- }
-
- IfDefScope scope("GET_OP_CLASSES", os);
+/// Emit the class declarations or definitions for the given op defs.
+static void
+emitOpClasses(const RecordKeeper &recordKeeper,
+ const std::vector<Record *> &defs, raw_ostream &os,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool emitDecl) {
if (defs.empty())
return;
- // Generate all of the locally instantiated methods first.
- StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
- os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
- staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
-
for (auto *def : defs) {
Operator op(*def);
if (emitDecl) {
@@ -4358,34 +4341,146 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
}
}
-// Emits a comma-separated list of the ops.
-static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
- IfDefScope scope("GET_OP_LIST", os);
+/// Emit the declarations for the provided op classes.
+static void emitOpClassDecls(const RecordKeeper &recordKeeper,
+ const std::vector<Record *> &defs,
+ raw_ostream &os) {
+ // First emit forward declaration for each class, this allows them to refer
+ // to each others in traits for example.
+ for (auto *def : defs) {
+ Operator op(*def);
+ NamespaceEmitter emitter(os, op.getCppNamespace());
+ os << "class " << op.getCppClassName() << ";\n";
+ }
+
+ // Emit the op class declarations.
+ IfDefScope scope("GET_OP_CLASSES", os);
+ if (defs.empty())
+ return;
+ StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
+ staticVerifierEmitter.collectOpConstraints(defs);
+ emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+ /*emitDecl=*/true);
+}
+
+/// Emit the definitions for the provided op classes.
+static void emitOpClassDefs(const RecordKeeper &recordKeeper,
+ ArrayRef<Record *> defs, raw_ostream &os,
+ StringRef constraintPrefix = "") {
+ if (defs.empty())
+ return;
+
+ // Generate all of the locally instantiated methods first.
+ StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
+ constraintPrefix);
+ os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+ staticVerifierEmitter.collectOpConstraints(defs);
+ staticVerifierEmitter.emitOpConstraints(defs);
- interleave(
- // TODO: We are constructing the Operator wrapper instance just for
- // getting it's qualified class name here. Reduce the overhead by having a
- // lightweight version of Operator class just for that purpose.
- defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
- [&os]() { os << ",\n"; });
+ // Emit the classes.
+ emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+ /*emitDecl=*/false);
}
+/// Emit op declarations for all op records.
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os, recordKeeper);
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
- emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
+ emitOpClassDecls(recordKeeper, defs, os);
+
+ // If we are generating sharded op definitions, emit the sharded op
+ // registration hooks.
+ SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+ shardOpDefinitions(defs, shardedDefs);
+ if (defs.empty() || shardedDefs.size() <= 1)
+ return false;
+
+ Dialect dialect = Operator(defs.front()).getDialect();
+ NamespaceEmitter ns(os, dialect);
+
+ const char *const opRegistrationHook =
+ "void register{0}Operations{1}({2}::{0} *dialect);\n";
+ os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
+ dialect.getCppNamespace());
+ for (unsigned i = 0; i < shardedDefs.size(); ++i) {
+ os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
+ dialect.getCppNamespace());
+ }
return false;
}
+/// Generate the dialect op registration hook and the op class definitions for a
+/// shard of ops.
+static void emitOpDefShard(const RecordKeeper &recordKeeper,
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/89411
More information about the llvm-commits
mailing list