[llvm] [mlir] [mlir] Update comment about `propertiesAttr` (NFC) (PR #89633)
Jeff Niu via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 22 10:09:23 PDT 2024
https://github.com/Mogball updated https://github.com/llvm/llvm-project/pull/89633
>From caff32bff994c9490d0a37de1ec2f8c69d820498 Mon Sep 17 00:00:00 2001
From: Mogball <jeffniu22 at gmail.com>
Date: Thu, 23 Jun 2022 20:44:59 +0000
Subject: [PATCH 1/3] [mlir][ods] Allow sharding of op definitions
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the
op class definitions and op list into N segments, e.g.
```
// mlir-tblgen -gen-op-defs -op-shard-count=2
void FooDialect::initialize() {
addOperations<
>();
addOperations<
>();
}
```
When split across multiple source files, this can help significantly improve
dialect compile time for dialects with a large opset.
stack-info: PR: https://github.com/llvm/llvm-project/pull/89423, branch: users/mogball/pr_1
---
mlir/CMakeLists.txt | 3 +
mlir/cmake/modules/AddMLIR.cmake | 38 ++++
mlir/cmake/modules/CMakeLists.txt | 2 +
mlir/cmake/modules/MLIRConfig.cmake.in | 1 +
mlir/include/mlir/TableGen/CodeGenHelpers.h | 12 +-
mlir/lib/TableGen/CodeGenHelpers.cpp | 15 +-
mlir/test/mlir-tblgen/shard-op-defs.td | 33 ++++
mlir/tools/mlir-src-sharder/CMakeLists.txt | 14 ++
.../mlir-src-sharder/mlir-src-sharder.cpp | 114 ++++++++++++
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 164 ++++++++++++++----
mlir/tools/mlir-tblgen/OpGenHelpers.cpp | 25 ++-
mlir/tools/mlir-tblgen/OpGenHelpers.h | 5 +
.../llvm-project-overlay/mlir/BUILD.bazel | 9 +
.../llvm-project-overlay/mlir/tblgen.bzl | 133 ++++++++++++++
14 files changed, 519 insertions(+), 49 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/shard-op-defs.td
create mode 100644 mlir/tools/mlir-src-sharder/CMakeLists.txt
create mode 100644 mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 5c4301af040b47..4c0ef8387b8dff 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
+add_subdirectory(tools/mlir-src-sharder)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
+set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
add_subdirectory(include/mlir)
add_subdirectory(lib)
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 1d2ed748bc2f13..afb74fb2d00025 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
tablegen(MLIR ${ARGV})
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
PARENT_SCOPE)
+
+ # Get the current set of include paths for this td file.
+ cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
+ get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
+ list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
+ # Filter out any empty include items.
+ list(REMOVE_ITEM tblgen_includes "")
+
+ # Build the absolute path for the current input file.
+ if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+ set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
+ else()
+ set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
+ endif()
+
+ # Append the includes used for this file to the tablegen_compile_commands
+ # file.
+ file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
+ "--- !FileInfo:\n"
+ " filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
+ " includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
+ )
endfunction()
# Clear out any pre-existing compile_commands file before processing. This
@@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()
+# Declare sharded dialect operation declarations and definitions
+function(add_sharded_ops ops_target shard_count)
+ set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
+ mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
+ mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
+ set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
+ foreach(index RANGE ${shard_count})
+ set(SHARDED_SRC ${ops_target}.${index}.cpp)
+ list(APPEND SHARDED_SRCS ${SHARDED_SRC})
+ tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
+ set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
+ endforeach()
+ add_public_tablegen_target(MLIR${ops_target}ShardGen)
+ set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
+endfunction()
+
# Declare a dialect in the include directory
function(add_mlir_interface interface)
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt
index 8d2904ef46dfe8..3ac1c79b090ed6 100644
--- a/mlir/cmake/modules/CMakeLists.txt
+++ b/mlir/cmake/modules/CMakeLists.txt
@@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# Refer to the best host mlir-tbgen, which might be a host-optimized version
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
@@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
+set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in
index d4da3cd98cce98..7076d94a32f2bc 100644
--- a/mlir/cmake/modules/MLIRConfig.cmake.in
+++ b/mlir/cmake/modules/MLIRConfig.cmake.in
@@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
+set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h
index dd17a44c889bbe..c263c69c53d1e3 100644
--- a/mlir/include/mlir/TableGen/CodeGenHelpers.h
+++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h
@@ -99,8 +99,14 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
+ /// Create a constraint uniquer with a unique prefix derived from the record
+ /// keeper with an optional tag.
StaticVerifierFunctionEmitter(raw_ostream &os,
- const llvm::RecordKeeper &records);
+ const llvm::RecordKeeper &records,
+ StringRef tag = "");
+
+ /// Collect and unique all the constraints used by operations.
+ void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
@@ -108,7 +114,7 @@ class StaticVerifierFunctionEmitter {
///
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
- void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
+ void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
@@ -177,8 +183,6 @@ class StaticVerifierFunctionEmitter {
/// Emit pattern constraints.
void emitPatternConstraints();
- /// Collect and unique all the constraints used by operations.
- void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index d906de6b56afc0..59865146e20bc4 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -24,7 +24,8 @@ using namespace mlir::tblgen;
/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
-static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
+static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
+ StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
@@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
nameRef.consume_back(".td");
// Sanitize any invalid characters.
- std::string uniqueName;
+ std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
@@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
}
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
- raw_ostream &os, const llvm::RecordKeeper &records)
- : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
+ raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
+ : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
- collectOpConstraints(opDefs);
- if (emitDecl)
- return;
-
+ ArrayRef<llvm::Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td
new file mode 100644
index 00000000000000..84ac6b0fbe9ebe
--- /dev/null
+++ b/mlir/test/mlir-tblgen/shard-op-defs.td
@@ -0,0 +1,33 @@
+// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
+// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "test";
+}
+
+class Test_Op<string mnemonic, list<Trait> traits = []>
+ : Op<Test_Dialect, mnemonic, traits>;
+
+def OpA : Test_Op<"a">;
+def OpB : Test_Op<"b">;
+def OpC : Test_Op<"c">;
+
+// DECLS: OpA
+// DECLS: OpB
+// DECLS: OpC
+// DECLS: registerTestDialectOperations(
+// DECLS: registerTestDialectOperations0(
+// DECLS: registerTestDialectOperations1(
+
+// DEFS-LABEL: GET_OP_DEFS_0
+// DEFS: void test::registerTestDialectOperations(
+// DEFS: void test::registerTestDialectOperations0(
+// DEFS: OpAAdaptor
+// DEFS: OpBAdaptor
+
+// DEFS-LABEL: GET_OP_DEFS_1
+// DEFS: void test::registerTestDialectOperations1(
+// DEFS: OpCAdaptor
diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt
new file mode 100644
index 00000000000000..4ef870b61124ad
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt
@@ -0,0 +1,14 @@
+set(LLVM_LINK_COMPONENTS Support)
+set(LIBS MLIRSupport)
+
+add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
+ mlir-src-sharder.cpp
+
+ DEPENDS
+ ${LIBS}
+ )
+
+set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
+target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
+
+mlir_check_all_link_libraries(mlir-src-sharder)
diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
new file mode 100644
index 00000000000000..dc1e2939c7d25b
--- /dev/null
+++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
@@ -0,0 +1,114 @@
+//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+/// Create a dependency file for `-d` option.
+///
+/// This functionality is generally only for the benefit of the build system,
+/// and is modeled after the same option in TableGen.
+static LogicalResult createDependencyFile(StringRef outputFilename,
+ StringRef dependencyFile) {
+ if (outputFilename == "-") {
+ llvm::errs() << "error: the option -d must be used together with -o\n";
+ return failure();
+ }
+
+ std::string errorMessage;
+ std::unique_ptr<llvm::ToolOutputFile> outputFile =
+ openOutputFile(dependencyFile, &errorMessage);
+ if (!outputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return failure();
+ }
+
+ outputFile->os() << outputFilename << ":\n";
+ outputFile->keep();
+ return success();
+}
+
+int main(int argc, char **argv) {
+ // FIXME: This is necessary because we link in TableGen, which defines its
+ // options as static variables.. some of which overlap with our options.
+ llvm::cl::ResetCommandLineParser();
+
+ llvm::cl::opt<unsigned> opShardIndex(
+ "op-shard-index", llvm::cl::desc("The current shard index"));
+ llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+ llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+ llvm::cl::list<std::string> includeDirs(
+ "I", llvm::cl::desc("Directory of include files"),
+ llvm::cl::value_desc("directory"), llvm::cl::Prefix);
+ llvm::cl::opt<std::string> dependencyFilename(
+ "d", llvm::cl::desc("Dependency filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::init(""));
+ llvm::cl::opt<bool> writeIfChanged(
+ "write-if-changed",
+ llvm::cl::desc("Only write to the output file if it changed"));
+
+ llvm::InitLLVM y(argc, argv);
+ llvm::cl::ParseCommandLineOptions(argc, argv);
+
+ // Open the input file.
+ std::string errorMessage;
+ std::unique_ptr<llvm::MemoryBuffer> inputFile =
+ openInputFile(inputFilename, &errorMessage);
+ if (!inputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ // Write the output to a buffer.
+ std::string outputStr;
+ llvm::raw_string_ostream os(outputStr);
+ os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
+ << inputFile->getBuffer();
+
+ // Determine whether we need to write the output file.
+ bool shouldWriteOutput = true;
+ if (writeIfChanged) {
+ // Only update the real output file if there are any differences. This
+ // prevents recompilation of all the files depending on it if there aren't
+ // any.
+ if (auto existingOrErr =
+ llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
+ if (std::move(existingOrErr.get())->getBuffer() == os.str())
+ shouldWriteOutput = false;
+ }
+
+ // Populate the output file if necessary.
+ if (shouldWriteOutput) {
+ std::unique_ptr<llvm::ToolOutputFile> outputFile =
+ openOutputFile(outputFilename, &errorMessage);
+ if (!outputFile) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+ outputFile->os() << os.str();
+ outputFile->keep();
+ }
+
+ // Always write the depfile, even if the main output hasn't changed. If it's
+ // missing, Ninja considers the output dirty.
+ if (!dependencyFilename.empty())
+ if (failed(createDependencyFile(outputFilename, dependencyFilename)))
+ return 1;
+
+ return 0;
+}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 53ed5cb7c043ec..63fe5a80990746 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef(
emitter.adaptor.writeDefTo(os);
}
-// Emits the opcode enum and op classes.
-static void emitOpClasses(const RecordKeeper &recordKeeper,
- const std::vector<Record *> &defs, raw_ostream &os,
- bool emitDecl) {
- // First emit forward declaration for each class, this allows them to refer
- // to each others in traits for example.
- if (emitDecl) {
- os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
- os << "#undef GET_OP_FWD_DEFINES\n";
- for (auto *def : defs) {
- Operator op(*def);
- NamespaceEmitter emitter(os, op.getCppNamespace());
- os << "class " << op.getCppClassName() << ";\n";
- }
- os << "#endif\n\n";
- }
-
- IfDefScope scope("GET_OP_CLASSES", os);
+/// Emit the class declarations or definitions for the given op defs.
+static void
+emitOpClasses(const RecordKeeper &recordKeeper,
+ const std::vector<Record *> &defs, raw_ostream &os,
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter,
+ bool emitDecl) {
if (defs.empty())
return;
- // Generate all of the locally instantiated methods first.
- StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
- os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
- staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
-
for (auto *def : defs) {
Operator op(*def);
if (emitDecl) {
@@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
}
}
-// Emits a comma-separated list of the ops.
-static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
- IfDefScope scope("GET_OP_LIST", os);
+/// Emit the declarations for the provided op classes.
+static void emitOpClassDecls(const RecordKeeper &recordKeeper,
+ const std::vector<Record *> &defs,
+ raw_ostream &os) {
+ // First emit forward declaration for each class, this allows them to refer
+ // to each others in traits for example.
+ for (auto *def : defs) {
+ Operator op(*def);
+ NamespaceEmitter emitter(os, op.getCppNamespace());
+ os << "class " << op.getCppClassName() << ";\n";
+ }
+
+ // Emit the op class declarations.
+ IfDefScope scope("GET_OP_CLASSES", os);
+ if (defs.empty())
+ return;
+ StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
+ staticVerifierEmitter.collectOpConstraints(defs);
+ emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+ /*emitDecl=*/true);
+}
+
+/// Emit the definitions for the provided op classes.
+static void emitOpClassDefs(const RecordKeeper &recordKeeper,
+ ArrayRef<Record *> defs, raw_ostream &os,
+ StringRef constraintPrefix = "") {
+ if (defs.empty())
+ return;
+
+ // Generate all of the locally instantiated methods first.
+ StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
+ constraintPrefix);
+ os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
+ staticVerifierEmitter.collectOpConstraints(defs);
+ staticVerifierEmitter.emitOpConstraints(defs);
- interleave(
- // TODO: We are constructing the Operator wrapper instance just for
- // getting it's qualified class name here. Reduce the overhead by having a
- // lightweight version of Operator class just for that purpose.
- defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
- [&os]() { os << ",\n"; });
+ // Emit the classes.
+ emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
+ /*emitDecl=*/false);
}
+/// Emit op declarations for all op records.
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Declarations", os, recordKeeper);
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
- emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
+ emitOpClassDecls(recordKeeper, defs, os);
+
+ // If we are generating sharded op definitions, emit the sharded op
+ // registration hooks.
+ SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+ shardOpDefinitions(defs, shardedDefs);
+ if (defs.empty() || shardedDefs.size() <= 1)
+ return false;
+
+ Dialect dialect = Operator(defs.front()).getDialect();
+ NamespaceEmitter ns(os, dialect);
+
+ const char *const opRegistrationHook =
+ "void register{0}Operations{1}({2}::{0} *dialect);\n";
+ os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
+ dialect.getCppNamespace());
+ for (unsigned i = 0; i < shardedDefs.size(); ++i) {
+ os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
+ dialect.getCppNamespace());
+ }
return false;
}
+/// Generate the dialect op registration hook and the op class definitions for a
+/// shard of ops.
+static void emitOpDefShard(const RecordKeeper &recordKeeper,
+ ArrayRef<Record *> defs, const Dialect &dialect,
+ unsigned shardIndex, unsigned shardCount,
+ raw_ostream &os) {
+ std::string shardGuard = "GET_OP_DEFS_";
+ std::string indexStr = std::to_string(shardIndex);
+ shardGuard += indexStr;
+ IfDefScope scope(shardGuard, os);
+
+ // Emit the op registration hook in the first shard.
+ const char *const opRegistrationHook =
+ "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
+ if (shardIndex == 0) {
+ os << formatv(opRegistrationHook, dialect.getCppNamespace(),
+ dialect.getCppClassName(), "");
+ for (unsigned i = 0; i < shardCount; ++i) {
+ os << formatv(" {0}::register{1}Operations{2}(dialect);\n",
+ dialect.getCppNamespace(), dialect.getCppClassName(), i);
+ }
+ os << "}\n";
+ }
+
+ // Generate the per-shard op registration hook.
+ os << formatv(opCommentHeader, dialect.getCppClassName(),
+ "Op Registration Hook")
+ << formatv(opRegistrationHook, dialect.getCppNamespace(),
+ dialect.getCppClassName(), shardIndex);
+ for (Record *def : defs) {
+ os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
+ Operator(def).getQualCppClassName());
+ }
+ os << "}\n";
+
+ // Generate the per-shard op definitions.
+ emitOpClassDefs(recordKeeper, defs, os, indexStr);
+}
+
+/// Emit op definitions for all op records.
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Op Definitions", os, recordKeeper);
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
- emitOpList(defs, os);
- emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
+ SmallVector<ArrayRef<Record *>, 4> shardedDefs;
+ shardOpDefinitions(defs, shardedDefs);
+
+ // If no shard was requested, emit the regular op list and class definitions.
+ if (shardedDefs.size() == 1) {
+ {
+ IfDefScope scope("GET_OP_LIST", os);
+ interleave(
+ defs, os,
+ [&](Record *def) { os << Operator(def).getQualCppClassName(); },
+ ",\n");
+ }
+ {
+ IfDefScope scope("GET_OP_CLASSES", os);
+ emitOpClassDefs(recordKeeper, defs, os);
+ }
+ return false;
+ }
+ if (defs.empty())
+ return false;
+ Dialect dialect = Operator(defs.front()).getDialect();
+ for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
+ emitOpDefShard(recordKeeper, value, dialect, idx, shardedDefs.size(), os);
+ }
return false;
}
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
index 7fd34df8460d39..c2a2423a240269 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp
@@ -31,6 +31,10 @@ static cl::opt<std::string> opExcFilter(
"op-exclude-regex",
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
cl::cat(opDefGenCat));
+static cl::opt<unsigned> opShardCount(
+ "op-shard-count",
+ cl::desc("The number of shards into which the op classes will be divided"),
+ cl::cat(opDefGenCat), cl::init(1));
static std::string getOperationName(const Record &def) {
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
@@ -79,4 +83,23 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
reserved.insert("issubclass");
reserved.insert("type");
return reserved.contains(str);
-}
\ No newline at end of file
+}
+
+void mlir::tblgen::shardOpDefinitions(
+ ArrayRef<llvm::Record *> defs,
+ SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs) {
+ assert(opShardCount > 0 && "expected a positive shard count");
+ if (opShardCount == 1) {
+ shardedDefs.push_back(defs);
+ return;
+ }
+
+ unsigned minShardSize = defs.size() / opShardCount;
+ unsigned numMissing = defs.size() - minShardSize * opShardCount;
+ shardedDefs.reserve(opShardCount);
+ for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
+ unsigned size = minShardSize + (i < numMissing);
+ shardedDefs.push_back(defs.slice(start, size));
+ start += size;
+ }
+}
diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h
index 3dcff14d1221ee..1b43d5d3ce3a7d 100644
--- a/mlir/tools/mlir-tblgen/OpGenHelpers.h
+++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
+#include "mlir/Support/LLVM.h"
#include "llvm/TableGen/Record.h"
#include <vector>
@@ -28,6 +29,10 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
bool isPythonReserved(llvm::StringRef str);
+/// Shard the op defintions into the number of shards set by "op-shard-count".
+void shardOpDefinitions(ArrayRef<llvm::Record *> defs,
+ SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs);
+
} // namespace tblgen
} // namespace mlir
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 6c732b8f134901..00a40019ac2e7b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9727,6 +9727,15 @@ cc_binary(
],
)
+cc_binary(
+ name = "mlir-src-sharder",
+ srcs = ["tools/mlir-src-sharder/mlir-src-sharder.cpp"],
+ deps = [
+ ":Support",
+ "//llvm:Support",
+ ],
+)
+
cc_binary(
name = "mlir-linalg-ods-yaml-gen",
srcs = [
diff --git a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
index fdf6a57107ac34..e45ba1fe0ef721 100644
--- a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
+++ b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl
@@ -432,3 +432,136 @@ def gentbl_cc_library(
copts = copts,
**kwargs
)
+
+def _gentbl_shard_impl(ctx):
+ args = ctx.actions.args()
+ args.add(ctx.file.src_file)
+ args.add("-op-shard-index", ctx.attr.index)
+ args.add("-o", ctx.outputs.out.path)
+ ctx.actions.run(
+ outputs = [ctx.outputs.out],
+ inputs = [ctx.file.src_file],
+ executable = ctx.executable.sharder,
+ arguments = [args],
+ use_default_shell_env = True,
+ mnemonic = "ShardGenerate",
+ )
+
+gentbl_shard_rule = rule(
+ _gentbl_shard_impl,
+ doc = "",
+ output_to_genfiles = True,
+ attrs = {
+ "index": attr.int(mandatory = True, doc = ""),
+ "sharder": attr.label(
+ doc = "",
+ executable = True,
+ cfg = "exec",
+ ),
+ "src_file": attr.label(
+ doc = "",
+ allow_single_file = True,
+ mandatory = True,
+ ),
+ "out": attr.output(
+ doc = "",
+ mandatory = True,
+ ),
+ },
+)
+
+def gentbl_sharded_ops(
+ name,
+ tblgen,
+ sharder,
+ td_file,
+ shard_count,
+ src_file,
+ src_out,
+ hdr_out,
+ test = False,
+ includes = [],
+ strip_include_prefix = None,
+ deps = []):
+ """Generate sharded op declarations and definitions.
+
+ This special build rule shards op definitions in a TableGen file and generates multiple copies
+ of a template source file for including and compiling each shard. The rule defines a filegroup
+ consisting of the source shards, the generated source file, and the generated header file.
+
+ Args:
+ name: The name of the filegroup.
+ tblgen: The binary used to produce the output.
+ sharder: The source file sharder to use.
+ td_file: The primary table definitions file.
+ shard_count: The number of op definition shards to produce.
+ src_file: The source file template.
+ src_out: The generated source file.
+ hdr_out: The generated header file.
+ test: Whether this is a test target.
+ includes: See gentbl_rule.includes
+ deps: See gentbl_rule.deps
+ strip_include_prefix: Attribute to pass through to cc_library.
+ """
+ cc_lib_name = name + "__gentbl_cc_lib"
+ gentbl_cc_library(
+ name = cc_lib_name,
+ strip_include_prefix = strip_include_prefix,
+ includes = includes,
+ tbl_outs = [
+ (
+ [
+ "-gen-op-defs",
+ "-op-shard-count=" + str(shard_count),
+ ],
+ src_out,
+ ),
+ (
+ [
+ "-gen-op-decls",
+ "-op-shard-count=" + str(shard_count),
+ ],
+ hdr_out,
+ ),
+ ],
+ tblgen = tblgen,
+ td_file = td_file,
+ test = test,
+ deps = deps,
+ )
+ all_files = [hdr_out, src_out]
+ for i in range(0, shard_count):
+ out_file = "shard_copy_" + str(i) + "_" + src_file
+ gentbl_shard_rule(
+ index = i,
+ name = name + "__src_shard" + str(i),
+ testonly = test,
+ out = out_file,
+ sharder = sharder,
+ src_file = src_file,
+ )
+ all_files.append(out_file)
+ native.filegroup(name = name, srcs = all_files)
+
+def gentbl_sharded_op_defs(name, source_file, shard_count):
+ """Generates multiple copies of a source file that includes sharded op definitions.
+
+ Args:
+ name: The name of the rule.
+ source_file: The source to copy.
+ shard_count: The number of shards.
+
+ Returns:
+ A list of the copied filenames to be included in the dialect library.
+ """
+ copies = []
+ for i in range(0, shard_count):
+ out_file = "shard_copy_" + str(i) + "_" + source_file
+ copies.append(out_file)
+ native.genrule(
+ name = name + "_shard_" + str(i),
+ srcs = [source_file],
+ outs = [out_file],
+ cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)",
+ )
+ return copies
>From c03dd33266bd7b59281f1a059ecc0144867c07cd Mon Sep 17 00:00:00 2001
From: Mogball <jeffniu22 at gmail.com>
Date: Thu, 23 Jun 2022 20:45:26 +0000
Subject: [PATCH 2/3] [mlir][test] Reorganize the test dialect
This PR massively reorganizes the Test dialect's source files. It moves
manually-written op hooks into `TestOpDefs.cpp`, moves format custom
directive parsers and printers into `TestFormatUtils`, adds missing
comment blocks, and moves around where generated source files are
included for types, attributes, enums, etc. into their own source file.
This will hopefully help navigate the test dialect source code, but also
speeds up compile time of the test dialect by putting generated source
files into separate compilation units.
This also sets up the test dialect to shard its op definitions, done in
the next PR.
stack-info: PR: https://github.com/llvm/llvm-project/pull/89424, branch: users/mogball/pr_2
---
.../TestDenseBackwardDataFlowAnalysis.cpp | 1 +
.../TestDenseForwardDataFlowAnalysis.cpp | 1 +
.../FuncToLLVM/TestConvertCallOp.cpp | 1 +
.../TestOneToNTypeConversionPass.cpp | 1 +
.../Dialect/Affine/TestReifyValueBounds.cpp | 1 +
.../lib/Dialect/DLTI/TestDataLayoutQuery.cpp | 2 +-
.../Func/TestDecomposeCallGraphTypes.cpp | 1 +
mlir/test/lib/Dialect/Test/CMakeLists.txt | 3 +
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 3 +-
mlir/test/lib/Dialect/Test/TestDialect.cpp | 1450 ++---------------
mlir/test/lib/Dialect/Test/TestDialect.h | 50 +-
.../Dialect/Test/TestDialectInterfaces.cpp | 1 +
.../test/lib/Dialect/Test/TestFormatUtils.cpp | 377 +++++
mlir/test/lib/Dialect/Test/TestFormatUtils.h | 211 +++
.../Test/TestFromLLVMIRTranslation.cpp | 1 +
mlir/test/lib/Dialect/Test/TestInterfaces.cpp | 2 +
mlir/test/lib/Dialect/Test/TestInterfaces.h | 2 +
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 1161 +++++++++++++
mlir/test/lib/Dialect/Test/TestOps.cpp | 18 +
mlir/test/lib/Dialect/Test/TestOps.h | 149 ++
mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp | 1 +
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 +
.../Dialect/Test/TestToLLVMIRTranslation.cpp | 1 +
mlir/test/lib/Dialect/Test/TestTraits.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestTypes.cpp | 1 +
mlir/test/lib/Dialect/Test/TestTypes.h | 6 +-
mlir/test/lib/IR/TestBytecodeRoundtrip.cpp | 1 +
mlir/test/lib/IR/TestClone.cpp | 2 +-
mlir/test/lib/IR/TestSideEffects.cpp | 2 +-
mlir/test/lib/IR/TestSymbolUses.cpp | 2 +-
mlir/test/lib/IR/TestTypes.cpp | 1 +
mlir/test/lib/IR/TestVisitorsGeneric.cpp | 2 +-
mlir/test/lib/Pass/TestPassManager.cpp | 1 +
mlir/test/lib/Transforms/TestInlining.cpp | 1 +
.../Transforms/TestMakeIsolatedFromAbove.cpp | 1 +
mlir/unittests/IR/AdaptorTest.cpp | 1 +
mlir/unittests/IR/IRMapping.cpp | 1 +
mlir/unittests/IR/InterfaceAttachmentTest.cpp | 1 +
mlir/unittests/IR/InterfaceTest.cpp | 1 +
mlir/unittests/IR/OperationSupportTest.cpp | 1 +
mlir/unittests/IR/PatternMatchTest.cpp | 1 +
mlir/unittests/TableGen/OpBuildGen.cpp | 1 +
42 files changed, 2114 insertions(+), 1354 deletions(-)
create mode 100644 mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestFormatUtils.h
create mode 100644 mlir/test/lib/Dialect/Test/TestOpDefs.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestOps.cpp
create mode 100644 mlir/test/lib/Dialect/Test/TestOps.h
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index ca052392f2f5f2..65592a5c5d698b 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 29480f5ad63ee0..3f9ce2dc0bc50a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -12,6 +12,7 @@
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
index 5e17779660f392..f878a262512ee8 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 3c4067b35d8e5b..cc1af59c5e15bb 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index b098a5a23fd316..34513cd418e4c2 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
index 84f45b31603192..56f309f150ca5d 100644
--- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
+++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 10aba733bd5696..0d7dce2240f4cb 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index d246c0492a3bd5..f63e4d330e6ac1 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -47,7 +47,10 @@ add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
TestDialect.cpp
+ TestFormatUtils.cpp
TestInterfaces.cpp
+ TestOpDefs.cpp
+ TestOps.cpp
TestPatterns.cpp
TestTraits.cpp
TestTypes.cpp
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index d41d495c38e553..2cc051e664beec 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
@@ -244,7 +245,7 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
//===----------------------------------------------------------------------===//
#include "TestAttrInterfaces.cpp.inc"
-
+#include "TestOpEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a23ed89c4b04d1..77fd7e61bd3a06 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -7,8 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
-#include "TestAttributes.h"
-#include "TestInterfaces.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -39,17 +38,85 @@
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Interfaces/FoldInterfaces.h"
+#include "mlir/Reducer/ReductionPatternInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
#include <numeric>
#include <optional>
-// Include this before the using namespace lines below to
-// test that we don't have namespace dependencies.
+// Include this before the using namespace lines below to test that we don't
+// have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
using namespace mlir;
using namespace test;
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
+ Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set TestProperties";
+ return failure();
+ }
+ auto label = dict.getAs<mlir::StringAttr>("label");
+ if (!label) {
+ emitError() << "expected StringAttr for key `label`";
+ return failure();
+ }
+ auto valueAttr = dict.getAs<IntegerAttr>("value");
+ if (!valueAttr) {
+ emitError() << "expected IntegerAttr for key `value`";
+ return failure();
+ }
+
+ prop.label = std::make_shared<std::string>(label.getValue());
+ prop.value = valueAttr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr
+test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
+ attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
+ return llvm::hash_combine(prop.value, StringRef(*prop.label));
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop) {
+ p.printKeywordOrString(*prop.label);
+ p << " is " << prop.value;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop) {
+ std::string label;
+ if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
+ parser.parseInteger(prop.value))
+ return failure();
+ prop.label = std::make_shared<std::string>(std::move(label));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
@@ -70,8 +137,8 @@ llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MyPropStruct &prop) {
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MyPropStruct &prop) {
StringRef str;
if (failed(reader.readString(str)))
return failure();
@@ -79,13 +146,71 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- MyPropStruct &prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ MyPropStruct &prop) {
writer.writeOwnedString(prop.content);
}
-static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
- MutableArrayRef<int64_t> prop) {
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
+ function_ref<InFlightDiagnostic()> emitError) {
+ DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
+ if (!dict) {
+ emitError() << "expected DictionaryAttr to set VersionedProperties";
+ return failure();
+ }
+ auto value1Attr = dict.getAs<IntegerAttr>("value1");
+ if (!value1Attr) {
+ emitError() << "expected IntegerAttr for key `value1`";
+ return failure();
+ }
+ auto value2Attr = dict.getAs<IntegerAttr>("value2");
+ if (!value2Attr) {
+ emitError() << "expected IntegerAttr for key `value2`";
+ return failure();
+ }
+
+ prop.value1 = value1Attr.getValue().getSExtValue();
+ prop.value2 = value2Attr.getValue().getSExtValue();
+ return success();
+}
+
+DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
+ const VersionedProperties &prop) {
+ SmallVector<NamedAttribute> attrs;
+ Builder b{ctx};
+ attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
+ attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
+ return b.getDictionaryAttr(attrs);
+}
+
+llvm::hash_code test::computeHash(const VersionedProperties &prop) {
+ return llvm::hash_combine(prop.value1, prop.value2);
+}
+
+void test::customPrintProperties(OpAsmPrinter &p,
+ const VersionedProperties &prop) {
+ p << prop.value1 << " | " << prop.value2;
+}
+
+ParseResult test::customParseProperties(OpAsmParser &parser,
+ VersionedProperties &prop) {
+ if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
+ parser.parseInteger(prop.value2))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
+ MutableArrayRef<int64_t> prop) {
uint64_t size;
if (failed(reader.readVarInt(size)))
return failure();
@@ -101,45 +226,13 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader,
return success();
}
-static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
- ArrayRef<int64_t> prop) {
+void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
+ ArrayRef<int64_t> prop) {
writer.writeVarInt(prop.size());
for (auto elt : prop)
writer.writeVarInt(elt);
}
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop);
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop);
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError);
-static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx,
- const VersionedProperties &prop);
-static llvm::hash_code computeHash(const VersionedProperties &prop);
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop);
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop);
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions);
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions);
-
-void test::registerTestDialect(DialectRegistry ®istry) {
- registry.insert<TestDialect>();
-}
-
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
@@ -196,9 +289,20 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
// TestDialect
//===----------------------------------------------------------------------===//
-static void testSideEffectOpGetEffect(
+void test::registerTestDialect(DialectRegistry ®istry) {
+ registry.insert<TestDialect>();
+}
+
+void test::testSideEffectOpGetEffect(
Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
+ SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
+ &effects) {
+ auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
+ if (!effectsAttr)
+ return;
+
+ effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
+}
// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
struct TestOpEffectInterfaceFallback
@@ -318,57 +422,6 @@ TestDialect::getOperationPrinter(Operation *op) const {
return {};
}
-//===----------------------------------------------------------------------===//
-// TypedAttrOp
-//===----------------------------------------------------------------------===//
-
-/// Parse an attribute with a given type.
-static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
- Attribute &attr) {
- return parser.parseAttribute(attr, type.getValue());
-}
-
-/// Print an attribute without its type.
-static void printAttrElideType(AsmPrinter &printer, Operation *op,
- TypeAttr type, Attribute attr) {
- printer.printAttributeWithoutType(attr);
-}
-
-//===----------------------------------------------------------------------===//
-// TestBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
- assert(index == 0 && "invalid successor index");
- return SuccessorOperands(getTargetOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 1)
- return SuccessorOperands(getFirstOperandsMutable());
- return SuccessorOperands(getSecondOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestProducingBranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
- assert(index <= 1 && "invalid successor index");
- if (index == 0)
- return SuccessorOperands(0, getSuccessOperandsMutable());
- return SuccessorOperands(1, getErrorOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// TestDialectCanonicalizerOp
-//===----------------------------------------------------------------------===//
-
static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
PatternRewriter &rewriter) {
@@ -381,1206 +434,3 @@ void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
-
-//===----------------------------------------------------------------------===//
-// TestCallOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // Check that the callee attribute was specified.
- auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
- if (!fnAttr)
- return emitOpError("requires a 'callee' symbol reference attribute");
- if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
- return emitOpError() << "'" << fnAttr.getValue()
- << "' does not reference a valid function";
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// ConversionFuncOp
-//===----------------------------------------------------------------------===//
-
-ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
- OperationState &result) {
- auto buildFuncType =
- [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
- function_interface_impl::VariadicFlag,
- std::string &) { return builder.getFunctionType(argTypes, results); };
-
- return function_interface_impl::parseFunctionOp(
- parser, result, /*allowVariadic=*/false,
- getFunctionTypeAttrName(result.name), buildFuncType,
- getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
-}
-
-void ConversionFuncOp::print(OpAsmPrinter &p) {
- function_interface_impl::printFunctionOp(
- p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
- getArgAttrsAttrName(), getResAttrsAttrName());
-}
-
-//===----------------------------------------------------------------------===//
-// TestFoldToCallOp
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
- using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(FoldToCallOp op,
- PatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
- op.getCalleeAttr(), ValueRange());
- return success();
- }
-};
-} // namespace
-
-void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldToCallOpPattern>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// Test IsolatedRegionOp - parse passthrough region arguments.
-//===----------------------------------------------------------------------===//
-
-ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Parse the input operand.
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (parser.parseOperand(argInfo.ssaName) ||
- parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
-}
-
-void IsolatedRegionOp::print(OpAsmPrinter &p) {
- p << ' ';
- p.printOperand(getOperand());
- p.shadowRegionArgs(getRegion(), getOperand());
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test SSACFGRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
- return RegionKind::SSACFG;
-}
-
-//===----------------------------------------------------------------------===//
-// Test GraphRegionOp
-//===----------------------------------------------------------------------===//
-
-RegionKind GraphRegionOp::getRegionKind(unsigned index) {
- return RegionKind::Graph;
-}
-
-//===----------------------------------------------------------------------===//
-// Test AffineScopeOp
-//===----------------------------------------------------------------------===//
-
-ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
-}
-
-void AffineScopeOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// Test OptionalCustomAttrOp
-//===----------------------------------------------------------------------===//
-
-static OptionalParseResult parseOptionalCustomParser(AsmParser &p,
- IntegerAttr &result) {
- if (succeeded(p.parseOptionalKeyword("foo")))
- return p.parseAttribute(result);
- return {};
-}
-
-static void printOptionalCustomParser(AsmPrinter &p, Operation *,
- IntegerAttr result) {
- p << "foo ";
- p.printAttribute(result);
-}
-
-//===----------------------------------------------------------------------===//
-// ReifyBoundOp
-//===----------------------------------------------------------------------===//
-
-::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
- if (getType() == "EQ")
- return ::mlir::presburger::BoundType::EQ;
- if (getType() == "LB")
- return ::mlir::presburger::BoundType::LB;
- if (getType() == "UB")
- return ::mlir::presburger::BoundType::UB;
- llvm_unreachable("invalid bound type");
-}
-
-LogicalResult ReifyBoundOp::verify() {
- if (isa<ShapedType>(getVar().getType())) {
- if (!getDim().has_value())
- return emitOpError("expected 'dim' attribute for shaped type variable");
- } else if (getVar().getType().isIndex()) {
- if (getDim().has_value())
- return emitOpError("unexpected 'dim' attribute for index variable");
- } else {
- return emitOpError("expected index-typed variable or shape type variable");
- }
- if (getConstant() && getScalable())
- return emitOpError("'scalable' and 'constant' are mutually exlusive");
- if (getScalable() != getVscaleMin().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- if (getScalable() != getVscaleMax().has_value())
- return emitOpError("expected 'vscale_min' if and only if 'scalable'");
- return success();
-}
-
-::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
- if (getDim().has_value())
- return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
- return ValueBoundsConstraintSet::Variable(getVar());
-}
-
-::mlir::ValueBoundsConstraintSet::ComparisonOperator
-CompareOp::getComparisonOperator() {
- if (getCmp() == "EQ")
- return ValueBoundsConstraintSet::ComparisonOperator::EQ;
- if (getCmp() == "LT")
- return ValueBoundsConstraintSet::ComparisonOperator::LT;
- if (getCmp() == "LE")
- return ValueBoundsConstraintSet::ComparisonOperator::LE;
- if (getCmp() == "GT")
- return ValueBoundsConstraintSet::ComparisonOperator::GT;
- if (getCmp() == "GE")
- return ValueBoundsConstraintSet::ComparisonOperator::GE;
- llvm_unreachable("invalid comparison operator");
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
- if (!getLhsMap())
- return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(0, getLhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
-}
-
-::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
- int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- if (!getRhsMap())
- return ValueBoundsConstraintSet::Variable(
- getVarOperands()[rhsOperandsBegin]);
- SmallVector<Value> mapOperands(
- getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
- return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
-}
-
-LogicalResult CompareOp::verify() {
- if (getCompose() && (getLhsMap() || getRhsMap()))
- return emitOpError(
- "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
- int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
- expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
- if (getVarOperands().size() != size_t(expectedNumOperands))
- return emitOpError("expected ")
- << expectedNumOperands << " operands, but got "
- << getVarOperands().size();
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test removing op with inner ops.
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct TestRemoveOpWithInnerOps
- : public OpRewritePattern<TestOpWithRegionPattern> {
- using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
-
- void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
-
- LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
- PatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
- return success();
- }
-};
-} // namespace
-
-void TestOpWithRegionPattern::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<TestRemoveOpWithInnerOps>(context);
-}
-
-OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
- return getOperand();
-}
-
-OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
-
-LogicalResult TestOpWithVariadicResultsAndFolder::fold(
- FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
- for (Value input : this->getOperands()) {
- results.push_back(input);
- }
- return success();
-}
-
-OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
- // Exercise the fact that an operation created with createOrFold should be
- // allowed to access its parent block.
- assert(getOperation()->getBlock() &&
- "expected that operation is not unlinked");
-
- if (adaptor.getOp() && !getProperties().attr) {
- // The folder adds "attr" if not present.
- getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
- return getResult();
- }
- return {};
-}
-
-OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
- int64_t sum = 0;
- if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
- sum += value.getValue().getSExtValue();
-
- for (Attribute attr : adaptor.getVariadic())
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 2 * value.getValue().getSExtValue();
-
- for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
- for (Attribute attr : attrs)
- if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
- sum += 3 * value.getValue().getSExtValue();
-
- sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
-
- return IntegerAttr::get(getType(), sum);
-}
-
-LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- inferredReturnTypes.assign({operands[0].getType()});
- return success();
-}
-
-LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
- MLIRContext *, std::optional<Location> location,
- OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<Type> &inferredReturnTypes) {
- if (adaptor.getX().getType() != adaptor.getY().getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- adaptor.getX().getType(), " vs ",
- adaptor.getY().getType());
- }
- inferredReturnTypes.assign({adaptor.getX().getType()});
- return success();
-}
-
-// TODO: We should be able to only define either inferReturnType or
-// refineReturnType, currently only refineReturnType can be omitted.
-LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
- MLIRContext *context, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- returnTypes.clear();
- return OpWithRefineTypeInterfaceOp::refineReturnTypes(
- context, location, operands, attributes, properties, regions,
- returnTypes);
-}
-
-LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
- MLIRContext *, std::optional<Location> location, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &returnTypes) {
- if (operands[0].getType() != operands[1].getType()) {
- return emitOptionalError(location, "operand type mismatch ",
- operands[0].getType(), " vs ",
- operands[1].getType());
- }
- // TODO: Add helper to make this more concise to write.
- if (returnTypes.empty())
- returnTypes.resize(1, nullptr);
- if (returnTypes[0] && returnTypes[0] != operands[0].getType())
- return emitOptionalError(location,
- "required first operand and result to match");
- returnTypes[0] = operands[0].getType();
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- ValueShapeRange operands, DictionaryAttr attributes,
- OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = operands.front().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
- MLIRContext *context, std::optional<Location> location,
- OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- // Create return type consisting of the last element of the first operand.
- auto operandType = adaptor.getOperand1().getType();
- auto sval = dyn_cast<ShapedType>(operandType);
- if (!sval)
- return emitOptionalError(location, "only shaped type operands allowed");
- int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
- auto type = IntegerType::get(context, 17);
-
- Attribute encoding;
- if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
- encoding = rankedTy.getEncoding();
- inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
- return success();
-}
-
-LogicalResult
-OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- shapes = SmallVector<Value, 1>{
- builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
- return success();
-}
-
-LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- Location loc = getLoc();
- shapes.reserve(operands.size());
- for (Value operand : llvm::reverse(operands)) {
- auto rank = cast<RankedTensorType>(operand.getType()).getRank();
- auto currShape = llvm::to_vector<4>(
- llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
- return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
- }));
- shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
- currShape));
- }
- return success();
-}
-
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
- OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
- Location loc = getLoc();
- shapes.reserve(getNumOperands());
- for (Value operand : llvm::reverse(getOperands())) {
- auto tensorType = cast<RankedTensorType>(operand.getType());
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, tensorType.getRank()),
- [&](int64_t dim) -> OpFoldResult {
- return tensorType.isDynamicDim(dim)
- ? static_cast<OpFoldResult>(
- builder.createOrFold<tensor::DimOp>(loc, operand,
- dim))
- : static_cast<OpFoldResult>(
- builder.getIndexAttr(tensorType.getDimSize(dim)));
- }));
- shapes.emplace_back(std::move(currShape));
- }
- return success();
-}
-
-LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
- MLIRContext *context, std::optional<Location>, ValueRange operands,
- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
- SmallVectorImpl<Type> &inferredReturnTypes) {
-
- Adaptor adaptor(operands, attributes, properties, regions);
- inferredReturnTypes.push_back(IntegerType::get(
- context, adaptor.getLhs() + adaptor.getProperties().rhs));
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test SideEffect interfaces
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// A test resource for side effects.
-struct TestResource : public SideEffects::Resource::Base<TestResource> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
-
- StringRef getName() final { return "<Test>"; }
-};
-} // namespace
-
-static void testSideEffectOpGetEffect(
- Operation *op,
- SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
- &effects) {
- auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
- if (!effectsAttr)
- return;
-
- effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- // Check for an effects attribute on the op instance.
- ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
- if (!effectsAttr)
- return;
-
- // If there is one, it is an array of dictionary attributes that hold
- // information on the effects of this operation.
- for (Attribute element : effectsAttr) {
- DictionaryAttr effectElement = cast<DictionaryAttr>(element);
-
- // Get the specific memory effect.
- MemoryEffects::Effect *effect =
- StringSwitch<MemoryEffects::Effect *>(
- cast<StringAttr>(effectElement.get("effect")).getValue())
- .Case("allocate", MemoryEffects::Allocate::get())
- .Case("free", MemoryEffects::Free::get())
- .Case("read", MemoryEffects::Read::get())
- .Case("write", MemoryEffects::Write::get());
-
- // Check for a non-default resource to use.
- SideEffects::Resource *resource = SideEffects::DefaultResource::get();
- if (effectElement.get("test_resource"))
- resource = TestResource::get();
-
- // Check for a result to affect.
- if (effectElement.get("on_result"))
- effects.emplace_back(effect, getResult(), resource);
- else if (Attribute ref = effectElement.get("on_reference"))
- effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
- else
- effects.emplace_back(effect, resource);
- }
-}
-
-void SideEffectOp::getEffects(
- SmallVectorImpl<TestEffects::EffectInstance> &effects) {
- testSideEffectOpGetEffect(getOperation(), effects);
-}
-
-//===----------------------------------------------------------------------===//
-// StringAttrPrettyNameOp
-//===----------------------------------------------------------------------===//
-
-// This op has fancy handling of its SSA result name.
-ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Add the result types.
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
- result.addTypes(parser.getBuilder().getIntegerType(32));
-
- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
- return failure();
-
- // If the attribute dictionary contains no 'names' attribute, infer it from
- // the SSA name (if specified).
- bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
- return attr.getName() == "names";
- });
-
- // If there was no name specified, check to see if there was a useful name
- // specified in the asm file.
- if (hadNames || parser.getNumResults() == 0)
- return success();
-
- SmallVector<StringRef, 4> names;
- auto *context = result.getContext();
-
- for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
- auto resultName = parser.getResultName(i);
- StringRef nameStr;
- if (!resultName.first.empty() && !isdigit(resultName.first[0]))
- nameStr = resultName.first;
-
- names.push_back(nameStr);
- }
-
- auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
- result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
- return success();
-}
-
-void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
- // Note that we only need to print the "name" attribute if the asmprinter
- // result name disagrees with it. This can happen in strange cases, e.g.
- // when there are conflicts.
- bool namesDisagree = getNames().size() != getNumResults();
-
- SmallString<32> resultNameStr;
- for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
- resultNameStr.clear();
- llvm::raw_svector_ostream tmpStream(resultNameStr);
- p.printOperand(getResult(i), tmpStream);
-
- auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
- if (!expectedName ||
- tmpStream.str().drop_front() != expectedName.getValue()) {
- namesDisagree = true;
- }
- }
-
- if (namesDisagree)
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
- else
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
-}
-
-// We set the SSA name in the asm syntax to the contents of the name
-// attribute.
-void StringAttrPrettyNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
-
- auto value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.getValue().empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-void CustomResultsNameOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- ArrayAttr value = getNames();
- for (size_t i = 0, e = value.size(); i != e; ++i)
- if (auto str = dyn_cast<StringAttr>(value[i]))
- if (!str.empty())
- setNameFn(getResult(i), str.getValue());
-}
-
-//===----------------------------------------------------------------------===//
-// ResultTypeWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ResultTypeWithTraitOp::verify() {
- if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
- return success();
- return emitError("result type should have trait 'TestTypeTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// AttrWithTraitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult AttrWithTraitOp::verify() {
- if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
- return success();
- return emitError("'attr' attribute should have trait 'TestAttrTrait'");
-}
-
-//===----------------------------------------------------------------------===//
-// RegionIfOp
-//===----------------------------------------------------------------------===//
-
-void RegionIfOp::print(OpAsmPrinter &p) {
- p << " ";
- p.printOperands(getOperands());
- p << ": " << getOperandTypes();
- p.printArrowTypeList(getResultTypes());
- p << " then ";
- p.printRegion(getThenRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " else ";
- p.printRegion(getElseRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
- p << " join ";
- p.printRegion(getJoinRegion(),
- /*printEntryBlockArgs=*/true,
- /*printBlockTerminators=*/true);
-}
-
-ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
- SmallVector<Type, 2> operandTypes;
-
- result.regions.reserve(3);
- Region *thenRegion = result.addRegion();
- Region *elseRegion = result.addRegion();
- Region *joinRegion = result.addRegion();
-
- // Parse operand, type and arrow type lists.
- if (parser.parseOperandList(operandInfos) ||
- parser.parseColonTypeList(operandTypes) ||
- parser.parseArrowTypeList(result.types))
- return failure();
-
- // Parse all attached regions.
- if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
- parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
- parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
- return failure();
-
- return parser.resolveOperands(operandInfos, operandTypes,
- parser.getCurrentLocation(), result.operands);
-}
-
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
- "invalid region index");
- return getOperands();
-}
-
-void RegionIfOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // We always branch to the join region.
- if (!point.isParent()) {
- if (point != getJoinRegion())
- regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
- else
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
- // The then and else regions are the entry regions of this op.
- regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
- regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
-}
-
-void RegionIfOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- // Each region is invoked at most once.
- invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
-}
-
-//===----------------------------------------------------------------------===//
-// AnyCondOp
-//===----------------------------------------------------------------------===//
-
-void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
- SmallVectorImpl<RegionSuccessor> ®ions) {
- // The parent op branches into the only region, and the region branches back
- // to the parent op.
- if (point.isParent())
- regions.emplace_back(&getRegion());
- else
- regions.emplace_back(getResults());
-}
-
-void AnyCondOp::getRegionInvocationBounds(
- ArrayRef<Attribute> operands,
- SmallVectorImpl<InvocationBounds> &invocationBounds) {
- invocationBounds.emplace_back(1, 1);
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockOp
-//===----------------------------------------------------------------------===//
-
-void LoopBlockOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- regions.emplace_back(&getBody(), getBody().getArguments());
- if (point.isParent())
- return;
-
- regions.emplace_back((*this)->getResults());
-}
-
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
- return MutableOperandRange(getInitMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// LoopBlockTerminatorOp
-//===----------------------------------------------------------------------===//
-
-MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
- return getExitArgMutable();
- return getNextIterArgMutable();
-}
-
-//===----------------------------------------------------------------------===//
-// SwitchWithNoBreakOp
-//===----------------------------------------------------------------------===//
-
-void TestNoTerminatorOp::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
-
-//===----------------------------------------------------------------------===//
-// SingleNoTerminatorCustomAsmOp
-//===----------------------------------------------------------------------===//
-
-ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
- OperationState &state) {
- Region *body = state.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
- return success();
-}
-
-void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
- printer.printRegion(
- getRegion(), /*printEntryBlockArgs=*/false,
- // This op has a single block without terminators. But explicitly mark
- // as not printing block terminators for testing.
- /*printBlockTerminators=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// TestVerifiersOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult TestVerifiersOp::verify() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- Operation *definingOp = getInput().getDefiningOp();
- if (definingOp && failed(mlir::verify(definingOp)))
- return emitOpError("operand hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of verifier");
-
- return success();
-}
-
-LogicalResult TestVerifiersOp::verifyRegions() {
- if (!getRegion().hasOneBlock())
- return emitOpError("`hasOneBlock` trait hasn't been verified");
-
- for (Block &block : getRegion())
- for (Operation &op : block)
- if (failed(mlir::verify(&op)))
- return emitOpError("nested op hasn't been verified");
-
- // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
- // loop.
- mlir::emitRemark(getLoc(), "success run of region verifier");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test InferIntRangeInterface
-//===----------------------------------------------------------------------===//
-
-void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- // Parse the input argument
- OpAsmParser::Argument argInfo;
- argInfo.type = parser.getBuilder().getIndexType();
- if (failed(parser.parseArgument(argInfo)))
- return failure();
-
- // Parse the body region, and reuse the operand info as the argument info.
- Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
-}
-
-void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
- p.printOptionalAttrDict((*this)->getAttrs());
- p << ' ';
- p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
- /*omitType=*/true);
- p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
-}
-
-void TestWithBoundsRegionOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- Value arg = getRegion().getArgument(0);
- setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
-}
-
-void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- APInt one(range.umin().getBitWidth(), 1);
- setResultRanges(getResult(),
- {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
- range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
-}
-
-void TestReflectBoundsOp::inferResultRanges(
- ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
- const ConstantIntRanges &range = argRanges[0];
- MLIRContext *ctx = getContext();
- Builder b(ctx);
- setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
- setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
- setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
- setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
- setResultRanges(getResult(), range);
-}
-
-OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
- // Just a simple fold for testing purposes that reads an operands constant
- // value and returns it.
- if (!attributes.empty())
- return attributes.front();
- return nullptr;
-}
-
-static LogicalResult
-setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set TestProperties";
- return failure();
- }
- auto label = dict.getAs<mlir::StringAttr>("label");
- if (!label) {
- emitError() << "expected StringAttr for key `label`";
- return failure();
- }
- auto valueAttr = dict.getAs<IntegerAttr>("value");
- if (!valueAttr) {
- emitError() << "expected IntegerAttr for key `value`";
- return failure();
- }
-
- prop.label = std::make_shared<std::string>(label.getValue());
- prop.value = valueAttr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx,
- const PropertiesWithCustomPrint &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
- attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
- return llvm::hash_combine(prop.value, StringRef(*prop.label));
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const PropertiesWithCustomPrint &prop) {
- p.printKeywordOrString(*prop.label);
- p << " is " << prop.value;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- PropertiesWithCustomPrint &prop) {
- std::string label;
- if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
- parser.parseInteger(prop.value))
- return failure();
- prop.label = std::make_shared<std::string>(std::move(label));
- return success();
-}
-
-static ParseResult
-parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
- SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
- SmallVector<int64_t> caseValues;
- while (succeeded(p.parseOptionalKeyword("case"))) {
- int64_t value;
- Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
- if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
- return failure();
- caseValues.push_back(value);
- }
- cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
- return success();
-}
-
-static void printSwitchCases(OpAsmPrinter &p, Operation *op,
- DenseI64ArrayAttr cases, RegionRange caseRegions) {
- for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
- p.printNewline();
- p << "case " << value << ' ';
- p.printRegion(*region, /*printEntryBlockArgs=*/false);
- }
-}
-
-static LogicalResult
-setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
- function_ref<InFlightDiagnostic()> emitError) {
- DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
- if (!dict) {
- emitError() << "expected DictionaryAttr to set VersionedProperties";
- return failure();
- }
- auto value1Attr = dict.getAs<IntegerAttr>("value1");
- if (!value1Attr) {
- emitError() << "expected IntegerAttr for key `value1`";
- return failure();
- }
- auto value2Attr = dict.getAs<IntegerAttr>("value2");
- if (!value2Attr) {
- emitError() << "expected IntegerAttr for key `value2`";
- return failure();
- }
-
- prop.value1 = value1Attr.getValue().getSExtValue();
- prop.value2 = value2Attr.getValue().getSExtValue();
- return success();
-}
-
-static DictionaryAttr
-getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) {
- SmallVector<NamedAttribute> attrs;
- Builder b{ctx};
- attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
- attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
- return b.getDictionaryAttr(attrs);
-}
-
-static llvm::hash_code computeHash(const VersionedProperties &prop) {
- return llvm::hash_combine(prop.value1, prop.value2);
-}
-
-static void customPrintProperties(OpAsmPrinter &p,
- const VersionedProperties &prop) {
- p << prop.value1 << " | " << prop.value2;
-}
-
-static ParseResult customParseProperties(OpAsmParser &parser,
- VersionedProperties &prop) {
- if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
- parser.parseInteger(prop.value2))
- return failure();
- return success();
-}
-
-static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
- return parser.parseLSquare() || parser.parseInteger(value[0]) ||
- parser.parseComma() || parser.parseInteger(value[1]) ||
- parser.parseComma() || parser.parseInteger(value[2]) ||
- parser.parseRSquare();
-}
-
-static void printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
- ArrayRef<int64_t> value) {
- printer << '[' << value << ']';
-}
-
-static bool parseIntProperty(OpAsmParser &parser, int64_t &value) {
- return failed(parser.parseInteger(value));
-}
-
-static void printIntProperty(OpAsmPrinter &printer, Operation *op,
- int64_t value) {
- printer << value;
-}
-
-static bool parseSumProperty(OpAsmParser &parser, int64_t &second,
- int64_t first) {
- int64_t sum;
- auto loc = parser.getCurrentLocation();
- if (parser.parseInteger(second) || parser.parseEqual() ||
- parser.parseInteger(sum))
- return true;
- if (sum != second + first) {
- parser.emitError(loc, "Expected sum to equal first + second");
- return true;
- }
- return false;
-}
-
-static void printSumProperty(OpAsmPrinter &printer, Operation *op,
- int64_t second, int64_t first) {
- printer << second << " = " << (second + first);
-}
-
-//===----------------------------------------------------------------------===//
-// Tensor/Buffer Ops
-//===----------------------------------------------------------------------===//
-
-void ReadBufferOp::getEffects(
- SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
- &effects) {
- // The buffer operand is read.
- effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
- SideEffects::DefaultResource::get());
- // The buffer contents are dumped.
- effects.emplace_back(MemoryEffects::Write::get(),
- SideEffects::DefaultResource::get());
-}
-
-//===----------------------------------------------------------------------===//
-// Test Dataflow
-//===----------------------------------------------------------------------===//
-
-CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallAndStoreOp::getArgOperands() {
- return getCalleeOperands();
-}
-
-MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
- return getCalleeOperandsMutable();
-}
-
-CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
- return getCallee();
-}
-
-void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- setCalleeAttr(callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
- return getForwardedOperands();
-}
-
-MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
- return getForwardedOperandsMutable();
-}
-
-void TestStoreWithARegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- if (point.isParent())
- regions.emplace_back(&getBody(), getBody().front().getArguments());
- else
- regions.emplace_back();
-}
-
-void TestStoreWithALoopRegion::getSuccessorRegions(
- RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
- // Both the operation itself and the region may be branching into the body or
- // back into the operation itself. It is possible for the operation not to
- // enter the body.
- regions.emplace_back(
- RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
-}
-
-LogicalResult
-TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader,
- ::mlir::OperationState &state) {
- auto &prop = state.getOrAddProperties<Properties>();
- if (::mlir::failed(reader.readAttribute(prop.dims)))
- return ::mlir::failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- return success();
- }
- }
-
- if (::mlir::failed(reader.readAttribute(prop.modifier)))
- return ::mlir::failure();
- return ::mlir::success();
-}
-
-void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) {
- auto &prop = getProperties();
- writer.writeAttribute(prop.dims);
-
- auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to write.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2)) {
- llvm::outs() << "downgrading op properties...\n";
- return;
- }
- }
- writer.writeAttribute(prop.modifier);
-}
-
-::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
- ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
- uint64_t value1, value2 = 0;
- if (failed(reader.readVarInt(value1)))
- return failure();
-
- // Check if we have a version. If not, assume we are parsing the current
- // version.
- auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
- bool needToParseAnotherInt = true;
- if (succeeded(maybeVersion)) {
- // If version is less than 2.0, there is no additional attribute to parse.
- // We can materialize missing properties post parsing before verification.
- const auto *version =
- reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
- if ((version->major_ < 2))
- needToParseAnotherInt = false;
- }
- if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
- return failure();
-
- prop.value1 = value1;
- prop.value2 = value2;
- return success();
-}
-
-void TestOpWithVersionedProperties::writeToMlirBytecode(
- ::mlir::DialectBytecodeWriter &writer,
- const test::VersionedProperties &prop) {
- writer.writeVarInt(prop.value1);
- writer.writeVarInt(prop.value2);
-}
-
-#include "TestOpEnums.cpp.inc"
-#include "TestOpInterfaces.cpp.inc"
-#include "TestTypeInterfaces.cpp.inc"
-
-#define GET_OP_CLASSES
-#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index d5b2fbeafc4104..c05e15fc642a25 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -43,19 +43,18 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include <memory>
namespace mlir {
-class DLTIDialect;
class RewritePatternSet;
-} // namespace mlir
+} // end namespace mlir
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
-#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"
namespace test {
@@ -75,49 +74,8 @@ struct TestDialectVersion : public mlir::DialectVersion {
uint32_t minor_ = 0;
};
-// Define some classes to exercises the Properties feature.
-
-struct PropertiesWithCustomPrint {
- /// A shared_ptr to a const object is safe: it is equivalent to a value-based
- /// member. Here the label will be deallocated when the last operation
- /// refering to it is destroyed. However there is no pool-allocation: this is
- /// offloaded to the client.
- std::shared_ptr<const std::string> label;
- int value;
- bool operator==(const PropertiesWithCustomPrint &rhs) const {
- return value == rhs.value && *label == *rhs.label;
- }
-};
-class MyPropStruct {
-public:
- std::string content;
- // These three methods are invoked through the `MyStructProperty` wrapper
- // defined in TestOps.td
- mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
- static mlir::LogicalResult
- setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
- llvm::hash_code hash() const;
- bool operator==(const MyPropStruct &rhs) const {
- return content == rhs.content;
- }
-};
-struct VersionedProperties {
- // For the sake of testing, assume that this object was associated to version
- // 1.2 of the test dialect when having only one int value. In the current
- // version 2.0, the property has two values. We also assume that the class is
- // upgrade-able if value2 = 0.
- int value1;
- int value2;
- bool operator==(const VersionedProperties &rhs) const {
- return value1 == rhs.value1 && value2 == rhs.value2;
- }
-};
} // namespace test
-#define GET_OP_CLASSES
-#include "TestOps.h.inc"
-
namespace test {
// Op deliberately defined in C++ code rather than ODS to test that C++
@@ -138,6 +96,10 @@ class ManualCppOpWithFold
void registerTestDialect(::mlir::DialectRegistry ®istry);
void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns);
+void testSideEffectOpGetEffect(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<
+ mlir::SideEffects::EffectInstance<mlir::TestEffects::Effect>> &effects);
} // namespace test
#endif // MLIR_TESTDIALECT_H
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 66578b246afab1..a3a8913d5964c6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
new file mode 100644
index 00000000000000..6e75dd39322810
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
@@ -0,0 +1,377 @@
+//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestFormatUtils.h"
+#include "mlir/IR/Builders.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperands(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
+ if (parser.parseOperand(operand))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseOperandList(varOperands) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
+ Value operand, Value optOperand,
+ OperandRange varOperands) {
+ printer << operand;
+ if (optOperand)
+ printer << ", " << optOperand;
+ printer << " -> (" << varOperands << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+ Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(operandType))
+ return failure();
+ if (succeeded(parser.parseOptionalComma()))
+ if (parser.parseType(optOperandType))
+ return failure();
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
+ Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " : " << operandType;
+ if (optOperandType)
+ printer << ", " << optOperandType;
+ printer << " -> (" << varOperandTypes << ")";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveWithTypeRefs(
+ OpAsmParser &parser, Type operandType, Type optOperandType,
+ const SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseKeyword("type_refs_capture"))
+ return failure();
+
+ Type operandType2, optOperandType2;
+ SmallVector<Type, 1> varOperandTypes2;
+ if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
+ varOperandTypes2))
+ return failure();
+
+ if (operandType != operandType2 || optOperandType != optOperandType2 ||
+ varOperandTypes != varOperandTypes2)
+ return failure();
+
+ return success();
+}
+
+void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
+ Operation *op, Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " type_refs_capture ";
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOperandsAndTypes(
+ OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
+ Type &operandType, Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+ parseCustomDirectiveResults(parser, operandType, optOperandType,
+ varOperandTypes))
+ return failure();
+ return success();
+}
+
+void test::printCustomDirectiveOperandsAndTypes(
+ OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
+ OperandRange varOperands, Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, op, operandType, optOperandType,
+ varOperandTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveRegions(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+ if (parser.parseRegion(region))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+ if (parser.parseRegion(*varRegion))
+ return failure();
+ varRegions.emplace_back(std::move(varRegion));
+ return success();
+}
+
+void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
+ Region ®ion,
+ MutableArrayRef<Region> varRegions) {
+ printer.printRegion(region);
+ if (!varRegions.empty()) {
+ printer << ", ";
+ for (Region ®ion : varRegions)
+ printer.printRegion(region);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+ SmallVectorImpl<Block *> &varSuccessors) {
+ if (parser.parseSuccessor(successor))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ Block *varSuccessor;
+ if (parser.parseSuccessor(varSuccessor))
+ return failure();
+ varSuccessors.append(2, varSuccessor);
+ return success();
+}
+
+void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
+ Block *successor,
+ SuccessorRange varSuccessors) {
+ printer << successor;
+ if (!varSuccessors.empty())
+ printer << ", " << varSuccessors.front();
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
+ IntegerAttr &attr,
+ IntegerAttr &optAttr) {
+ if (parser.parseAttribute(attr))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseAttribute(optAttr))
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
+ Attribute attribute,
+ Attribute optAttribute) {
+ printer << attribute;
+ if (optAttribute)
+ printer << ", " << optAttribute;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
+ NamedAttrList &attrs) {
+ return parser.parseOptionalAttrDict(attrs);
+}
+
+void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
+ DictionaryAttr attrs) {
+ printer.printOptionalAttrDict(attrs.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomDirectiveOptionalOperandRef(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ int64_t operandCount = 0;
+ if (parser.parseInteger(operandCount))
+ return failure();
+ bool expectedOptionalOperand = operandCount == 0;
+ return success(expectedOptionalOperand != !!optOperand);
+}
+
+void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
+ Operation *op,
+ Value optOperand) {
+ printer << (optOperand ? "1" : "0");
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseCustomOptionalOperand(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
+ if (succeeded(parser.parseOptionalLParen())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand) || parser.parseRParen())
+ return failure();
+ }
+ return success();
+}
+
+void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
+ Value optOperand) {
+ if (optOperand)
+ printer << "(" << optOperand << ") ";
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+ParseResult
+test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+ SmallVector<int64_t> caseValues;
+ while (succeeded(p.parseOptionalKeyword("case"))) {
+ int64_t value;
+ Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
+ if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+ return failure();
+ caseValues.push_back(value);
+ }
+ cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+ return success();
+}
+
+void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
+ for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+ p.printNewline();
+ p << "case " << value << ' ';
+ p.printRegion(*region, /*printEntryBlockArgs=*/false);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
+ return parser.parseLSquare() || parser.parseInteger(value[0]) ||
+ parser.parseComma() || parser.parseInteger(value[1]) ||
+ parser.parseComma() || parser.parseInteger(value[2]) ||
+ parser.parseRSquare();
+}
+
+void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
+ ArrayRef<int64_t> value) {
+ printer << '[' << value << ']';
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
+ return failed(parser.parseInteger(value));
+}
+
+void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t value) {
+ printer << value;
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
+ int64_t first) {
+ int64_t sum;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseInteger(second) || parser.parseEqual() ||
+ parser.parseInteger(sum))
+ return true;
+ if (sum != second + first) {
+ parser.emitError(loc, "Expected sum to equal first + second");
+ return true;
+ }
+ return false;
+}
+
+void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
+ int64_t second, int64_t first) {
+ printer << second << " = " << (second + first);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
+ IntegerAttr &result) {
+ if (succeeded(p.parseOptionalKeyword("foo")))
+ return p.parseAttribute(result);
+ return {};
+}
+
+void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
+ IntegerAttr result) {
+ p << "foo ";
+ p.printAttribute(result);
+}
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
+ Attribute &attr) {
+ return parser.parseAttribute(attr, type.getValue());
+}
+
+void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
+ Attribute attr) {
+ printer.printAttributeWithoutType(attr);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
new file mode 100644
index 00000000000000..7e9cd834278e34
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
@@ -0,0 +1,211 @@
+//===- TestFormatUtils.h - MLIR Test Dialect Assembly Format Utilities ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTFORMATUTILS_H
+#define MLIR_TESTFORMATUTILS_H
+
+#include "mlir/IR/OpImplementation.h"
+
+namespace test {
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperands
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperands(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands);
+
+void printCustomDirectiveOperands(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Value operand,
+ mlir::Value optOperand,
+ mlir::OperandRange varOperands);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveResults
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult
+parseCustomDirectiveResults(mlir::OpAsmParser &parser, mlir::Type &operandType,
+ mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveResults(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveWithTypeRefs
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveWithTypeRefs(
+ mlir::OpAsmParser &parser, mlir::Type operandType,
+ mlir::Type optOperandType,
+ const llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveWithTypeRefs(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Type operandType,
+ mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOperandsAndTypes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &varOperands,
+ mlir::Type &operandType, mlir::Type &optOperandType,
+ llvm::SmallVectorImpl<mlir::Type> &varOperandTypes);
+
+void printCustomDirectiveOperandsAndTypes(
+ mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value operand,
+ mlir::Value optOperand, mlir::OperandRange varOperands,
+ mlir::Type operandType, mlir::Type optOperandType,
+ mlir::TypeRange varOperandTypes);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveRegions
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveRegions(
+ mlir::OpAsmParser &parser, mlir::Region ®ion,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &varRegions);
+
+void printCustomDirectiveRegions(
+ mlir::OpAsmPrinter &printer, mlir::Operation *, mlir::Region ®ion,
+ llvm::MutableArrayRef<mlir::Region> varRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSuccessors
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveSuccessors(
+ mlir::OpAsmParser &parser, mlir::Block *&successor,
+ llvm::SmallVectorImpl<mlir::Block *> &varSuccessors);
+
+void printCustomDirectiveSuccessors(mlir::OpAsmPrinter &printer,
+ mlir::Operation *, mlir::Block *successor,
+ mlir::SuccessorRange varSuccessors);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttributes
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttributes(mlir::OpAsmParser &parser,
+ mlir::IntegerAttr &attr,
+ mlir::IntegerAttr &optAttr);
+
+void printCustomDirectiveAttributes(mlir::OpAsmPrinter &printer,
+ mlir::Operation *,
+ mlir::Attribute attribute,
+ mlir::Attribute optAttribute);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrDict
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveAttrDict(mlir::OpAsmParser &parser,
+ mlir::NamedAttrList &attrs);
+
+void printCustomDirectiveAttrDict(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::DictionaryAttr attrs);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperandRef
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomDirectiveOptionalOperandRef(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomDirectiveOptionalOperandRef(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalOperand
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseCustomOptionalOperand(
+ mlir::OpAsmParser &parser,
+ std::optional<mlir::OpAsmParser::UnresolvedOperand> &optOperand);
+
+void printCustomOptionalOperand(mlir::OpAsmPrinter &printer, mlir::Operation *,
+ mlir::Value optOperand);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSwitchCases
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseSwitchCases(
+ mlir::OpAsmParser &p, mlir::DenseI64ArrayAttr &cases,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &caseRegions);
+
+void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::DenseI64ArrayAttr cases,
+ mlir::RegionRange caseRegions);
+
+//===----------------------------------------------------------------------===//
+// CustomUsingPropertyInCustom
+//===----------------------------------------------------------------------===//
+
+bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]);
+
+void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer,
+ mlir::Operation *op,
+ llvm::ArrayRef<int64_t> value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveIntProperty
+//===----------------------------------------------------------------------===//
+
+bool parseIntProperty(mlir::OpAsmParser &parser, int64_t &value);
+
+void printIntProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t value);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveSumProperty
+//===----------------------------------------------------------------------===//
+
+bool parseSumProperty(mlir::OpAsmParser &parser, int64_t &second,
+ int64_t first);
+
+void printSumProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+ int64_t second, int64_t first);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveOptionalCustomParser
+//===----------------------------------------------------------------------===//
+
+mlir::OptionalParseResult parseOptionalCustomParser(mlir::AsmParser &p,
+ mlir::IntegerAttr &result);
+
+void printOptionalCustomParser(mlir::AsmPrinter &p, mlir::Operation *,
+ mlir::IntegerAttr result);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveAttrElideType
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
+ mlir::TypeAttr type,
+ mlir::Attribute &attr);
+
+void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
+ mlir::TypeAttr type, mlir::Attribute attr);
+
+} // end namespace test
+
+#endif // MLIR_TESTFORMATUTILS_H
diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
index 3673d62bea2c94..dc6413b25707e3 100644
--- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
index 64ec82ecb24ff8..14099bb4bb16ba 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp
@@ -6,3 +6,5 @@ bool mlir::TestEffects::Effect::classof(
const mlir::SideEffects::Effect *effect) {
return isa<mlir::TestEffects::Concrete>(effect);
}
+
+#include "TestOpInterfaces.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.h b/mlir/test/lib/Dialect/Test/TestInterfaces.h
index 3239584a93326d..d58d1aafbe66c2 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.h
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.h
@@ -34,4 +34,6 @@ struct Concrete : public Effect::Base<Concrete> {};
} // namespace TestEffects
} // namespace mlir
+#include "TestOpInterfaces.h.inc"
+
#endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
new file mode 100644
index 00000000000000..7263774ca158eb
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -0,0 +1,1161 @@
+//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "TestOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return SuccessorOperands(getTargetOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestProducingBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 1)
+ return SuccessorOperands(getFirstOperandsMutable());
+ return SuccessorOperands(getSecondOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestInternalBranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index <= 1 && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(0, getSuccessOperandsMutable());
+ return SuccessorOperands(1, getErrorOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Check that the callee attribute was specified.
+ auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+ if (!fnAttr)
+ return emitOpError("requires a 'callee' symbol reference attribute");
+ if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
+ return emitOpError() << "'" << fnAttr.getValue()
+ << "' does not reference a valid function";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// FoldToCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
+ using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FoldToCallOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
+ op.getCalleeAttr(), ValueRange());
+ return success();
+ }
+};
+} // namespace
+
+void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldToCallOpPattern>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// IsolatedRegionOp - test parsing passthrough operands
+//===----------------------------------------------------------------------===//
+
+ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Parse the input operand.
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(argInfo.ssaName) ||
+ parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
+}
+
+void IsolatedRegionOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printOperand(getOperand());
+ p.shadowRegionArgs(getRegion(), getOperand());
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SSACFGRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::SSACFG;
+}
+
+//===----------------------------------------------------------------------===//
+// GraphRegionOp
+//===----------------------------------------------------------------------===//
+
+RegionKind GraphRegionOp::getRegionKind(unsigned index) {
+ return RegionKind::Graph;
+}
+
+//===----------------------------------------------------------------------===//
+// AffineScopeOp
+//===----------------------------------------------------------------------===//
+
+ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
+}
+
+void AffineScopeOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestRemoveOpWithInnerOps
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestRemoveOpWithInnerOps
+ : public OpRewritePattern<TestOpWithRegionPattern> {
+ using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
+
+ void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
+
+ LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionPattern
+//===----------------------------------------------------------------------===//
+
+void TestOpWithRegionPattern::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<TestRemoveOpWithInnerOps>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithRegionFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
+ return getOperand();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpConstant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVariadicResultsAndFolder
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithVariadicResultsAndFolder::fold(
+ FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
+ for (Value input : this->getOperands()) {
+ results.push_back(input);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpInPlaceFold
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
+ // Exercise the fact that an operation created with createOrFold should be
+ // allowed to access its parent block.
+ assert(getOperation()->getBlock() &&
+ "expected that operation is not unlinked");
+
+ if (adaptor.getOp() && !getProperties().attr) {
+ // The folder adds "attr" if not present.
+ getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
+ return getResult();
+ }
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ inferredReturnTypes.assign({operands[0].getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes,
+ OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = operands.front().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto rank = cast<RankedTensorType>(operand.getType()).getRank();
+ auto currShape = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
+ return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
+ currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithResultShapePerDimInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto tensorType = cast<RankedTensorType>(operand.getType());
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, tensorType.getRank()),
+ [&](int64_t dim) -> OpFoldResult {
+ return tensorType.isDynamicDim(dim)
+ ? static_cast<OpFoldResult>(
+ builder.createOrFold<tensor::DimOp>(loc, operand,
+ dim))
+ : static_cast<OpFoldResult>(
+ builder.getIndexAttr(tensorType.getDimSize(dim)));
+ }));
+ shapes.emplace_back(std::move(currShape));
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// SideEffectOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A test resource for side effects.
+struct TestResource : public SideEffects::Resource::Base<TestResource> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
+
+ StringRef getName() final { return "<Test>"; }
+};
+} // namespace
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ // Check for an effects attribute on the op instance.
+ ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
+ if (!effectsAttr)
+ return;
+
+ // If there is one, it is an array of dictionary attributes that hold
+ // information on the effects of this operation.
+ for (Attribute element : effectsAttr) {
+ DictionaryAttr effectElement = cast<DictionaryAttr>(element);
+
+ // Get the specific memory effect.
+ MemoryEffects::Effect *effect =
+ StringSwitch<MemoryEffects::Effect *>(
+ cast<StringAttr>(effectElement.get("effect")).getValue())
+ .Case("allocate", MemoryEffects::Allocate::get())
+ .Case("free", MemoryEffects::Free::get())
+ .Case("read", MemoryEffects::Read::get())
+ .Case("write", MemoryEffects::Write::get());
+
+ // Check for a non-default resource to use.
+ SideEffects::Resource *resource = SideEffects::DefaultResource::get();
+ if (effectElement.get("test_resource"))
+ resource = TestResource::get();
+
+ // Check for a result to affect.
+ if (effectElement.get("on_result"))
+ effects.emplace_back(effect, getResult(), resource);
+ else if (Attribute ref = effectElement.get("on_reference"))
+ effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
+ else
+ effects.emplace_back(effect, resource);
+ }
+}
+
+void SideEffectOp::getEffects(
+ SmallVectorImpl<TestEffects::EffectInstance> &effects) {
+ testSideEffectOpGetEffect(getOperation(), effects);
+}
+
+//===----------------------------------------------------------------------===//
+// StringAttrPrettyNameOp
+//===----------------------------------------------------------------------===//
+
+// This op has fancy handling of its SSA result name.
+ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ // Add the result types.
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
+ result.addTypes(parser.getBuilder().getIntegerType(32));
+
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // If the attribute dictionary contains no 'names' attribute, infer it from
+ // the SSA name (if specified).
+ bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
+ return attr.getName() == "names";
+ });
+
+ // If there was no name specified, check to see if there was a useful name
+ // specified in the asm file.
+ if (hadNames || parser.getNumResults() == 0)
+ return success();
+
+ SmallVector<StringRef, 4> names;
+ auto *context = result.getContext();
+
+ for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
+ auto resultName = parser.getResultName(i);
+ StringRef nameStr;
+ if (!resultName.first.empty() && !isdigit(resultName.first[0]))
+ nameStr = resultName.first;
+
+ names.push_back(nameStr);
+ }
+
+ auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
+ result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
+ return success();
+}
+
+void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
+ // Note that we only need to print the "name" attribute if the asmprinter
+ // result name disagrees with it. This can happen in strange cases, e.g.
+ // when there are conflicts.
+ bool namesDisagree = getNames().size() != getNumResults();
+
+ SmallString<32> resultNameStr;
+ for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
+ resultNameStr.clear();
+ llvm::raw_svector_ostream tmpStream(resultNameStr);
+ p.printOperand(getResult(i), tmpStream);
+
+ auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
+ if (!expectedName ||
+ tmpStream.str().drop_front() != expectedName.getValue()) {
+ namesDisagree = true;
+ }
+ }
+
+ if (namesDisagree)
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+ else
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
+}
+
+// We set the SSA name in the asm syntax to the contents of the name
+// attribute.
+void StringAttrPrettyNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+
+ auto value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.getValue().empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// CustomResultsNameOp
+//===----------------------------------------------------------------------===//
+
+void CustomResultsNameOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ ArrayAttr value = getNames();
+ for (size_t i = 0, e = value.size(); i != e; ++i)
+ if (auto str = dyn_cast<StringAttr>(value[i]))
+ if (!str.empty())
+ setNameFn(getResult(i), str.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// ResultTypeWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ResultTypeWithTraitOp::verify() {
+ if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
+ return success();
+ return emitError("result type should have trait 'TestTypeTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// AttrWithTraitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AttrWithTraitOp::verify() {
+ if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
+ return success();
+ return emitError("'attr' attribute should have trait 'TestAttrTrait'");
+}
+
+//===----------------------------------------------------------------------===//
+// RegionIfOp
+//===----------------------------------------------------------------------===//
+
+void RegionIfOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printOperands(getOperands());
+ p << ": " << getOperandTypes();
+ p.printArrowTypeList(getResultTypes());
+ p << " then ";
+ p.printRegion(getThenRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " else ";
+ p.printRegion(getElseRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p << " join ";
+ p.printRegion(getJoinRegion(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+}
+
+ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
+ SmallVector<Type, 2> operandTypes;
+
+ result.regions.reserve(3);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+ Region *joinRegion = result.addRegion();
+
+ // Parse operand, type and arrow type lists.
+ if (parser.parseOperandList(operandInfos) ||
+ parser.parseColonTypeList(operandTypes) ||
+ parser.parseArrowTypeList(result.types))
+ return failure();
+
+ // Parse all attached regions.
+ if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
+ parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
+ parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
+ return failure();
+
+ return parser.resolveOperands(operandInfos, operandTypes,
+ parser.getCurrentLocation(), result.operands);
+}
+
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+ "invalid region index");
+ return getOperands();
+}
+
+void RegionIfOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // We always branch to the join region.
+ if (!point.isParent()) {
+ if (point != getJoinRegion())
+ regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
+ else
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ // The then and else regions are the entry regions of this op.
+ regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
+ regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
+}
+
+void RegionIfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ // Each region is invoked at most once.
+ invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
+}
+
+//===----------------------------------------------------------------------===//
+// AnyCondOp
+//===----------------------------------------------------------------------===//
+
+void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The parent op branches into the only region, and the region branches back
+ // to the parent op.
+ if (point.isParent())
+ regions.emplace_back(&getRegion());
+ else
+ regions.emplace_back(getResults());
+}
+
+void AnyCondOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ invocationBounds.emplace_back(1, 1);
+}
+
+//===----------------------------------------------------------------------===//
+// SingleBlockImplicitTerminatorOp
+//===----------------------------------------------------------------------===//
+
+/// Testing the correctness of some traits.
+static_assert(
+ llvm::is_detected<OpTrait::has_implicit_terminator_t,
+ SingleBlockImplicitTerminatorOp>::value,
+ "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
+static_assert(OpTrait::hasSingleBlockImplicitTerminator<
+ SingleBlockImplicitTerminatorOp>::value,
+ "hasSingleBlockImplicitTerminator does not match "
+ "SingleBlockImplicitTerminatorOp");
+
+//===----------------------------------------------------------------------===//
+// SingleNoTerminatorCustomAsmOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
+ OperationState &state) {
+ Region *body = state.addRegion();
+ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ return success();
+}
+
+void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
+ printer.printRegion(
+ getRegion(), /*printEntryBlockArgs=*/false,
+ // This op has a single block without terminators. But explicitly mark
+ // as not printing block terminators for testing.
+ /*printBlockTerminators=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// TestVerifiersOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestVerifiersOp::verify() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ Operation *definingOp = getInput().getDefiningOp();
+ if (definingOp && failed(mlir::verify(definingOp)))
+ return emitOpError("operand hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of verifier");
+
+ return success();
+}
+
+LogicalResult TestVerifiersOp::verifyRegions() {
+ if (!getRegion().hasOneBlock())
+ return emitOpError("`hasOneBlock` trait hasn't been verified");
+
+ for (Block &block : getRegion())
+ for (Operation &op : block)
+ if (failed(mlir::verify(&op)))
+ return emitOpError("nested op hasn't been verified");
+
+ // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
+ // loop.
+ mlir::emitRemark(getLoc(), "success run of region verifier");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsOp
+
+void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestWithBoundsRegionOp
+
+ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse the input argument
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (failed(parser.parseArgument(argInfo)))
+ return failure();
+
+ // Parse the body region, and reuse the operand info as the argument info.
+ Region *body = result.addRegion();
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
+}
+
+void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs());
+ p << ' ';
+ p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
+ /*omitType=*/true);
+ p << ' ';
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+}
+
+void TestWithBoundsRegionOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ Value arg = getRegion().getArgument(0);
+ setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
+}
+
+//===----------------------------------------------------------------------===//
+// TestIncrementOp
+
+void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ APInt one(range.umin().getBitWidth(), 1);
+ setResultRanges(getResult(),
+ {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
+ range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
+}
+
+//===----------------------------------------------------------------------===//
+// TestReflectBoundsOp
+
+void TestReflectBoundsOp::inferResultRanges(
+ ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
+ const ConstantIntRanges &range = argRanges[0];
+ MLIRContext *ctx = getContext();
+ Builder b(ctx);
+ setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
+ setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
+ setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
+ setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
+ setResultRanges(getResult(), range);
+}
+
+//===----------------------------------------------------------------------===//
+// ConversionFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false,
+ getFunctionTypeAttrName(result.name), buildFuncType,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ConversionFuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(
+ p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+ if (getType() == "EQ")
+ return mlir::presburger::BoundType::EQ;
+ if (getType() == "LB")
+ return mlir::presburger::BoundType::LB;
+ if (getType() == "UB")
+ return mlir::presburger::BoundType::UB;
+ llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+ if (isa<ShapedType>(getVar().getType())) {
+ if (!getDim().has_value())
+ return emitOpError("expected 'dim' attribute for shaped type variable");
+ } else if (getVar().getType().isIndex()) {
+ if (getDim().has_value())
+ return emitOpError("unexpected 'dim' attribute for index variable");
+ } else {
+ return emitOpError("expected index-typed variable or shape type variable");
+ }
+ if (getConstant() && getScalable())
+ return emitOpError("'scalable' and 'constant' are mutually exlusive");
+ if (getScalable() != getVscaleMin().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ if (getScalable() != getVscaleMax().has_value())
+ return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+ return success();
+}
+
+ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
+//===----------------------------------------------------------------------===//
+// CompareOp
+//===----------------------------------------------------------------------===//
+
+ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+ if (getCmp() == "EQ")
+ return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+ if (getCmp() == "LT")
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ if (getCmp() == "LE")
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (getCmp() == "GT")
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (getCmp() == "GE")
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ llvm_unreachable("invalid comparison operator");
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpFoldWithFoldAdaptor
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
+ int64_t sum = 0;
+ if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
+ sum += value.getValue().getSExtValue();
+
+ for (Attribute attr : adaptor.getVariadic())
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 2 * value.getValue().getSExtValue();
+
+ for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
+ for (Attribute attr : attrs)
+ if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
+ sum += 3 * value.getValue().getSExtValue();
+
+ sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
+
+ return IntegerAttr::get(getType(), sum);
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
+ MLIRContext *, std::optional<Location> location,
+ OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getX().getType() != adaptor.getY().getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ adaptor.getX().getType(), " vs ",
+ adaptor.getY().getType());
+ }
+ inferredReturnTypes.assign({adaptor.getX().getType()});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithRefineTypeInterfaceOp
+//===----------------------------------------------------------------------===//
+
+// TODO: We should be able to only define either inferReturnType or
+// refineReturnType, currently only refineReturnType can be omitted.
+LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ returnTypes.clear();
+ return OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ context, location, operands, attributes, properties, regions,
+ returnTypes);
+}
+
+LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
+ MLIRContext *, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &returnTypes) {
+ if (operands[0].getType() != operands[1].getType()) {
+ return emitOptionalError(location, "operand type mismatch ",
+ operands[0].getType(), " vs ",
+ operands[1].getType());
+ }
+ // TODO: Add helper to make this more concise to write.
+ if (returnTypes.empty())
+ returnTypes.resize(1, nullptr);
+ if (returnTypes[0] && returnTypes[0] != operands[0].getType())
+ return emitOptionalError(location,
+ "required first operand and result to match");
+ returnTypes[0] = operands[0].getType();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// OpWithShapedTypeInferTypeAdaptorInterfaceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
+ MLIRContext *context, std::optional<Location> location,
+ OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = adaptor.getOperand1().getType();
+ auto sval = dyn_cast<ShapedType>(operandType);
+ if (!sval)
+ return emitOptionalError(location, "only shaped type operands allowed");
+ int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
+ auto type = IntegerType::get(context, 17);
+
+ Attribute encoding;
+ if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
+ encoding = rankedTy.getEncoding();
+ inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
+ return success();
+}
+
+LogicalResult
+OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithPropertiesAndInferredType
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
+ MLIRContext *context, std::optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+
+ Adaptor adaptor(operands, attributes, properties, regions);
+ inferredReturnTypes.push_back(IntegerType::get(
+ context, adaptor.getLhs() + adaptor.getProperties().rhs));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockOp
+//===----------------------------------------------------------------------===//
+
+void LoopBlockOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ regions.emplace_back(&getBody(), getBody().getArguments());
+ if (point.isParent())
+ return;
+
+ regions.emplace_back((*this)->getResults());
+}
+
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ assert(point == getBody());
+ return MutableOperandRange(getInitMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockTerminatorOp
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ if (point.isParent())
+ return getExitArgMutable();
+ return getNextIterArgMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// SwitchWithNoBreakOp
+//===----------------------------------------------------------------------===//
+
+void TestNoTerminatorOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {}
+
+//===----------------------------------------------------------------------===//
+// Test InferIntRangeInterface
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
+ // Just a simple fold for testing purposes that reads an operands constant
+ // value and returns it.
+ if (!attributes.empty())
+ return attributes.front();
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Tensor/Buffer Ops
+//===----------------------------------------------------------------------===//
+
+void ReadBufferOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ // The buffer operand is read.
+ effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
+ SideEffects::DefaultResource::get());
+ // The buffer contents are dumped.
+ effects.emplace_back(MemoryEffects::Write::get(),
+ SideEffects::DefaultResource::get());
+}
+
+//===----------------------------------------------------------------------===//
+// Test Dataflow
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// TestCallAndStoreOp
+
+CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallAndStoreOp::getArgOperands() {
+ return getCalleeOperands();
+}
+
+MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
+ return getCalleeOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestCallOnDeviceOp
+
+CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
+ return getForwardedOperands();
+}
+
+MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
+ return getForwardedOperandsMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithARegion
+
+void TestStoreWithARegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody(), getBody().front().getArguments());
+ else
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestStoreWithALoopRegion
+
+void TestStoreWithALoopRegion::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for the operation not to
+ // enter the body.
+ regions.emplace_back(
+ RegionSuccessor(&getBody(), getBody().front().getArguments()));
+ regions.emplace_back();
+}
+
+//===----------------------------------------------------------------------===//
+// TestVersionedOpA
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
+ mlir::OperationState &state) {
+ auto &prop = state.getOrAddProperties<Properties>();
+ if (mlir::failed(reader.readAttribute(prop.dims)))
+ return mlir::failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ return success();
+ }
+ }
+
+ if (mlir::failed(reader.readAttribute(prop.modifier)))
+ return mlir::failure();
+ return mlir::success();
+}
+
+void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
+ auto &prop = getProperties();
+ writer.writeAttribute(prop.dims);
+
+ auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to write.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2)) {
+ llvm::outs() << "downgrading op properties...\n";
+ return;
+ }
+ }
+ writer.writeAttribute(prop.modifier);
+}
+
+//===----------------------------------------------------------------------===//
+// TestOpWithVersionedProperties
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
+ mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
+ uint64_t value1, value2 = 0;
+ if (failed(reader.readVarInt(value1)))
+ return failure();
+
+ // Check if we have a version. If not, assume we are parsing the current
+ // version.
+ auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
+ bool needToParseAnotherInt = true;
+ if (succeeded(maybeVersion)) {
+ // If version is less than 2.0, there is no additional attribute to parse.
+ // We can materialize missing properties post parsing before verification.
+ const auto *version =
+ reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
+ if ((version->major_ < 2))
+ needToParseAnotherInt = false;
+ }
+ if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
+ return failure();
+
+ prop.value1 = value1;
+ prop.value2 = value2;
+ return success();
+}
+
+void TestOpWithVersionedProperties::writeToMlirBytecode(
+ mlir::DialectBytecodeWriter &writer,
+ const test::VersionedProperties &prop) {
+ writer.writeVarInt(prop.value1);
+ writer.writeVarInt(prop.value2);
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp
new file mode 100644
index 00000000000000..ce7e476be74e65
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.cpp
@@ -0,0 +1,18 @@
+//===- TestOps.cpp - MLIR Test Dialect Operations ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOps.h"
+#include "TestDialect.h"
+#include "TestFormatUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+using namespace mlir;
+using namespace test;
+
+#define GET_OP_CLASSES
+#include "TestOps.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
new file mode 100644
index 00000000000000..f9925855bb9db6
--- /dev/null
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -0,0 +1,149 @@
+//===- TestOps.h - MLIR Test Dialect Operations ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TESTOPS_H
+#define MLIR_TESTOPS_H
+
+#include "TestAttributes.h"
+#include "TestInterfaces.h"
+#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/DLTI/Traits.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
+
+namespace test {
+class TestDialect;
+
+//===----------------------------------------------------------------------===//
+// TestResource
+//===----------------------------------------------------------------------===//
+
+/// A test resource for side effects.
+struct TestResource : public mlir::SideEffects::Resource::Base<TestResource> {
+ llvm::StringRef getName() final { return "<Test>"; }
+};
+
+//===----------------------------------------------------------------------===//
+// PropertiesWithCustomPrint
+//===----------------------------------------------------------------------===//
+
+struct PropertiesWithCustomPrint {
+ /// A shared_ptr to a const object is safe: it is equivalent to a value-based
+ /// member. Here the label will be deallocated when the last operation
+ /// refering to it is destroyed. However there is no pool-allocation: this is
+ /// offloaded to the client.
+ std::shared_ptr<const std::string> label;
+ int value;
+ bool operator==(const PropertiesWithCustomPrint &rhs) const {
+ return value == rhs.value && *label == *rhs.label;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ PropertiesWithCustomPrint &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr
+getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const PropertiesWithCustomPrint &prop);
+llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const PropertiesWithCustomPrint &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ PropertiesWithCustomPrint &prop);
+
+//===----------------------------------------------------------------------===//
+// MyPropStruct
+//===----------------------------------------------------------------------===//
+
+class MyPropStruct {
+public:
+ std::string content;
+ // These three methods are invoked through the `MyStructProperty` wrapper
+ // defined in TestOps.td
+ mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
+ static mlir::LogicalResult
+ setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+ llvm::hash_code hash() const;
+ bool operator==(const MyPropStruct &rhs) const {
+ return content == rhs.content;
+ }
+};
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ MyPropStruct &prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ MyPropStruct &prop);
+
+//===----------------------------------------------------------------------===//
+// VersionedProperties
+//===----------------------------------------------------------------------===//
+
+struct VersionedProperties {
+ // For the sake of testing, assume that this object was associated to version
+ // 1.2 of the test dialect when having only one int value. In the current
+ // version 2.0, the property has two values. We also assume that the class is
+ // upgrade-able if value2 = 0.
+ int value1;
+ int value2;
+ bool operator==(const VersionedProperties &rhs) const {
+ return value1 == rhs.value1 && value2 == rhs.value2;
+ }
+};
+
+mlir::LogicalResult setPropertiesFromAttribute(
+ VersionedProperties &prop, mlir::Attribute attr,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError);
+mlir::DictionaryAttr getPropertiesAsAttribute(mlir::MLIRContext *ctx,
+ const VersionedProperties &prop);
+llvm::hash_code computeHash(const VersionedProperties &prop);
+void customPrintProperties(mlir::OpAsmPrinter &p,
+ const VersionedProperties &prop);
+mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser,
+ VersionedProperties &prop);
+
+//===----------------------------------------------------------------------===//
+// Bytecode Support
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader,
+ llvm::MutableArrayRef<int64_t> prop);
+void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer,
+ llvm::ArrayRef<int64_t> prop);
+
+} // namespace test
+
+#define GET_OP_CLASSES
+#include "TestOps.h.inc"
+
+#endif // MLIR_TESTOPS_H
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
index 84e6a43655cacd..c376d6c73c6452 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp
@@ -8,6 +8,7 @@
#include "TestOpsSyntax.h"
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Base64.h"
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..0c1731ba5f07c8 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index fa093cafcb0dc3..57e7d658fb501f 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index d9b67ef95ace83..031e1062dac76d 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 7a195eb25a3ba1..1593b6d7d7534b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -139,6 +139,7 @@ static void printBarString(AsmPrinter &printer, StringRef foo) {
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
+#include "TestTypeInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index b1b5921d8faddd..da5604944d5a3b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -31,11 +31,11 @@ class TestAttrWithFormatAttr;
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
struct FieldInfo {
- ::llvm::StringRef name;
- ::mlir::Type type;
+ llvm::StringRef name;
+ mlir::Type type;
// Custom allocation called from generated constructor code
- FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const {
+ FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const {
return FieldInfo{alloc.copyInto(name), type};
}
};
diff --git a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
index e668224d343234..4894ad5294990a 100644
--- a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
+++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Bytecode/BytecodeReader.h"
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/BuiltinOps.h"
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
index 7b18f219b915f4..b742b316c77126 100644
--- a/mlir/test/lib/IR/TestClone.cpp
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index 09ad1363228243..8e13dd9751398c 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 0e1368f2e0ecaf..b470b15c533b57 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp
index 2bd63a48f77d1d..c6bce111d3ea7f 100644
--- a/mlir/test/lib/IR/TestTypes.cpp
+++ b/mlir/test/lib/IR/TestTypes.cpp
@@ -8,6 +8,7 @@
#include "TestTypes.h"
#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
index 00148df26e3512..4556671df0ba0b 100644
--- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 477b75916f80c8..2762e254903245 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp
index 9821179d05e891..223cc78dd1e21d 100644
--- a/mlir/test/lib/Transforms/TestInlining.cpp
+++ b/mlir/test/lib/Transforms/TestInlining.cpp
@@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
index 61e1fbcf3feaf3..82fa6cdb68d23c 100644
--- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
+++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp
index 66ce53bbbadec9..0a5fa8d3c475c3 100644
--- a/mlir/unittests/IR/AdaptorTest.cpp
+++ b/mlir/unittests/IR/AdaptorTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestOpsSyntax.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp
index 83627975006ee8..b88009d1e3c361 100644
--- a/mlir/unittests/IR/IRMapping.cpp
+++ b/mlir/unittests/IR/IRMapping.cpp
@@ -11,6 +11,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 58049a9969e3ab..b6066dd5685dc6 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -19,6 +19,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
#include "mlir/IR/OwningOpRef.h"
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 5ab4d9a106231a..42196b003e7dad 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -15,6 +15,7 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
using namespace mlir;
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9d75615b39c0c1..f94dc784458077 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/OperationSupport.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 30b72618e45f0b..75d5228c82d99b 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -10,6 +10,7 @@
#include "gtest/gtest.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "../../test/lib/Dialect/Test/TestOps.h"
using namespace mlir;
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 52347dcabe0381..c83ac9088114ce 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "TestOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
>From 478b4e8e9523ab80e65f115274b2f2ae4423faf4 Mon Sep 17 00:00:00 2001
From: Mogball <jeff at modular.com>
Date: Mon, 22 Apr 2024 16:39:58 +0000
Subject: [PATCH 3/3] [mlir][test] Shard the Test Dialect (NFC)
This PR uses the new op sharding mechanism in tablegen to shard the test
dialect's op definitions. This breaks the definition of ops into
multiple source files, speeding up compile time of the test dialect
dramatically. This improves developer cycle times when iterating on the
test dialect.
stack-info: PR: https://github.com/llvm/llvm-project/pull/89628, branch: users/Mogball/stack/1
---
mlir/test/lib/Dialect/Test/CMakeLists.txt | 6 +++--
mlir/test/lib/Dialect/Test/TestDialect.cpp | 5 +---
mlir/test/lib/Dialect/Test/TestOps.cpp | 1 -
.../mlir/test/BUILD.bazel | 25 +++++++++++--------
4 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f63e4d330e6ac1..fab89378093326 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -31,8 +31,6 @@ mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTestEnumDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
-mlir_tablegen(TestOps.h.inc -gen-op-decls)
-mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpsDialect.h.inc -gen-dialect-decls -dialect=test)
mlir_tablegen(TestOpsDialect.cpp.inc -gen-dialect-defs -dialect=test)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
@@ -43,6 +41,8 @@ mlir_tablegen(TestOpsSyntax.h.inc -gen-op-decls)
mlir_tablegen(TestOpsSyntax.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestOpsSyntaxIncGen)
+add_sharded_ops(TestOps 20)
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestDialect
TestAttributes.cpp
@@ -56,6 +56,7 @@ add_mlir_library(MLIRTestDialect
TestTypes.cpp
TestOpsSyntax.cpp
TestDialectInterfaces.cpp
+ ${SHARDED_SRCS}
EXCLUDE_FROM_LIBMLIR
@@ -66,6 +67,7 @@ add_mlir_library(MLIRTestDialect
MLIRTestTypeDefIncGen
MLIRTestOpsIncGen
MLIRTestOpsSyntaxIncGen
+ MLIRTestOpsShardGen
LINK_LIBS PUBLIC
MLIRControlFlowInterfaces
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 77fd7e61bd3a06..bfb9592e638288 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -326,12 +326,9 @@ struct TestOpEffectInterfaceFallback
void TestDialect::initialize() {
registerAttributes();
registerTypes();
- addOperations<
-#define GET_OP_LIST
-#include "TestOps.cpp.inc"
- >();
registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
+ registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp
index ce7e476be74e65..47d5b1b19121ef 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOps.cpp
@@ -14,5 +14,4 @@
using namespace mlir;
using namespace test;
-#define GET_OP_CLASSES
#include "TestOps.cpp.inc"
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index dc5f4047c286db..b98f7eb5613af4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -4,7 +4,7 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("//llvm:lit_test.bzl", "package_path")
-load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
+load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library", "gentbl_sharded_ops", "td_library")
package(
default_visibility = ["//visibility:public"],
@@ -151,14 +151,6 @@ gentbl_cc_library(
name = "TestOpsIncGen",
strip_include_prefix = "lib/Dialect/Test",
tbl_outs = [
- (
- ["-gen-op-decls"],
- "lib/Dialect/Test/TestOps.h.inc",
- ),
- (
- ["-gen-op-defs"],
- "lib/Dialect/Test/TestOps.cpp.inc",
- ),
(
[
"-gen-dialect-decls",
@@ -370,12 +362,25 @@ cc_library(
],
)
+gentbl_sharded_ops(
+ name = "TestDialectOpSrcs",
+ hdr_out = "lib/Dialect/Test/TestOps.h.inc",
+ shard_count = 20,
+ sharder = "//mlir:mlir-src-sharder",
+ src_file = "lib/Dialect/Test/TestOps.cpp",
+ src_out = "lib/Dialect/Test/TestOps.cpp.inc",
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "lib/Dialect/Test/TestOps.td",
+ test = True,
+ deps = [":TestOpTdFiles"],
+)
+
cc_library(
name = "TestDialect",
srcs = glob(
["lib/Dialect/Test/*.cpp"],
exclude = ["lib/Dialect/Test/TestToLLVMIRTranslation.cpp"],
- ),
+ ) + [":TestDialectOpSrcs"],
hdrs = glob(["lib/Dialect/Test/*.h"]),
includes = [
"lib/Dialect/Test",
More information about the llvm-commits
mailing list