[mlir] [llvm] [mlir] Add config for PDL (PR #69927)
Jacques Pienaar via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 3 08:48:59 PST 2024
https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/69927
>From 7d3719843a1bc842caa0e540078feb973e4ae68c Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 22 Oct 2023 09:33:40 -0700
Subject: [PATCH] [mlir] Add config for PDL
Make it so that PDL in pattern rewrites can be optionally disabled.
PDL is still enabled by default and not optional bazel. So this should
be a NOP for most folks, while enabling other to disable.
This only works with tests disabled. With tests enabled this still
compiles but tests fail as there is no lit config to disable tests that
depend on PDL rewrites yet.
---
mlir/CMakeLists.txt | 5 +-
mlir/examples/minimal-opt/README.md | 9 +-
mlir/include/mlir/Config/mlir-config.h.cmake | 3 +
.../Conversion/LLVMCommon/TypeConverter.h | 1 +
.../mlir/Dialect/Vector/IR/VectorOps.h | 1 +
mlir/include/mlir/IR/PDLPatternMatch.h.inc | 995 ++++++++++++++++++
mlir/include/mlir/IR/PatternMatch.h | 935 +---------------
.../mlir/Transforms/DialectConversion.h | 15 +
.../Bufferization/TransformOps/CMakeLists.txt | 1 -
mlir/lib/IR/CMakeLists.txt | 7 +
mlir/lib/IR/PDL/CMakeLists.txt | 7 +
mlir/lib/IR/PDL/PDLPatternMatch.cpp | 133 +++
mlir/lib/IR/PatternMatch.cpp | 119 +--
mlir/lib/Rewrite/ByteCode.h | 36 +
mlir/lib/Rewrite/CMakeLists.txt | 32 +-
mlir/lib/Rewrite/FrozenRewritePatternSet.cpp | 10 +-
mlir/lib/Rewrite/PatternApplicator.cpp | 4 +-
.../Transforms/Utils/DialectConversion.cpp | 3 +
mlir/test/CMakeLists.txt | 12 +-
mlir/test/lib/Transforms/CMakeLists.txt | 14 +-
mlir/tools/mlir-lsp-server/CMakeLists.txt | 6 +-
mlir/tools/mlir-opt/CMakeLists.txt | 8 +-
mlir/tools/mlir-opt/mlir-opt.cpp | 7 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 4 +
24 files changed, 1297 insertions(+), 1070 deletions(-)
create mode 100644 mlir/include/mlir/IR/PDLPatternMatch.h.inc
create mode 100644 mlir/lib/IR/PDL/CMakeLists.txt
create mode 100644 mlir/lib/IR/PDL/PDLPatternMatch.cpp
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index dcc068e4097c5c..3de8677aefe90b 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -133,6 +133,8 @@ set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
"Statically link the nvptxlibrary instead of calling ptxas as a subprocess \
for compiling PTX to cubin")
+set(MLIR_ENABLE_PDL_IN_PATTERNMATCH 1 CACHE BOOL "Enable PDL in PatternMatch")
+
option(MLIR_INCLUDE_TESTS
"Generate build targets for the MLIR unit tests."
${LLVM_INCLUDE_TESTS})
@@ -178,10 +180,9 @@ include_directories( ${MLIR_INCLUDE_DIR})
# Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like
# MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included
# from another directory like tools
-add_subdirectory(tools/mlir-tblgen)
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
-
+add_subdirectory(tools/mlir-tblgen)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_LINALG_ODS_YAML_GEN_TABLEGEN_EXE "${MLIR_LINALG_ODS_YAML_GEN_TABLEGEN_EXE}" CACHE INTERNAL "")
diff --git a/mlir/examples/minimal-opt/README.md b/mlir/examples/minimal-opt/README.md
index b8a455f7a79662..1bc54b8367cc5e 100644
--- a/mlir/examples/minimal-opt/README.md
+++ b/mlir/examples/minimal-opt/README.md
@@ -14,10 +14,10 @@ Below are some example measurements taken at the time of the LLVM 17 release,
using clang-14 on a X86 Ubuntu and [bloaty](https://github.com/google/bloaty).
| | Base | Os | Oz | Os LTO | Oz LTO |
-| :-----------------------------: | ------ | ------ | ------ | ------ | ------ |
-| `mlir-cat` | 1018kB | 836KB | 879KB | 697KB | 649KB |
-| `mlir-minimal-opt` | 1.54MB | 1.25MB | 1.29MB | 1.10MB | 1.00MB |
-| `mlir-minimal-opt-canonicalize` | 2.24MB | 1.81MB | 1.86MB | 1.62MB | 1.48MB |
+| :------------------------------: | ------ | ------ | ------ | ------ | ------ |
+| `mlir-cat` | 1024KB | 840KB | 885KB | 706KB | 657KB |
+| `mlir-minimal-opt` | 1.62MB | 1.32MB | 1.36MB | 1.17MB | 1.07MB |
+| `mlir-minimal-opt-canonicalize` | 1.83MB | 1.40MB | 1.45MB | 1.25MB | 1.14MB |
Base configuration:
@@ -32,6 +32,7 @@ cmake ../llvm/ -G Ninja \
-DCMAKE_CXX_COMPILER=clang++ \
-DLLVM_ENABLE_LLD=ON \
-DLLVM_ENABLE_BACKTRACES=OFF \
+ -DMLIR_ENABLE_PDL_IN_PATTERNMATCH=OFF \
-DCMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=-Wl,-icf=all
```
diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake
index efa77b2e5ce5db..e152a36c0ce0cf 100644
--- a/mlir/include/mlir/Config/mlir-config.h.cmake
+++ b/mlir/include/mlir/Config/mlir-config.h.cmake
@@ -26,4 +26,7 @@
numeric seed that is passed to the random number generator. */
#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}
+/* If set, enables PDL usage. */
+#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
#endif
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 74f9c977b70286..e228229302cff4 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index a28b27e4e15816..4603953cb40fa5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -29,6 +29,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
// Pull in all enum type definitions and utility function declarations.
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
new file mode 100644
index 00000000000000..a215da8cb6431d
--- /dev/null
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -0,0 +1,995 @@
+//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_PDLPATTERNMATCH_H
+#define MLIR_IR_PDLPATTERNMATCH_H
+
+#include "mlir/Config/mlir-config.h"
+
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// PDL Patterns
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+
+/// Storage type of byte-code interpreter values. These are passed to constraint
+/// functions as arguments.
+class PDLValue {
+public:
+ /// The underlying kind of a PDL value.
+ enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
+
+ /// Construct a new PDL value.
+ PDLValue(const PDLValue &other) = default;
+ PDLValue(std::nullptr_t = nullptr) {}
+ PDLValue(Attribute value)
+ : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
+ PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
+ PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
+ PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
+ PDLValue(Value value)
+ : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
+ PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
+
+ /// Returns true if the type of the held value is `T`.
+ template <typename T>
+ bool isa() const {
+ assert(value && "isa<> used on a null value");
+ return kind == getKindOf<T>();
+ }
+
+ /// Attempt to dynamically cast this value to type `T`, returns null if this
+ /// value is not an instance of `T`.
+ template <typename T,
+ typename ResultT = std::conditional_t<
+ std::is_convertible<T, bool>::value, T, std::optional<T>>>
+ ResultT dyn_cast() const {
+ return isa<T>() ? castImpl<T>() : ResultT();
+ }
+
+ /// Cast this value to type `T`, asserts if this value is not an instance of
+ /// `T`.
+ template <typename T>
+ T cast() const {
+ assert(isa<T>() && "expected value to be of type `T`");
+ return castImpl<T>();
+ }
+
+ /// Get an opaque pointer to the value.
+ const void *getAsOpaquePointer() const { return value; }
+
+ /// Return if this value is null or not.
+ explicit operator bool() const { return value; }
+
+ /// Return the kind of this value.
+ Kind getKind() const { return kind; }
+
+ /// Print this value to the provided output stream.
+ void print(raw_ostream &os) const;
+
+ /// Print the specified value kind to an output stream.
+ static void print(raw_ostream &os, Kind kind);
+
+private:
+ /// Find the index of a given type in a range of other types.
+ template <typename...>
+ struct index_of_t;
+ template <typename T, typename... R>
+ struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
+ template <typename T, typename F, typename... R>
+ struct index_of_t<T, F, R...>
+ : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
+
+ /// Return the kind used for the given T.
+ template <typename T>
+ static Kind getKindOf() {
+ return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
+ TypeRange, Value, ValueRange>::value);
+ }
+
+ /// The internal implementation of `cast`, that returns the underlying value
+ /// as the given type `T`.
+ template <typename T>
+ std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
+ castImpl() const {
+ return T::getFromOpaquePointer(value);
+ }
+ template <typename T>
+ std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
+ castImpl() const {
+ return *reinterpret_cast<T *>(const_cast<void *>(value));
+ }
+ template <typename T>
+ std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
+ return reinterpret_cast<T>(const_cast<void *>(value));
+ }
+
+ /// The internal opaque representation of a PDLValue.
+ const void *value{nullptr};
+ /// The kind of the opaque value.
+ Kind kind{Kind::Attribute};
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
+ value.print(os);
+ return os;
+}
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
+ PDLValue::print(os, kind);
+ return os;
+}
+
+//===----------------------------------------------------------------------===//
+// PDLResultList
+
+/// The class represents a list of PDL results, returned by a native rewrite
+/// method. It provides the mechanism with which to pass PDLValues back to the
+/// PDL bytecode.
+class PDLResultList {
+public:
+ /// Push a new Attribute value onto the result list.
+ void push_back(Attribute value) { results.push_back(value); }
+
+ /// Push a new Operation onto the result list.
+ void push_back(Operation *value) { results.push_back(value); }
+
+ /// Push a new Type onto the result list.
+ void push_back(Type value) { results.push_back(value); }
+
+ /// Push a new TypeRange onto the result list.
+ void push_back(TypeRange value) {
+ // The lifetime of a TypeRange can't be guaranteed, so we'll need to
+ // allocate a storage for it.
+ llvm::OwningArrayRef<Type> storage(value.size());
+ llvm::copy(value, storage.begin());
+ allocatedTypeRanges.emplace_back(std::move(storage));
+ typeRanges.push_back(allocatedTypeRanges.back());
+ results.push_back(&typeRanges.back());
+ }
+ void push_back(ValueTypeRange<OperandRange> value) {
+ typeRanges.push_back(value);
+ results.push_back(&typeRanges.back());
+ }
+ void push_back(ValueTypeRange<ResultRange> value) {
+ typeRanges.push_back(value);
+ results.push_back(&typeRanges.back());
+ }
+
+ /// Push a new Value onto the result list.
+ void push_back(Value value) { results.push_back(value); }
+
+ /// Push a new ValueRange onto the result list.
+ void push_back(ValueRange value) {
+ // The lifetime of a ValueRange can't be guaranteed, so we'll need to
+ // allocate a storage for it.
+ llvm::OwningArrayRef<Value> storage(value.size());
+ llvm::copy(value, storage.begin());
+ allocatedValueRanges.emplace_back(std::move(storage));
+ valueRanges.push_back(allocatedValueRanges.back());
+ results.push_back(&valueRanges.back());
+ }
+ void push_back(OperandRange value) {
+ valueRanges.push_back(value);
+ results.push_back(&valueRanges.back());
+ }
+ void push_back(ResultRange value) {
+ valueRanges.push_back(value);
+ results.push_back(&valueRanges.back());
+ }
+
+protected:
+ /// Create a new result list with the expected number of results.
+ PDLResultList(unsigned maxNumResults) {
+ // For now just reserve enough space for all of the results. We could do
+ // separate counts per range type, but it isn't really worth it unless there
+ // are a "large" number of results.
+ typeRanges.reserve(maxNumResults);
+ valueRanges.reserve(maxNumResults);
+ }
+
+ /// The PDL results held by this list.
+ SmallVector<PDLValue> results;
+ /// Memory used to store ranges held by the list.
+ SmallVector<TypeRange> typeRanges;
+ SmallVector<ValueRange> valueRanges;
+ /// Memory allocated to store ranges in the result list whose lifetime was
+ /// generated in the native function.
+ SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
+ SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternConfig
+
+/// An individual configuration for a pattern, which can be accessed by native
+/// functions via the PDLPatternConfigSet. This allows for injecting additional
+/// configuration into PDL patterns that is specific to certain compilation
+/// flows.
+class PDLPatternConfig {
+public:
+ virtual ~PDLPatternConfig() = default;
+
+ /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+ /// pattern. These can be used to setup any specific state necessary for the
+ /// rewrite.
+ virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
+ virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
+
+ /// Return the TypeID that represents this configuration.
+ TypeID getTypeID() const { return id; }
+
+protected:
+ PDLPatternConfig(TypeID id) : id(id) {}
+
+private:
+ TypeID id;
+};
+
+/// This class provides a base class for users implementing a type of pattern
+/// configuration.
+template <typename T>
+class PDLPatternConfigBase : public PDLPatternConfig {
+public:
+ /// Support LLVM style casting.
+ static bool classof(const PDLPatternConfig *config) {
+ return config->getTypeID() == getConfigID();
+ }
+
+ /// Return the type id used for this configuration.
+ static TypeID getConfigID() { return TypeID::get<T>(); }
+
+protected:
+ PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
+};
+
+/// This class contains a set of configurations for a specific pattern.
+/// Configurations are uniqued by TypeID, meaning that only one configuration of
+/// each type is allowed.
+class PDLPatternConfigSet {
+public:
+ PDLPatternConfigSet() = default;
+
+ /// Construct a set with the given configurations.
+ template <typename... ConfigsT>
+ PDLPatternConfigSet(ConfigsT &&...configs) {
+ (addConfig(std::forward<ConfigsT>(configs)), ...);
+ }
+
+ /// Get the configuration defined by the given type. Asserts that the
+ /// configuration of the provided type exists.
+ template <typename T>
+ const T &get() const {
+ const T *config = tryGet<T>();
+ assert(config && "configuration not found");
+ return *config;
+ }
+
+ /// Get the configuration defined by the given type, returns nullptr if the
+ /// configuration does not exist.
+ template <typename T>
+ const T *tryGet() const {
+ for (const auto &configIt : configs)
+ if (const T *config = dyn_cast<T>(configIt.get()))
+ return config;
+ return nullptr;
+ }
+
+ /// Notify the configurations within this set at the beginning or end of a
+ /// rewrite of a matched pattern.
+ void notifyRewriteBegin(PatternRewriter &rewriter) {
+ for (const auto &config : configs)
+ config->notifyRewriteBegin(rewriter);
+ }
+ void notifyRewriteEnd(PatternRewriter &rewriter) {
+ for (const auto &config : configs)
+ config->notifyRewriteEnd(rewriter);
+ }
+
+protected:
+ /// Add a configuration to the set.
+ template <typename T>
+ void addConfig(T &&config) {
+ assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
+ configs.emplace_back(
+ std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
+ }
+
+ /// The set of configurations for this pattern. This uses a vector instead of
+ /// a map with the expectation that the number of configurations per set is
+ /// small (<= 1).
+ SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given set of opaque PDLValue entities. Returns success if
+/// the constraint successfully held, failure otherwise.
+using PDLConstraintFunction =
+ std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+/// A native PDL rewrite function. This function performs a rewrite on the
+/// given set of values. Any results from this rewrite that should be passed
+/// back to PDL should be added to the provided result list. This method is only
+/// invoked when the corresponding match was successful. Returns failure if an
+/// invariant of the rewrite was broken (certain rewriters may recover from
+/// partial pattern application).
+using PDLRewriteFunction = std::function<LogicalResult(
+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
+namespace detail {
+namespace pdl_function_builder {
+/// A utility variable that always resolves to false. This is useful for static
+/// asserts that are always false, but only should fire in certain templated
+/// constructs. For example, if a templated function should never be called, the
+/// function could be defined as:
+///
+/// template <typename T>
+/// void foo() {
+/// static_assert(always_false<T>, "This function should never be called");
+/// }
+///
+template <class... T>
+constexpr bool always_false = false;
+
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Type Processing
+//===----------------------------------------------------------------------===//
+
+/// This struct provides a convenient way to determine how to process a given
+/// type as either a PDL parameter, or a result value. This allows for
+/// supporting complex types in constraint and rewrite functions, without
+/// requiring the user to hand-write the necessary glue code themselves.
+/// Specializations of this class should implement the following methods to
+/// enable support as a PDL argument or result type:
+///
+/// static LogicalResult verifyAsArg(
+/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
+/// size_t argIdx);
+///
+/// * This method verifies that the given PDLValue is valid for use as a
+/// value of `T`.
+///
+/// static T processAsArg(PDLValue pdlValue);
+///
+/// * This method processes the given PDLValue as a value of `T`.
+///
+/// static void processAsResult(PatternRewriter &, PDLResultList &results,
+/// const T &value);
+///
+/// * This method processes the given value of `T` as the result of a
+/// function invocation. The method should package the value into an
+/// appropriate form and append it to the given result list.
+///
+/// If the type `T` is based on a higher order value, consider using
+/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
+/// the implementation.
+///
+template <typename T, typename Enable = void>
+struct ProcessPDLValue;
+
+/// This struct provides a simplified model for processing types that are based
+/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
+/// allows for building the necessary processing functions on top of the base
+/// value instead of a PDLValue. Derived users should implement the following
+/// (which subsume the ProcessPDLValue variants):
+///
+/// static LogicalResult verifyAsArg(
+/// function_ref<LogicalResult(const Twine &)> errorFn,
+/// const BaseT &baseValue, size_t argIdx);
+///
+/// * This method verifies that the given PDLValue is valid for use as a
+/// value of `T`.
+///
+/// static T processAsArg(BaseT baseValue);
+///
+/// * This method processes the given base value as a value of `T`.
+///
+template <typename T, typename BaseT>
+struct ProcessPDLValueBasedOn {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ PDLValue pdlValue, size_t argIdx) {
+ // Verify the base class before continuing.
+ if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
+ return failure();
+ return ProcessPDLValue<T>::verifyAsArg(
+ errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
+ }
+ static T processAsArg(PDLValue pdlValue) {
+ return ProcessPDLValue<T>::processAsArg(
+ ProcessPDLValue<BaseT>::processAsArg(pdlValue));
+ }
+
+ /// Explicitly add the expected parent API to ensure the parent class
+ /// implements the necessary API (and doesn't implicitly inherit it from
+ /// somewhere else).
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
+ size_t argIdx) {
+ return success();
+ }
+ static T processAsArg(BaseT baseValue);
+};
+
+/// This struct provides a simplified model for processing types that have
+/// "builtin" PDLValue support:
+/// * Attribute, Operation *, Type, TypeRange, ValueRange
+template <typename T>
+struct ProcessBuiltinPDLValue {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ PDLValue pdlValue, size_t argIdx) {
+ if (pdlValue)
+ return success();
+ return errorFn("expected a non-null value for argument " + Twine(argIdx) +
+ " of type: " + llvm::getTypeName<T>());
+ }
+
+ static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ T value) {
+ results.push_back(value);
+ }
+};
+
+/// This struct provides a simplified model for processing types that inherit
+/// from builtin PDLValue types. For example, derived attributes like
+/// IntegerAttr, derived types like IntegerType, derived operations like
+/// ModuleOp, Interfaces, etc.
+template <typename T, typename BaseT>
+struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ BaseT baseValue, size_t argIdx) {
+ return TypeSwitch<BaseT, LogicalResult>(baseValue)
+ .Case([&](T) { return success(); })
+ .Default([&](BaseT) {
+ return errorFn("expected argument " + Twine(argIdx) +
+ " to be of type: " + llvm::getTypeName<T>());
+ });
+ }
+ using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
+
+ static T processAsArg(BaseT baseValue) {
+ return baseValue.template cast<T>();
+ }
+ using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
+
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ T value) {
+ results.push_back(value);
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Attribute
+
+template <>
+struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
+template <typename T>
+struct ProcessPDLValue<T,
+ std::enable_if_t<std::is_base_of<Attribute, T>::value>>
+ : public ProcessDerivedPDLValue<T, Attribute> {};
+
+/// Handling for various Attribute value types.
+template <>
+struct ProcessPDLValue<StringRef>
+ : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
+ static StringRef processAsArg(StringAttr value) { return value.getValue(); }
+ using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
+
+ static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+ StringRef value) {
+ results.push_back(rewriter.getStringAttr(value));
+ }
+};
+template <>
+struct ProcessPDLValue<std::string>
+ : public ProcessPDLValueBasedOn<std::string, StringAttr> {
+ template <typename T>
+ static std::string processAsArg(T value) {
+ static_assert(always_false<T>,
+ "`std::string` arguments require a string copy, use "
+ "`StringRef` for string-like arguments instead");
+ return {};
+ }
+ static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+ StringRef value) {
+ results.push_back(rewriter.getStringAttr(value));
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Operation
+
+template <>
+struct ProcessPDLValue<Operation *>
+ : public ProcessBuiltinPDLValue<Operation *> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
+ : public ProcessDerivedPDLValue<T, Operation *> {
+ static T processAsArg(Operation *value) { return cast<T>(value); }
+};
+
+//===----------------------------------------------------------------------===//
+// Type
+
+template <>
+struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
+ : public ProcessDerivedPDLValue<T, Type> {};
+
+//===----------------------------------------------------------------------===//
+// TypeRange
+
+template <>
+struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
+template <>
+struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ValueTypeRange<OperandRange> types) {
+ results.push_back(types);
+ }
+};
+template <>
+struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ValueTypeRange<ResultRange> types) {
+ results.push_back(types);
+ }
+};
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Type, N>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ SmallVector<Type, N> values) {
+ results.push_back(TypeRange(values));
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Value
+
+template <>
+struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
+
+//===----------------------------------------------------------------------===//
+// ValueRange
+
+template <>
+struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
+};
+template <>
+struct ProcessPDLValue<OperandRange> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ OperandRange values) {
+ results.push_back(values);
+ }
+};
+template <>
+struct ProcessPDLValue<ResultRange> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ResultRange values) {
+ results.push_back(values);
+ }
+};
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Value, N>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ SmallVector<Value, N> values) {
+ results.push_back(ValueRange(values));
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Argument Handling
+//===----------------------------------------------------------------------===//
+
+/// Validate the given PDLValues match the constraints defined by the argument
+/// types of the given function. In the case of failure, a match failure
+/// diagnostic is emitted.
+/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
+/// does not currently preserve Constraint application ordering.
+template <typename PDLFnT, std::size_t... I>
+LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ using FnTraitsT = llvm::function_traits<PDLFnT>;
+
+ auto errorFn = [&](const Twine &msg) {
+ return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
+ };
+ return success(
+ (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+ verifyAsArg(errorFn, values[I], I)) &&
+ ...));
+}
+
+/// Assert that the given PDLValues match the constraints defined by the
+/// arguments of the given function. In the case of failure, a fatal error
+/// is emitted.
+template <typename PDLFnT, std::size_t... I>
+void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ // We only want to do verification in debug builds, same as with `assert`.
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ using FnTraitsT = llvm::function_traits<PDLFnT>;
+ auto errorFn = [&](const Twine &msg) -> LogicalResult {
+ llvm::report_fatal_error(msg);
+ };
+ (void)errorFn;
+ assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+ verifyAsArg(errorFn, values[I], I)) &&
+ ...));
+#endif
+ (void)values;
+}
+
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Results Handling
+//===----------------------------------------------------------------------===//
+
+/// Store a single result within the result list.
+template <typename T>
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results, T &&value) {
+ ProcessPDLValue<T>::processAsResult(rewriter, results,
+ std::forward<T>(value));
+ return success();
+}
+
+/// Store a std::pair<> as individual results within the result list.
+template <typename T1, typename T2>
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ std::pair<T1, T2> &&pair) {
+ if (failed(processResults(rewriter, results, std::move(pair.first))) ||
+ failed(processResults(rewriter, results, std::move(pair.second))))
+ return failure();
+ return success();
+}
+
+/// Store a std::tuple<> as individual results within the result list.
+template <typename... Ts>
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ std::tuple<Ts...> &&tuple) {
+ auto applyFn = [&](auto &&...args) {
+ return (succeeded(processResults(rewriter, results, std::move(args))) &&
+ ...);
+ };
+ return success(std::apply(applyFn, std::move(tuple)));
+}
+
+/// Handle LogicalResult propagation.
+inline LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ LogicalResult &&result) {
+ return result;
+}
+template <typename T>
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ FailureOr<T> &&result) {
+ if (failed(result))
+ return failure();
+ return processResults(rewriter, results, std::move(*result));
+}
+
+//===----------------------------------------------------------------------===//
+// PDL Constraint Builder
+//===----------------------------------------------------------------------===//
+
+/// Process the arguments of a native constraint and invoke it.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+typename FnTraitsT::result_t
+processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
+ ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ return fn(
+ rewriter,
+ (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+ values[I]))...);
+}
+
+/// Build a constraint function from the given function `ConstraintFnT`. This
+/// allows for enabling the user to define simpler, more direct constraint
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the constraint function is already in the correct form, we just forward
+/// it directly.
+template <typename ConstraintFnT>
+std::enable_if_t<
+ std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+ PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+ return std::forward<ConstraintFnT>(constraintFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename ConstraintFnT>
+std::enable_if_t<
+ !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+ PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+ return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
+ PatternRewriter &rewriter,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ auto argIndices = std::make_index_sequence<
+ llvm::function_traits<ConstraintFnT>::num_args - 1>();
+ if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
+ return failure();
+ return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
+ argIndices);
+ };
+}
+
+//===----------------------------------------------------------------------===//
+// PDL Rewrite Builder
+//===----------------------------------------------------------------------===//
+
+/// Process the arguments of a native rewrite and invoke it.
+/// This overload handles the case of no return values.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
+ LogicalResult>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+ PDLResultList &, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ fn(rewriter,
+ (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+ values[I]))...);
+ return success();
+}
+/// This overload handles the case of return values, which need to be packaged
+/// into the result list.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
+ LogicalResult>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+ PDLResultList &results, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ return processResults(
+ rewriter, results,
+ fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+ processAsArg(values[I]))...));
+ (void)values;
+}
+
+/// Build a rewrite function from the given function `RewriteFnT`. This
+/// allows for enabling the user to define simpler, more direct rewrite
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the rewrite function is already in the correct form, we just forward
+/// it directly.
+template <typename RewriteFnT>
+std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+ PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+ return std::forward<RewriteFnT>(rewriteFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename RewriteFnT>
+std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+ PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+ return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
+ PatternRewriter &rewriter, PDLResultList &results,
+ ArrayRef<PDLValue> values) {
+ auto argIndices =
+ std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
+ 1>();
+ assertArgs<RewriteFnT>(rewriter, values, argIndices);
+ return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
+ argIndices);
+ };
+}
+
+} // namespace pdl_function_builder
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// This class contains all of the necessary data for a set of PDL patterns, or
+/// pattern rewrites specified in the form of the PDL dialect. This PDL module
+/// contained by this pattern may contain any number of `pdl.pattern`
+/// operations.
+class PDLPatternModule {
+public:
+ PDLPatternModule() = default;
+
+ /// Construct a PDL pattern with the given module and configurations.
+ PDLPatternModule(OwningOpRef<ModuleOp> module)
+ : pdlModule(std::move(module)) {}
+ template <typename... ConfigsT>
+ PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
+ : PDLPatternModule(std::move(module)) {
+ auto configSet = std::make_unique<PDLPatternConfigSet>(
+ std::forward<ConfigsT>(patternConfigs)...);
+ attachConfigToPatterns(*pdlModule, *configSet);
+ configs.emplace_back(std::move(configSet));
+ }
+
+ /// Merge the state in `other` into this pattern module.
+ void mergeIn(PDLPatternModule &&other);
+
+ /// Return the internal PDL module of this pattern.
+ ModuleOp getModule() { return pdlModule.get(); }
+
+ /// Return the MLIR context of this pattern.
+ MLIRContext *getContext() { return getModule()->getContext(); }
+
+ //===--------------------------------------------------------------------===//
+ // Function Registry
+
+ /// Register a constraint function with PDL. A constraint function may be
+ /// specified in one of two ways:
+ ///
+ /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+ ///
+ /// In this overload the arguments of the constraint function are passed via
+ /// the low-level PDLValue form.
+ ///
+ /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
+ ///
+ /// In this form the arguments of the constraint function are passed via the
+ /// expected high level C++ type. In this form, the framework will
+ /// automatically unwrap PDLValues and convert them to the expected ValueTs.
+ /// For example, if the constraint function accepts a `Operation *`, the
+ /// framework will automatically cast the input PDLValue. In the case of a
+ /// `StringRef`, the framework will automatically unwrap the argument as a
+ /// StringAttr and pass the underlying string value. To see the full list of
+ /// supported types, or to see how to add handling for custom types, view
+ /// the definition of `ProcessPDLValue` above.
+ void registerConstraintFunction(StringRef name,
+ PDLConstraintFunction constraintFn);
+ template <typename ConstraintFnT>
+ void registerConstraintFunction(StringRef name,
+ ConstraintFnT &&constraintFn) {
+ registerConstraintFunction(name,
+ detail::pdl_function_builder::buildConstraintFn(
+ std::forward<ConstraintFnT>(constraintFn)));
+ }
+
+ /// Register a rewrite function with PDL. A rewrite function may be specified
+ /// in one of two ways:
+ ///
+ /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
+ ///
+ /// In this overload the arguments of the constraint function are passed via
+ /// the low-level PDLValue form, and the results are manually appended to
+ /// the given result list.
+ ///
+ /// * `ResultT (PatternRewriter &, ValueTs... values)`
+ ///
+ /// In this form the arguments and result of the rewrite function are passed
+ /// via the expected high level C++ type. In this form, the framework will
+ /// automatically unwrap the PDLValues arguments and convert them to the
+ /// expected ValueTs. It will also automatically handle the processing and
+ /// packaging of the result value to the result list. For example, if the
+ /// rewrite function takes a `Operation *`, the framework will automatically
+ /// cast the input PDLValue. In the case of a `StringRef`, the framework
+ /// will automatically unwrap the argument as a StringAttr and pass the
+ /// underlying string value. In the reverse case, if the rewrite returns a
+ /// StringRef or std::string, it will automatically package this as a
+ /// StringAttr and append it to the result list. To see the full list of
+ /// supported types, or to see how to add handling for custom types, view
+ /// the definition of `ProcessPDLValue` above.
+ void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+ template <typename RewriteFnT>
+ void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
+ registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
+ std::forward<RewriteFnT>(rewriteFn)));
+ }
+
+ /// Return the set of the registered constraint functions.
+ const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+ return constraintFunctions;
+ }
+ llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
+ return constraintFunctions;
+ }
+ /// Return the set of the registered rewrite functions.
+ const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
+ return rewriteFunctions;
+ }
+ llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
+ return rewriteFunctions;
+ }
+
+ /// Return the set of the registered pattern configs.
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
+ return std::move(configs);
+ }
+ DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
+ return std::move(configMap);
+ }
+
+ /// Clear out the patterns and functions within this module.
+ void clear() {
+ pdlModule = nullptr;
+ constraintFunctions.clear();
+ rewriteFunctions.clear();
+ }
+
+private:
+ /// Attach the given pattern config set to the patterns defined within the
+ /// given module.
+ void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
+
+ /// The module containing the `pdl.pattern` operations.
+ OwningOpRef<ModuleOp> pdlModule;
+
+ /// The set of configuration sets referenced by patterns within `pdlModule`.
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
+ DenseMap<Operation *, PDLPatternConfigSet *> configMap;
+
+ /// The external functions referenced from within the PDL module.
+ llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+ llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
+};
+} // namespace mlir
+
+#else
+
+namespace mlir {
+// Stubs for when PDL in pattern rewrites is not enabled.
+
+class PDLValue {
+public:
+ template <typename T>
+ T dyn_cast() const {
+ return nullptr;
+ }
+};
+class PDLResultList {};
+using PDLConstraintFunction =
+ std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLRewriteFunction = std::function<LogicalResult(
+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
+class PDLPatternModule {
+public:
+ PDLPatternModule() = default;
+
+ PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
+ MLIRContext *getContext() {
+ llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
+ }
+ void mergeIn(PDLPatternModule &&other) {}
+ void clear() {}
+ template <typename ConstraintFnT>
+ void registerConstraintFunction(StringRef name,
+ ConstraintFnT &&constraintFn) {}
+ void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
+ template <typename RewriteFnT>
+ void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
+ const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+ return constraintFunctions;
+ }
+
+private:
+ llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+};
+
+} // namespace mlir
+#endif
+
+#endif // MLIR_IR_PDLPATTERNMATCH_H
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 6625ef553eba21..9b4fa65bff49e1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -735,932 +735,12 @@ class PatternRewriter : public RewriterBase {
virtual bool canRecoverFromRewriteFailure() const { return false; }
};
-//===----------------------------------------------------------------------===//
-// PDL Patterns
-//===----------------------------------------------------------------------===//
-
-//===----------------------------------------------------------------------===//
-// PDLValue
-
-/// Storage type of byte-code interpreter values. These are passed to constraint
-/// functions as arguments.
-class PDLValue {
-public:
- /// The underlying kind of a PDL value.
- enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
-
- /// Construct a new PDL value.
- PDLValue(const PDLValue &other) = default;
- PDLValue(std::nullptr_t = nullptr) {}
- PDLValue(Attribute value)
- : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
- PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
- PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
- PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
- PDLValue(Value value)
- : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
- PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
-
- /// Returns true if the type of the held value is `T`.
- template <typename T>
- bool isa() const {
- assert(value && "isa<> used on a null value");
- return kind == getKindOf<T>();
- }
-
- /// Attempt to dynamically cast this value to type `T`, returns null if this
- /// value is not an instance of `T`.
- template <typename T,
- typename ResultT = std::conditional_t<
- std::is_convertible<T, bool>::value, T, std::optional<T>>>
- ResultT dyn_cast() const {
- return isa<T>() ? castImpl<T>() : ResultT();
- }
-
- /// Cast this value to type `T`, asserts if this value is not an instance of
- /// `T`.
- template <typename T>
- T cast() const {
- assert(isa<T>() && "expected value to be of type `T`");
- return castImpl<T>();
- }
-
- /// Get an opaque pointer to the value.
- const void *getAsOpaquePointer() const { return value; }
-
- /// Return if this value is null or not.
- explicit operator bool() const { return value; }
-
- /// Return the kind of this value.
- Kind getKind() const { return kind; }
-
- /// Print this value to the provided output stream.
- void print(raw_ostream &os) const;
-
- /// Print the specified value kind to an output stream.
- static void print(raw_ostream &os, Kind kind);
-
-private:
- /// Find the index of a given type in a range of other types.
- template <typename...>
- struct index_of_t;
- template <typename T, typename... R>
- struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
- template <typename T, typename F, typename... R>
- struct index_of_t<T, F, R...>
- : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
-
- /// Return the kind used for the given T.
- template <typename T>
- static Kind getKindOf() {
- return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
- TypeRange, Value, ValueRange>::value);
- }
-
- /// The internal implementation of `cast`, that returns the underlying value
- /// as the given type `T`.
- template <typename T>
- std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
- castImpl() const {
- return T::getFromOpaquePointer(value);
- }
- template <typename T>
- std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
- castImpl() const {
- return *reinterpret_cast<T *>(const_cast<void *>(value));
- }
- template <typename T>
- std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
- return reinterpret_cast<T>(const_cast<void *>(value));
- }
-
- /// The internal opaque representation of a PDLValue.
- const void *value{nullptr};
- /// The kind of the opaque value.
- Kind kind{Kind::Attribute};
-};
-
-inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
- value.print(os);
- return os;
-}
-
-inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
- PDLValue::print(os, kind);
- return os;
-}
-
-//===----------------------------------------------------------------------===//
-// PDLResultList
-
-/// The class represents a list of PDL results, returned by a native rewrite
-/// method. It provides the mechanism with which to pass PDLValues back to the
-/// PDL bytecode.
-class PDLResultList {
-public:
- /// Push a new Attribute value onto the result list.
- void push_back(Attribute value) { results.push_back(value); }
-
- /// Push a new Operation onto the result list.
- void push_back(Operation *value) { results.push_back(value); }
-
- /// Push a new Type onto the result list.
- void push_back(Type value) { results.push_back(value); }
-
- /// Push a new TypeRange onto the result list.
- void push_back(TypeRange value) {
- // The lifetime of a TypeRange can't be guaranteed, so we'll need to
- // allocate a storage for it.
- llvm::OwningArrayRef<Type> storage(value.size());
- llvm::copy(value, storage.begin());
- allocatedTypeRanges.emplace_back(std::move(storage));
- typeRanges.push_back(allocatedTypeRanges.back());
- results.push_back(&typeRanges.back());
- }
- void push_back(ValueTypeRange<OperandRange> value) {
- typeRanges.push_back(value);
- results.push_back(&typeRanges.back());
- }
- void push_back(ValueTypeRange<ResultRange> value) {
- typeRanges.push_back(value);
- results.push_back(&typeRanges.back());
- }
-
- /// Push a new Value onto the result list.
- void push_back(Value value) { results.push_back(value); }
-
- /// Push a new ValueRange onto the result list.
- void push_back(ValueRange value) {
- // The lifetime of a ValueRange can't be guaranteed, so we'll need to
- // allocate a storage for it.
- llvm::OwningArrayRef<Value> storage(value.size());
- llvm::copy(value, storage.begin());
- allocatedValueRanges.emplace_back(std::move(storage));
- valueRanges.push_back(allocatedValueRanges.back());
- results.push_back(&valueRanges.back());
- }
- void push_back(OperandRange value) {
- valueRanges.push_back(value);
- results.push_back(&valueRanges.back());
- }
- void push_back(ResultRange value) {
- valueRanges.push_back(value);
- results.push_back(&valueRanges.back());
- }
-
-protected:
- /// Create a new result list with the expected number of results.
- PDLResultList(unsigned maxNumResults) {
- // For now just reserve enough space for all of the results. We could do
- // separate counts per range type, but it isn't really worth it unless there
- // are a "large" number of results.
- typeRanges.reserve(maxNumResults);
- valueRanges.reserve(maxNumResults);
- }
-
- /// The PDL results held by this list.
- SmallVector<PDLValue> results;
- /// Memory used to store ranges held by the list.
- SmallVector<TypeRange> typeRanges;
- SmallVector<ValueRange> valueRanges;
- /// Memory allocated to store ranges in the result list whose lifetime was
- /// generated in the native function.
- SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
- SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
-};
-
-//===----------------------------------------------------------------------===//
-// PDLPatternConfig
-
-/// An individual configuration for a pattern, which can be accessed by native
-/// functions via the PDLPatternConfigSet. This allows for injecting additional
-/// configuration into PDL patterns that is specific to certain compilation
-/// flows.
-class PDLPatternConfig {
-public:
- virtual ~PDLPatternConfig() = default;
-
- /// Hooks that are invoked at the beginning and end of a rewrite of a matched
- /// pattern. These can be used to setup any specific state necessary for the
- /// rewrite.
- virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
- virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
-
- /// Return the TypeID that represents this configuration.
- TypeID getTypeID() const { return id; }
-
-protected:
- PDLPatternConfig(TypeID id) : id(id) {}
-
-private:
- TypeID id;
-};
-
-/// This class provides a base class for users implementing a type of pattern
-/// configuration.
-template <typename T>
-class PDLPatternConfigBase : public PDLPatternConfig {
-public:
- /// Support LLVM style casting.
- static bool classof(const PDLPatternConfig *config) {
- return config->getTypeID() == getConfigID();
- }
-
- /// Return the type id used for this configuration.
- static TypeID getConfigID() { return TypeID::get<T>(); }
-
-protected:
- PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
-};
-
-/// This class contains a set of configurations for a specific pattern.
-/// Configurations are uniqued by TypeID, meaning that only one configuration of
-/// each type is allowed.
-class PDLPatternConfigSet {
-public:
- PDLPatternConfigSet() = default;
-
- /// Construct a set with the given configurations.
- template <typename... ConfigsT>
- PDLPatternConfigSet(ConfigsT &&...configs) {
- (addConfig(std::forward<ConfigsT>(configs)), ...);
- }
-
- /// Get the configuration defined by the given type. Asserts that the
- /// configuration of the provided type exists.
- template <typename T>
- const T &get() const {
- const T *config = tryGet<T>();
- assert(config && "configuration not found");
- return *config;
- }
-
- /// Get the configuration defined by the given type, returns nullptr if the
- /// configuration does not exist.
- template <typename T>
- const T *tryGet() const {
- for (const auto &configIt : configs)
- if (const T *config = dyn_cast<T>(configIt.get()))
- return config;
- return nullptr;
- }
-
- /// Notify the configurations within this set at the beginning or end of a
- /// rewrite of a matched pattern.
- void notifyRewriteBegin(PatternRewriter &rewriter) {
- for (const auto &config : configs)
- config->notifyRewriteBegin(rewriter);
- }
- void notifyRewriteEnd(PatternRewriter &rewriter) {
- for (const auto &config : configs)
- config->notifyRewriteEnd(rewriter);
- }
-
-protected:
- /// Add a configuration to the set.
- template <typename T>
- void addConfig(T &&config) {
- assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
- configs.emplace_back(
- std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
- }
-
- /// The set of configurations for this pattern. This uses a vector instead of
- /// a map with the expectation that the number of configurations per set is
- /// small (<= 1).
- SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
-};
-
-//===----------------------------------------------------------------------===//
-// PDLPatternModule
-
-/// A generic PDL pattern constraint function. This function applies a
-/// constraint to a given set of opaque PDLValue entities. Returns success if
-/// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction =
- std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
-/// A native PDL rewrite function. This function performs a rewrite on the
-/// given set of values. Any results from this rewrite that should be passed
-/// back to PDL should be added to the provided result list. This method is only
-/// invoked when the corresponding match was successful. Returns failure if an
-/// invariant of the rewrite was broken (certain rewriters may recover from
-/// partial pattern application).
-using PDLRewriteFunction = std::function<LogicalResult(
- PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
-
-namespace detail {
-namespace pdl_function_builder {
-/// A utility variable that always resolves to false. This is useful for static
-/// asserts that are always false, but only should fire in certain templated
-/// constructs. For example, if a templated function should never be called, the
-/// function could be defined as:
-///
-/// template <typename T>
-/// void foo() {
-/// static_assert(always_false<T>, "This function should never be called");
-/// }
-///
-template <class... T>
-constexpr bool always_false = false;
-
-//===----------------------------------------------------------------------===//
-// PDL Function Builder: Type Processing
-//===----------------------------------------------------------------------===//
-
-/// This struct provides a convenient way to determine how to process a given
-/// type as either a PDL parameter, or a result value. This allows for
-/// supporting complex types in constraint and rewrite functions, without
-/// requiring the user to hand-write the necessary glue code themselves.
-/// Specializations of this class should implement the following methods to
-/// enable support as a PDL argument or result type:
-///
-/// static LogicalResult verifyAsArg(
-/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
-/// size_t argIdx);
-///
-/// * This method verifies that the given PDLValue is valid for use as a
-/// value of `T`.
-///
-/// static T processAsArg(PDLValue pdlValue);
-///
-/// * This method processes the given PDLValue as a value of `T`.
-///
-/// static void processAsResult(PatternRewriter &, PDLResultList &results,
-/// const T &value);
-///
-/// * This method processes the given value of `T` as the result of a
-/// function invocation. The method should package the value into an
-/// appropriate form and append it to the given result list.
-///
-/// If the type `T` is based on a higher order value, consider using
-/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
-/// the implementation.
-///
-template <typename T, typename Enable = void>
-struct ProcessPDLValue;
-
-/// This struct provides a simplified model for processing types that are based
-/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
-/// allows for building the necessary processing functions on top of the base
-/// value instead of a PDLValue. Derived users should implement the following
-/// (which subsume the ProcessPDLValue variants):
-///
-/// static LogicalResult verifyAsArg(
-/// function_ref<LogicalResult(const Twine &)> errorFn,
-/// const BaseT &baseValue, size_t argIdx);
-///
-/// * This method verifies that the given PDLValue is valid for use as a
-/// value of `T`.
-///
-/// static T processAsArg(BaseT baseValue);
-///
-/// * This method processes the given base value as a value of `T`.
-///
-template <typename T, typename BaseT>
-struct ProcessPDLValueBasedOn {
- static LogicalResult
- verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
- PDLValue pdlValue, size_t argIdx) {
- // Verify the base class before continuing.
- if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
- return failure();
- return ProcessPDLValue<T>::verifyAsArg(
- errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
- }
- static T processAsArg(PDLValue pdlValue) {
- return ProcessPDLValue<T>::processAsArg(
- ProcessPDLValue<BaseT>::processAsArg(pdlValue));
- }
-
- /// Explicitly add the expected parent API to ensure the parent class
- /// implements the necessary API (and doesn't implicitly inherit it from
- /// somewhere else).
- static LogicalResult
- verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
- size_t argIdx) {
- return success();
- }
- static T processAsArg(BaseT baseValue);
-};
-
-/// This struct provides a simplified model for processing types that have
-/// "builtin" PDLValue support:
-/// * Attribute, Operation *, Type, TypeRange, ValueRange
-template <typename T>
-struct ProcessBuiltinPDLValue {
- static LogicalResult
- verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
- PDLValue pdlValue, size_t argIdx) {
- if (pdlValue)
- return success();
- return errorFn("expected a non-null value for argument " + Twine(argIdx) +
- " of type: " + llvm::getTypeName<T>());
- }
-
- static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- T value) {
- results.push_back(value);
- }
-};
-
-/// This struct provides a simplified model for processing types that inherit
-/// from builtin PDLValue types. For example, derived attributes like
-/// IntegerAttr, derived types like IntegerType, derived operations like
-/// ModuleOp, Interfaces, etc.
-template <typename T, typename BaseT>
-struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
- static LogicalResult
- verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
- BaseT baseValue, size_t argIdx) {
- return TypeSwitch<BaseT, LogicalResult>(baseValue)
- .Case([&](T) { return success(); })
- .Default([&](BaseT) {
- return errorFn("expected argument " + Twine(argIdx) +
- " to be of type: " + llvm::getTypeName<T>());
- });
- }
- using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
-
- static T processAsArg(BaseT baseValue) {
- return baseValue.template cast<T>();
- }
- using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
-
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- T value) {
- results.push_back(value);
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Attribute
-
-template <>
-struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
-template <typename T>
-struct ProcessPDLValue<T,
- std::enable_if_t<std::is_base_of<Attribute, T>::value>>
- : public ProcessDerivedPDLValue<T, Attribute> {};
-
-/// Handling for various Attribute value types.
-template <>
-struct ProcessPDLValue<StringRef>
- : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
- static StringRef processAsArg(StringAttr value) { return value.getValue(); }
- using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
-
- static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
- StringRef value) {
- results.push_back(rewriter.getStringAttr(value));
- }
-};
-template <>
-struct ProcessPDLValue<std::string>
- : public ProcessPDLValueBasedOn<std::string, StringAttr> {
- template <typename T>
- static std::string processAsArg(T value) {
- static_assert(always_false<T>,
- "`std::string` arguments require a string copy, use "
- "`StringRef` for string-like arguments instead");
- return {};
- }
- static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
- StringRef value) {
- results.push_back(rewriter.getStringAttr(value));
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Operation
-
-template <>
-struct ProcessPDLValue<Operation *>
- : public ProcessBuiltinPDLValue<Operation *> {};
-template <typename T>
-struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
- : public ProcessDerivedPDLValue<T, Operation *> {
- static T processAsArg(Operation *value) { return cast<T>(value); }
-};
-
-//===----------------------------------------------------------------------===//
-// Type
-
-template <>
-struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
-template <typename T>
-struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
- : public ProcessDerivedPDLValue<T, Type> {};
-
-//===----------------------------------------------------------------------===//
-// TypeRange
-
-template <>
-struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
-template <>
-struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- ValueTypeRange<OperandRange> types) {
- results.push_back(types);
- }
-};
-template <>
-struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- ValueTypeRange<ResultRange> types) {
- results.push_back(types);
- }
-};
-template <unsigned N>
-struct ProcessPDLValue<SmallVector<Type, N>> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- SmallVector<Type, N> values) {
- results.push_back(TypeRange(values));
- }
-};
-
-//===----------------------------------------------------------------------===//
-// Value
-
-template <>
-struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
-
-//===----------------------------------------------------------------------===//
-// ValueRange
-
-template <>
-struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
-};
-template <>
-struct ProcessPDLValue<OperandRange> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- OperandRange values) {
- results.push_back(values);
- }
-};
-template <>
-struct ProcessPDLValue<ResultRange> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- ResultRange values) {
- results.push_back(values);
- }
-};
-template <unsigned N>
-struct ProcessPDLValue<SmallVector<Value, N>> {
- static void processAsResult(PatternRewriter &, PDLResultList &results,
- SmallVector<Value, N> values) {
- results.push_back(ValueRange(values));
- }
-};
-
-//===----------------------------------------------------------------------===//
-// PDL Function Builder: Argument Handling
-//===----------------------------------------------------------------------===//
-
-/// Validate the given PDLValues match the constraints defined by the argument
-/// types of the given function. In the case of failure, a match failure
-/// diagnostic is emitted.
-/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
-/// does not currently preserve Constraint application ordering.
-template <typename PDLFnT, std::size_t... I>
-LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
- std::index_sequence<I...>) {
- using FnTraitsT = llvm::function_traits<PDLFnT>;
-
- auto errorFn = [&](const Twine &msg) {
- return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
- };
- return success(
- (succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
- verifyAsArg(errorFn, values[I], I)) &&
- ...));
-}
-
-/// Assert that the given PDLValues match the constraints defined by the
-/// arguments of the given function. In the case of failure, a fatal error
-/// is emitted.
-template <typename PDLFnT, std::size_t... I>
-void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
- std::index_sequence<I...>) {
- // We only want to do verification in debug builds, same as with `assert`.
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- using FnTraitsT = llvm::function_traits<PDLFnT>;
- auto errorFn = [&](const Twine &msg) -> LogicalResult {
- llvm::report_fatal_error(msg);
- };
- (void)errorFn;
- assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
- verifyAsArg(errorFn, values[I], I)) &&
- ...));
-#endif
- (void)values;
-}
-
-//===----------------------------------------------------------------------===//
-// PDL Function Builder: Results Handling
-//===----------------------------------------------------------------------===//
-
-/// Store a single result within the result list.
-template <typename T>
-static LogicalResult processResults(PatternRewriter &rewriter,
- PDLResultList &results, T &&value) {
- ProcessPDLValue<T>::processAsResult(rewriter, results,
- std::forward<T>(value));
- return success();
-}
-
-/// Store a std::pair<> as individual results within the result list.
-template <typename T1, typename T2>
-static LogicalResult processResults(PatternRewriter &rewriter,
- PDLResultList &results,
- std::pair<T1, T2> &&pair) {
- if (failed(processResults(rewriter, results, std::move(pair.first))) ||
- failed(processResults(rewriter, results, std::move(pair.second))))
- return failure();
- return success();
-}
-
-/// Store a std::tuple<> as individual results within the result list.
-template <typename... Ts>
-static LogicalResult processResults(PatternRewriter &rewriter,
- PDLResultList &results,
- std::tuple<Ts...> &&tuple) {
- auto applyFn = [&](auto &&...args) {
- return (succeeded(processResults(rewriter, results, std::move(args))) &&
- ...);
- };
- return success(std::apply(applyFn, std::move(tuple)));
-}
-
-/// Handle LogicalResult propagation.
-inline LogicalResult processResults(PatternRewriter &rewriter,
- PDLResultList &results,
- LogicalResult &&result) {
- return result;
-}
-template <typename T>
-static LogicalResult processResults(PatternRewriter &rewriter,
- PDLResultList &results,
- FailureOr<T> &&result) {
- if (failed(result))
- return failure();
- return processResults(rewriter, results, std::move(*result));
-}
-
-//===----------------------------------------------------------------------===//
-// PDL Constraint Builder
-//===----------------------------------------------------------------------===//
-
-/// Process the arguments of a native constraint and invoke it.
-template <typename PDLFnT, std::size_t... I,
- typename FnTraitsT = llvm::function_traits<PDLFnT>>
-typename FnTraitsT::result_t
-processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
- ArrayRef<PDLValue> values,
- std::index_sequence<I...>) {
- return fn(
- rewriter,
- (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
- values[I]))...);
-}
-
-/// Build a constraint function from the given function `ConstraintFnT`. This
-/// allows for enabling the user to define simpler, more direct constraint
-/// functions without needing to handle the low-level PDL goop.
-///
-/// If the constraint function is already in the correct form, we just forward
-/// it directly.
-template <typename ConstraintFnT>
-std::enable_if_t<
- std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
- PDLConstraintFunction>
-buildConstraintFn(ConstraintFnT &&constraintFn) {
- return std::forward<ConstraintFnT>(constraintFn);
-}
-/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
-/// we desire.
-template <typename ConstraintFnT>
-std::enable_if_t<
- !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
- PDLConstraintFunction>
-buildConstraintFn(ConstraintFnT &&constraintFn) {
- return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
- PatternRewriter &rewriter,
- ArrayRef<PDLValue> values) -> LogicalResult {
- auto argIndices = std::make_index_sequence<
- llvm::function_traits<ConstraintFnT>::num_args - 1>();
- if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
- return failure();
- return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
- argIndices);
- };
-}
-
-//===----------------------------------------------------------------------===//
-// PDL Rewrite Builder
-//===----------------------------------------------------------------------===//
-
-/// Process the arguments of a native rewrite and invoke it.
-/// This overload handles the case of no return values.
-template <typename PDLFnT, std::size_t... I,
- typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
- LogicalResult>
-processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
- PDLResultList &, ArrayRef<PDLValue> values,
- std::index_sequence<I...>) {
- fn(rewriter,
- (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
- values[I]))...);
- return success();
-}
-/// This overload handles the case of return values, which need to be packaged
-/// into the result list.
-template <typename PDLFnT, std::size_t... I,
- typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
- LogicalResult>
-processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
- PDLResultList &results, ArrayRef<PDLValue> values,
- std::index_sequence<I...>) {
- return processResults(
- rewriter, results,
- fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
- processAsArg(values[I]))...));
- (void)values;
-}
-
-/// Build a rewrite function from the given function `RewriteFnT`. This
-/// allows for enabling the user to define simpler, more direct rewrite
-/// functions without needing to handle the low-level PDL goop.
-///
-/// If the rewrite function is already in the correct form, we just forward
-/// it directly.
-template <typename RewriteFnT>
-std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
- PDLRewriteFunction>
-buildRewriteFn(RewriteFnT &&rewriteFn) {
- return std::forward<RewriteFnT>(rewriteFn);
-}
-/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
-/// we desire.
-template <typename RewriteFnT>
-std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
- PDLRewriteFunction>
-buildRewriteFn(RewriteFnT &&rewriteFn) {
- return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
- PatternRewriter &rewriter, PDLResultList &results,
- ArrayRef<PDLValue> values) {
- auto argIndices =
- std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
- 1>();
- assertArgs<RewriteFnT>(rewriter, values, argIndices);
- return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
- argIndices);
- };
-}
-
-} // namespace pdl_function_builder
-} // namespace detail
-
-//===----------------------------------------------------------------------===//
-// PDLPatternModule
-
-/// This class contains all of the necessary data for a set of PDL patterns, or
-/// pattern rewrites specified in the form of the PDL dialect. This PDL module
-/// contained by this pattern may contain any number of `pdl.pattern`
-/// operations.
-class PDLPatternModule {
-public:
- PDLPatternModule() = default;
-
- /// Construct a PDL pattern with the given module and configurations.
- PDLPatternModule(OwningOpRef<ModuleOp> module)
- : pdlModule(std::move(module)) {}
- template <typename... ConfigsT>
- PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
- : PDLPatternModule(std::move(module)) {
- auto configSet = std::make_unique<PDLPatternConfigSet>(
- std::forward<ConfigsT>(patternConfigs)...);
- attachConfigToPatterns(*pdlModule, *configSet);
- configs.emplace_back(std::move(configSet));
- }
-
- /// Merge the state in `other` into this pattern module.
- void mergeIn(PDLPatternModule &&other);
-
- /// Return the internal PDL module of this pattern.
- ModuleOp getModule() { return pdlModule.get(); }
-
- //===--------------------------------------------------------------------===//
- // Function Registry
-
- /// Register a constraint function with PDL. A constraint function may be
- /// specified in one of two ways:
- ///
- /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
- ///
- /// In this overload the arguments of the constraint function are passed via
- /// the low-level PDLValue form.
- ///
- /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
- ///
- /// In this form the arguments of the constraint function are passed via the
- /// expected high level C++ type. In this form, the framework will
- /// automatically unwrap PDLValues and convert them to the expected ValueTs.
- /// For example, if the constraint function accepts a `Operation *`, the
- /// framework will automatically cast the input PDLValue. In the case of a
- /// `StringRef`, the framework will automatically unwrap the argument as a
- /// StringAttr and pass the underlying string value. To see the full list of
- /// supported types, or to see how to add handling for custom types, view
- /// the definition of `ProcessPDLValue` above.
- void registerConstraintFunction(StringRef name,
- PDLConstraintFunction constraintFn);
- template <typename ConstraintFnT>
- void registerConstraintFunction(StringRef name,
- ConstraintFnT &&constraintFn) {
- registerConstraintFunction(name,
- detail::pdl_function_builder::buildConstraintFn(
- std::forward<ConstraintFnT>(constraintFn)));
- }
-
- /// Register a rewrite function with PDL. A rewrite function may be specified
- /// in one of two ways:
- ///
- /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
- ///
- /// In this overload the arguments of the constraint function are passed via
- /// the low-level PDLValue form, and the results are manually appended to
- /// the given result list.
- ///
- /// * `ResultT (PatternRewriter &, ValueTs... values)`
- ///
- /// In this form the arguments and result of the rewrite function are passed
- /// via the expected high level C++ type. In this form, the framework will
- /// automatically unwrap the PDLValues arguments and convert them to the
- /// expected ValueTs. It will also automatically handle the processing and
- /// packaging of the result value to the result list. For example, if the
- /// rewrite function takes a `Operation *`, the framework will automatically
- /// cast the input PDLValue. In the case of a `StringRef`, the framework
- /// will automatically unwrap the argument as a StringAttr and pass the
- /// underlying string value. In the reverse case, if the rewrite returns a
- /// StringRef or std::string, it will automatically package this as a
- /// StringAttr and append it to the result list. To see the full list of
- /// supported types, or to see how to add handling for custom types, view
- /// the definition of `ProcessPDLValue` above.
- void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
- template <typename RewriteFnT>
- void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
- registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
- std::forward<RewriteFnT>(rewriteFn)));
- }
-
- /// Return the set of the registered constraint functions.
- const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
- return constraintFunctions;
- }
- llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
- return constraintFunctions;
- }
- /// Return the set of the registered rewrite functions.
- const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
- return rewriteFunctions;
- }
- llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
- return rewriteFunctions;
- }
-
- /// Return the set of the registered pattern configs.
- SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
- return std::move(configs);
- }
- DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
- return std::move(configMap);
- }
-
- /// Clear out the patterns and functions within this module.
- void clear() {
- pdlModule = nullptr;
- constraintFunctions.clear();
- rewriteFunctions.clear();
- }
-
-private:
- /// Attach the given pattern config set to the patterns defined within the
- /// given module.
- void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
-
- /// The module containing the `pdl.pattern` operations.
- OwningOpRef<ModuleOp> pdlModule;
+} // namespace mlir
- /// The set of configuration sets referenced by patterns within `pdlModule`.
- SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
- DenseMap<Operation *, PDLPatternConfigSet *> configMap;
+// Optionally expose PDL pattern matching methods.
+#include "PDLPatternMatch.h.inc"
- /// The external functions referenced from within the PDL module.
- llvm::StringMap<PDLConstraintFunction> constraintFunctions;
- llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
-};
+namespace mlir {
//===----------------------------------------------------------------------===//
// RewritePatternSet
@@ -1679,8 +759,7 @@ class RewritePatternSet {
nativePatterns.emplace_back(std::move(pattern));
}
RewritePatternSet(PDLPatternModule &&pattern)
- : context(pattern.getModule()->getContext()),
- pdlPatterns(std::move(pattern)) {}
+ : context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
MLIRContext *getContext() const { return context; }
@@ -1853,6 +932,7 @@ class RewritePatternSet {
pattern->addDebugLabels(debugLabels);
nativePatterns.emplace_back(std::move(pattern));
}
+
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
@@ -1863,6 +943,9 @@ class RewritePatternSet {
MLIRContext *const context;
NativePatternListT nativePatterns;
+
+ // Patterns expressed with PDL. This will compile to a stub class when PDL is
+ // not enabled.
PDLPatternModule pdlPatterns;
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6de981d35c8c3a..c5725e9c856256 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -13,6 +13,7 @@
#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
+#include "mlir/Config/mlir-config.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
@@ -1015,6 +1016,7 @@ class ConversionTarget {
MLIRContext &ctx;
};
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
@@ -1044,6 +1046,19 @@ class PDLConversionConfig final
/// Register the dialect conversion PDL functions with the given pattern set.
void registerConversionPDLFunctions(RewritePatternSet &patterns);
+#else
+
+// Stubs for when PDL in rewriting is not enabled.
+
+inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
+
+class PDLConversionConfig final {
+public:
+ PDLConversionConfig(const TypeConverter * /*converter*/) {}
+};
+
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
index 7f7b348b17ae68..be5eb73b91229f 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
@@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
MLIRFunctionInterfaces
MLIRLinalgDialect
MLIRParser
- MLIRPDLDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 21bba11b851171..a155b7c5ecade1 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -61,3 +61,10 @@ add_mlir_library(MLIRIR
LINK_LIBS PUBLIC
MLIRSupport
)
+
+if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
+ add_subdirectory(PDL)
+ target_link_libraries(MLIRIR PUBLIC
+ MLIRIRPDLPatternMatch)
+endif()
+
diff --git a/mlir/lib/IR/PDL/CMakeLists.txt b/mlir/lib/IR/PDL/CMakeLists.txt
new file mode 100644
index 00000000000000..08b7fe36fac096
--- /dev/null
+++ b/mlir/lib/IR/PDL/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_library(MLIRIRPDLPatternMatch
+ PDLPatternMatch.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+)
+
diff --git a/mlir/lib/IR/PDL/PDLPatternMatch.cpp b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
new file mode 100644
index 00000000000000..da07cc462a5a13
--- /dev/null
+++ b/mlir/lib/IR/PDL/PDLPatternMatch.cpp
@@ -0,0 +1,133 @@
+//===- PDLPatternMatch.cpp - Base classes for PDL pattern match
+//------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Iterators.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+//===----------------------------------------------------------------------===//
+
+void PDLValue::print(raw_ostream &os) const {
+ if (!value) {
+ os << "<NULL-PDLValue>";
+ return;
+ }
+ switch (kind) {
+ case Kind::Attribute:
+ os << cast<Attribute>();
+ break;
+ case Kind::Operation:
+ os << *cast<Operation *>();
+ break;
+ case Kind::Type:
+ os << cast<Type>();
+ break;
+ case Kind::TypeRange:
+ llvm::interleaveComma(cast<TypeRange>(), os);
+ break;
+ case Kind::Value:
+ os << cast<Value>();
+ break;
+ case Kind::ValueRange:
+ llvm::interleaveComma(cast<ValueRange>(), os);
+ break;
+ }
+}
+
+void PDLValue::print(raw_ostream &os, Kind kind) {
+ switch (kind) {
+ case Kind::Attribute:
+ os << "Attribute";
+ break;
+ case Kind::Operation:
+ os << "Operation";
+ break;
+ case Kind::Type:
+ os << "Type";
+ break;
+ case Kind::TypeRange:
+ os << "TypeRange";
+ break;
+ case Kind::Value:
+ os << "Value";
+ break;
+ case Kind::ValueRange:
+ os << "ValueRange";
+ break;
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+//===----------------------------------------------------------------------===//
+
+void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
+ // Ignore the other module if it has no patterns.
+ if (!other.pdlModule)
+ return;
+
+ // Steal the functions and config of the other module.
+ for (auto &it : other.constraintFunctions)
+ registerConstraintFunction(it.first(), std::move(it.second));
+ for (auto &it : other.rewriteFunctions)
+ registerRewriteFunction(it.first(), std::move(it.second));
+ for (auto &it : other.configs)
+ configs.emplace_back(std::move(it));
+ for (auto &it : other.configMap)
+ configMap.insert(it);
+
+ // Steal the other state if we have no patterns.
+ if (!pdlModule) {
+ pdlModule = std::move(other.pdlModule);
+ return;
+ }
+
+ // Merge the pattern operations from the other module into this one.
+ Block *block = pdlModule->getBody();
+ block->getOperations().splice(block->end(),
+ other.pdlModule->getBody()->getOperations());
+}
+
+void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
+ PDLPatternConfigSet &configSet) {
+ // Attach the configuration to the symbols within the module. We only add
+ // to symbols to avoid hardcoding any specific operation names here (given
+ // that we don't depend on any PDL dialect). We can't use
+ // cast<SymbolOpInterface> here because patterns may be optional symbols.
+ module->walk([&](Operation *op) {
+ if (op->hasTrait<SymbolOpInterface::Trait>())
+ configMap[op] = &configSet;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Function Registry
+
+void PDLPatternModule::registerConstraintFunction(
+ StringRef name, PDLConstraintFunction constraintFn) {
+ // TODO: Is it possible to diagnose when `name` is already registered to
+ // a function that is not equivalent to `constraintFn`?
+ // Allow existing mappings in the case multiple patterns depend on the same
+ // constraint.
+ constraintFunctions.try_emplace(name, std::move(constraintFn));
+}
+
+void PDLPatternModule::registerRewriteFunction(StringRef name,
+ PDLRewriteFunction rewriteFn) {
+ // TODO: Is it possible to diagnose when `name` is already registered to
+ // a function that is not equivalent to `rewriteFn`?
+ // Allow existing mappings in the case multiple patterns depend on the same
+ // rewrite.
+ rewriteFunctions.try_emplace(name, std::move(rewriteFn));
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e9b9b2a810a4c..5e788cdb4897d3 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Config/mlir-config.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
@@ -97,124 +98,6 @@ LogicalResult RewritePattern::match(Operation *op) const {
/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}
-//===----------------------------------------------------------------------===//
-// PDLValue
-//===----------------------------------------------------------------------===//
-
-void PDLValue::print(raw_ostream &os) const {
- if (!value) {
- os << "<NULL-PDLValue>";
- return;
- }
- switch (kind) {
- case Kind::Attribute:
- os << cast<Attribute>();
- break;
- case Kind::Operation:
- os << *cast<Operation *>();
- break;
- case Kind::Type:
- os << cast<Type>();
- break;
- case Kind::TypeRange:
- llvm::interleaveComma(cast<TypeRange>(), os);
- break;
- case Kind::Value:
- os << cast<Value>();
- break;
- case Kind::ValueRange:
- llvm::interleaveComma(cast<ValueRange>(), os);
- break;
- }
-}
-
-void PDLValue::print(raw_ostream &os, Kind kind) {
- switch (kind) {
- case Kind::Attribute:
- os << "Attribute";
- break;
- case Kind::Operation:
- os << "Operation";
- break;
- case Kind::Type:
- os << "Type";
- break;
- case Kind::TypeRange:
- os << "TypeRange";
- break;
- case Kind::Value:
- os << "Value";
- break;
- case Kind::ValueRange:
- os << "ValueRange";
- break;
- }
-}
-
-//===----------------------------------------------------------------------===//
-// PDLPatternModule
-//===----------------------------------------------------------------------===//
-
-void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
- // Ignore the other module if it has no patterns.
- if (!other.pdlModule)
- return;
-
- // Steal the functions and config of the other module.
- for (auto &it : other.constraintFunctions)
- registerConstraintFunction(it.first(), std::move(it.second));
- for (auto &it : other.rewriteFunctions)
- registerRewriteFunction(it.first(), std::move(it.second));
- for (auto &it : other.configs)
- configs.emplace_back(std::move(it));
- for (auto &it : other.configMap)
- configMap.insert(it);
-
- // Steal the other state if we have no patterns.
- if (!pdlModule) {
- pdlModule = std::move(other.pdlModule);
- return;
- }
-
- // Merge the pattern operations from the other module into this one.
- Block *block = pdlModule->getBody();
- block->getOperations().splice(block->end(),
- other.pdlModule->getBody()->getOperations());
-}
-
-void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
- PDLPatternConfigSet &configSet) {
- // Attach the configuration to the symbols within the module. We only add
- // to symbols to avoid hardcoding any specific operation names here (given
- // that we don't depend on any PDL dialect). We can't use
- // cast<SymbolOpInterface> here because patterns may be optional symbols.
- module->walk([&](Operation *op) {
- if (op->hasTrait<SymbolOpInterface::Trait>())
- configMap[op] = &configSet;
- });
-}
-
-//===----------------------------------------------------------------------===//
-// Function Registry
-
-void PDLPatternModule::registerConstraintFunction(
- StringRef name, PDLConstraintFunction constraintFn) {
- // TODO: Is it possible to diagnose when `name` is already registered to
- // a function that is not equivalent to `constraintFn`?
- // Allow existing mappings in the case multiple patterns depend on the same
- // constraint.
- constraintFunctions.try_emplace(name, std::move(constraintFn));
-}
-
-void PDLPatternModule::registerRewriteFunction(StringRef name,
- PDLRewriteFunction rewriteFn) {
- // TODO: Is it possible to diagnose when `name` is already registered to
- // a function that is not equivalent to `rewriteFn`?
- // Allow existing mappings in the case multiple patterns depend on the same
- // rewrite.
- rewriteFunctions.try_emplace(name, std::move(rewriteFn));
-}
-
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index 4d43fe636bd1f2..4aceac7ed3a4c9 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -16,6 +16,8 @@
#include "mlir/IR/PatternMatch.h"
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
namespace mlir {
namespace pdl_interp {
class RecordMatchOp;
@@ -224,4 +226,38 @@ class PDLByteCode {
} // namespace detail
} // namespace mlir
+#else
+
+namespace mlir::detail {
+
+class PDLByteCodeMutableState {
+public:
+ void cleanupAfterMatchAndRewrite() {}
+ void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {}
+};
+
+class PDLByteCodePattern : public Pattern {};
+
+class PDLByteCode {
+public:
+ struct MatchResult {
+ const PDLByteCodePattern *pattern = nullptr;
+ PatternBenefit benefit;
+ };
+
+ void initializeMutableState(PDLByteCodeMutableState &state) const {}
+ void match(Operation *op, PatternRewriter &rewriter,
+ SmallVectorImpl<MatchResult> &matches,
+ PDLByteCodeMutableState &state) const {}
+ LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
+ PDLByteCodeMutableState &state) const {
+ return failure();
+ }
+ ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; }
+};
+
+} // namespace mlir::detail
+
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+
#endif // MLIR_REWRITE_BYTECODE_H_
diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt
index e0395be6cd6f59..a6c39406aa4b3e 100644
--- a/mlir/lib/Rewrite/CMakeLists.txt
+++ b/mlir/lib/Rewrite/CMakeLists.txt
@@ -1,5 +1,6 @@
+set(LLVM_OPTIONAL_SOURCES ByteCode.cpp)
+
add_mlir_library(MLIRRewrite
- ByteCode.cpp
FrozenRewritePatternSet.cpp
PatternApplicator.cpp
@@ -11,8 +12,31 @@ add_mlir_library(MLIRRewrite
LINK_LIBS PUBLIC
MLIRIR
- MLIRPDLDialect
- MLIRPDLInterpDialect
- MLIRPDLToPDLInterp
MLIRSideEffectInterfaces
)
+
+if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
+ add_mlir_library(MLIRRewritePDL
+ ByteCode.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
+
+ DEPENDS
+ mlir-generic-headers
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPDLDialect
+ MLIRPDLInterpDialect
+ MLIRPDLToPDLInterp
+ MLIRSideEffectInterfaces
+ )
+
+ target_link_libraries(MLIRRewrite PUBLIC
+ MLIRPDLDialect
+ MLIRPDLInterpDialect
+ MLIRPDLToPDLInterp
+ MLIRRewritePDL)
+endif()
+
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 43840d1e8cec2f..17fe02df9f66cd 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -8,8 +8,6 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "ByteCode.h"
-#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
-#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -17,6 +15,11 @@
using namespace mlir;
+// Include the PDL rewrite support.
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+
static LogicalResult
convertPDLToPDLInterp(ModuleOp pdlModule,
DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
@@ -48,6 +51,7 @@ convertPDLToPDLInterp(ModuleOp pdlModule,
pdlModule.getBody()->walk(simplifyFn);
return success();
}
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// FrozenRewritePatternSet
@@ -121,6 +125,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
impl->nativeAnyOpPatterns.push_back(std::move(pat));
}
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
// Generate the bytecode for the PDL patterns if any were provided.
PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
ModuleOp pdlModule = pdlPatterns.getModule();
@@ -137,6 +142,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
pdlModule, pdlPatterns.takeConfigs(), configMap,
pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeRewriteFunctions());
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
}
FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 08d6ee618ac690..0064eb84aba84d 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -152,7 +152,6 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Find the next pattern with the highest benefit.
const Pattern *bestPattern = nullptr;
unsigned *bestPatternIt = &opIt;
- const PDLByteCode::MatchResult *pdlMatch = nullptr;
/// Operation specific patterns.
if (opIt < opE)
@@ -164,6 +163,8 @@ LogicalResult PatternApplicator::matchAndRewrite(
bestPatternIt = &anyIt;
bestPattern = anyOpPatterns[anyIt];
}
+
+ const PDLByteCode::MatchResult *pdlMatch = nullptr;
/// PDL patterns.
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
pdlMatches[pdlIt].benefit)) {
@@ -171,6 +172,7 @@ LogicalResult PatternApplicator::matchAndRewrite(
pdlMatch = &pdlMatches[pdlIt];
bestPattern = pdlMatch->pattern;
}
+
if (!bestPattern)
break;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4d2afe462b9281..85433d088dcbf0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -3312,6 +3313,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
return std::nullopt;
}
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
@@ -3382,6 +3384,7 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
return std::move(remappedTypes);
});
}
+#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 3f312164cb1f35..7ec4c8f0963a2d 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -97,16 +97,13 @@ set(MLIR_TEST_DEPENDS
mlir-capi-ir-test
mlir-capi-llvm-test
mlir-capi-pass-test
- mlir-capi-pdl-test
mlir-capi-quant-test
mlir-capi-sparse-tensor-test
mlir-capi-transform-test
mlir-capi-translation-test
mlir-linalg-ods-yaml-gen
mlir-lsp-server
- mlir-pdll-lsp-server
mlir-opt
- mlir-pdll
mlir-query
mlir-reduce
mlir-tblgen
@@ -115,6 +112,12 @@ set(MLIR_TEST_DEPENDS
tblgen-to-irdl
)
+set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS}
+ mlir-capi-pdl-test
+ mlir-pdll-lsp-server
+ mlir-pdll
+ )
+
# The native target may not be enabled, in this case we won't
# run tests that involves executing on the host: do not build
# useless binaries.
@@ -159,9 +162,10 @@ if(LLVM_BUILD_EXAMPLES)
toyc-ch3
toyc-ch4
toyc-ch5
+ )
+ list(APPEND MLIR_TEST_DEPENDS
transform-opt-ch2
transform-opt-ch3
- mlir-minimal-opt
)
if(MLIR_ENABLE_EXECUTION_ENGINE)
list(APPEND MLIR_TEST_DEPENDS
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index e032ce7200fbf8..2a3a8608db5442 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+ TestDialectConversion.cpp)
+set(MLIRTestTransformsPDLDep)
+set(MLIRTestTransformsPDLSrc)
+if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
TestDialectConversion.pdll
TestDialectConversionPDLLPatterns.h.inc
@@ -6,17 +11,22 @@ add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
)
+ set(MLIRTestTransformsPDLSrc
+ TestDialectConversion.cpp)
+ set(MLIRTestTransformsPDLDep
+ MLIRTestDialectConversionPDLLPatternsIncGen)
+endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
- TestDialectConversion.cpp
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
TestTopologicalSort.cpp
+ ${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
@@ -24,7 +34,7 @@ add_mlir_library(MLIRTestTransforms
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
DEPENDS
- MLIRTestDialectConversionPDLLPatternsIncGen
+ ${MLIRTestTransformsPDLDep}
LINK_LIBS PUBLIC
MLIRAnalysis
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index e90ccf17af17f5..9664f6b94844e6 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -21,10 +21,12 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestIR
MLIRTestPass
MLIRTestReducer
+ )
+ set(test_libs
+ ${test_libs}
MLIRTestRewrite
MLIRTestTransformDialect
- MLIRTestTransforms
- )
+ MLIRTestTransforms)
endif()
set(LIBS
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b6ada66d321880..15317a119c154c 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -38,16 +38,18 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestIR
MLIRTestOneToNTypeConversionPass
MLIRTestPass
- MLIRTestPDLL
MLIRTestReducer
- MLIRTestRewrite
- MLIRTestTransformDialect
MLIRTestTransforms
MLIRTilingInterfaceTestPasses
MLIRVectorTestPasses
MLIRTestVectorToSPIRV
MLIRLLVMTestPasses
)
+ set(test_libs ${test_libs}
+ MLIRTestPDLL
+ MLIRTestRewrite
+ MLIRTestTransformDialect
+ )
endif()
set(LIBS
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f7a5b3183b50b1..bf8f3b7aa21d11 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -85,7 +85,9 @@ void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
void registerTestDialectConversionPasses();
+#endif
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
@@ -147,8 +149,8 @@ void registerTestNvgpuLowerings();
namespace test {
void registerTestDialect(DialectRegistry &);
-void registerTestTransformDialectExtension(DialectRegistry &);
void registerTestDynDialect(DialectRegistry &);
+void registerTestTransformDialectExtension(DialectRegistry &);
} // namespace test
#ifdef MLIR_INCLUDE_TESTS
@@ -260,6 +262,9 @@ void registerTestPasses() {
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestNvgpuLowerings();
mlir::test::registerTestWrittenToPass();
+#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
+ mlir::test::registerTestDialectConversionPasses();
+#endif
}
#endif
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2a56b2d6f0373f..2a72bf965e544d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -35,6 +35,7 @@ expand_template(
substitutions = {
"#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS": "#define MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 0",
"#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}": "/* #undef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED */",
+ "#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH": "#define MLIR_ENABLE_PDL_IN_PATTERNMATCH 1",
},
template = "include/mlir/Config/mlir-config.h.cmake",
)
@@ -318,11 +319,13 @@ cc_library(
srcs = glob([
"lib/IR/*.cpp",
"lib/IR/*.h",
+ "lib/IR/PDL/*.cpp",
"lib/Bytecode/Reader/*.h",
"lib/Bytecode/Writer/*.h",
"lib/Bytecode/*.h",
]) + [
"lib/Bytecode/BytecodeOpInterface.cpp",
+ "include/mlir/IR/PDLPatternMatch.h.inc",
],
hdrs = glob([
"include/mlir/IR/*.h",
@@ -345,6 +348,7 @@ cc_library(
":BuiltinTypesIncGen",
":BytecodeOpInterfaceIncGen",
":CallOpInterfacesIncGen",
+ ":config",
":DataLayoutInterfacesIncGen",
":InferTypeOpInterfaceIncGen",
":OpAsmInterfaceIncGen",
More information about the llvm-commits
mailing list