[llvm-branch-commits] [mlir] abfd1a8 - [mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
River Riddle via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Dec 1 15:11:02 PST 2020
Author: River Riddle
Date: 2020-12-01T15:05:50-08:00
New Revision: abfd1a8b3bc5ad8516a83c3ae7ba9f16032525ad
URL: https://github.com/llvm/llvm-project/commit/abfd1a8b3bc5ad8516a83c3ae7ba9f16032525ad
DIFF: https://github.com/llvm/llvm-project/commit/abfd1a8b3bc5ad8516a83c3ae7ba9f16032525ad.diff
LOG: [mlir][PDL] Add support for PDL bytecode and expose PDL support to OwningRewritePatternList
PDL patterns are now supported via a new `PDLPatternModule` class. This class contains a ModuleOp with the pdl::PatternOp operations representing the patterns, as well as a collection of registered C++ functions for native constraints/creations/rewrites/etc. that may be invoked via the pdl patterns. Instances of this class are added to an OwningRewritePatternList in the same fashion as C++ RewritePatterns, i.e. via the `insert` method.
The PDL bytecode is an in-memory representation of the PDL interpreter dialect that can be efficiently interpreted/executed. The representation of the bytecode boils down to a code array(for opcodes/memory locations/etc) and a memory buffer(for storing attributes/operations/values/any other data necessary). The bytecode operations are effectively a 1-1 mapping to the PDLInterp dialect operations, with a few exceptions in cases where the in-memory representation of the bytecode can be more efficient than the MLIR representation. For example, a generic `AreEqual` bytecode op can be used to represent AreEqualOp, CheckAttributeOp, and CheckTypeOp.
The execution of the bytecode is split into two phases: matching and rewriting. When matching, all of the matched patterns are collected to avoid the overhead of re-running parts of the matcher. These matched patterns are then considered alongside the native C++ patterns, which rewrite immediately in-place via `RewritePattern::matchAndRewrite`, for the given root operation. When a PDL pattern is matched and has the highest benefit, it is passed back to the bytecode to execute its rewriter.
Differential Revision: https://reviews.llvm.org/D89107
Added:
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Rewrite/CMakeLists.txt
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Modified:
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/BlockSupport.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/IR/StorageUniquerSupport.h
mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
mlir/include/mlir/Rewrite/PatternApplicator.h
mlir/lib/IR/Block.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/CMakeLists.txt
mlir/lib/Rewrite/FrozenRewritePatternList.cpp
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/test/lib/CMakeLists.txt
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index df49eb37b2a5..6b11c0dde809 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -108,7 +108,7 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
// `op`.
- pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+ pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest
```
}];
@@ -316,7 +316,7 @@ def PDLInterp_CheckTypeOp
Example:
```mlir
- pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest
+ pdl_interp.check_type %type is i32 -> ^matchDest, ^failureDest
```
}];
@@ -338,7 +338,7 @@ def PDLInterp_CreateAttributeOp
Example:
```mlir
- pdl_interp.create_attribute 10 : i64
+ %attr = pdl_interp.create_attribute 10 : i64
```
}];
@@ -369,7 +369,7 @@ def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
Example:
```mlir
- %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+ %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute
```
}];
@@ -772,7 +772,7 @@ def PDLInterp_SwitchAttributeOp
Example:
```mlir
- pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest
+ pdl_interp.switch_attribute %attr to [10, true](^10Dest, ^trueDest) -> ^defaultDest
```
}];
let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues);
@@ -837,7 +837,7 @@ def PDLInterp_SwitchOperationNameOp
Example:
```mlir
- pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest
+ pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"](^fooDest, ^barDest) -> ^defaultDest
```
}];
@@ -874,7 +874,7 @@ def PDLInterp_SwitchResultCountOp
Example:
```mlir
- pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest
+ pdl_interp.switch_result_count of %op to [0, 2](^0Dest, ^2Dest) -> ^defaultDest
```
}];
diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index fc16effbba70..6cf2df9a1406 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -58,6 +58,7 @@ class SuccessorRange final
SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
public:
using RangeBaseT::RangeBaseT;
+ SuccessorRange();
SuccessorRange(Block *block);
SuccessorRange(Operation *term);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5b3c44868db2..3d5bc66ee9e2 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -69,6 +69,9 @@ class Operation final
/// Remove this operation from its parent block and delete it.
void erase();
+ /// Remove the operation from its parent block, but don't delete it.
+ void remove();
+
/// Create a deep copy of this operation, remapping any operands that use
/// values outside of the operation using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 96d6d1194b60..74899c9565fe 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -349,7 +349,7 @@ class OperationName {
void *getAsOpaquePointer() const {
return static_cast<void *>(representation.getOpaqueValue());
}
- static OperationName getFromOpaquePointer(void *pointer);
+ static OperationName getFromOpaquePointer(const void *pointer);
private:
RepresentationUnion representation;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2158f09cc469..4fdc0878c590 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -10,6 +10,7 @@
#define MLIR_PATTERNMATCHER_H
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
namespace mlir {
@@ -225,6 +226,189 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
}
};
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// PDLValue
+
+/// Storage type of byte-code interpreter values. These are passed to constraint
+/// functions as arguments.
+class PDLValue {
+ /// The internal implementation type when the value is an Attribute,
+ /// Operation*, or Type. See `impl` below for more details.
+ using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>;
+
+public:
+ PDLValue(const PDLValue &other) : impl(other.impl) {}
+ PDLValue(std::nullptr_t = nullptr) : impl() {}
+ PDLValue(Attribute value) : impl(value) {}
+ PDLValue(Operation *value) : impl(value) {}
+ PDLValue(Type value) : impl(value) {}
+ PDLValue(Value value) : impl(value) {}
+
+ /// Returns true if the type of the held value is `T`.
+ template <typename T>
+ std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const {
+ return impl.is<Value>();
+ }
+ template <typename T>
+ std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const {
+ auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
+ return attrOpTypeImpl && attrOpTypeImpl.is<T>();
+ }
+
+ /// Attempt to dynamically cast this value to type `T`, returns null if this
+ /// value is not an instance of `T`.
+ template <typename T>
+ std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const {
+ return impl.dyn_cast<T>();
+ }
+ template <typename T>
+ std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const {
+ auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>();
+ return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>();
+ }
+
+ /// Cast this value to type `T`, asserts if this value is not an instance of
+ /// `T`.
+ template <typename T>
+ std::enable_if_t<std::is_same<T, Value>::value, T> cast() const {
+ return impl.get<T>();
+ }
+ template <typename T>
+ std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const {
+ return impl.get<AttrOpTypeImplT>().get<T>();
+ }
+
+ /// Get an opaque pointer to the value.
+ void *getAsOpaquePointer() { return impl.getOpaqueValue(); }
+
+ /// Print this value to the provided output stream.
+ void print(raw_ostream &os);
+
+private:
+ /// The internal opaque representation of a PDLValue. We use a nested
+ /// PointerUnion structure here because `Value` only has 1 low bit
+ /// available, where as the remaining types all have 3.
+ llvm::PointerUnion<AttrOpTypeImplT, Value> impl;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
+ value.print(os);
+ return os;
+}
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given set of opaque PDLValue entities. The second parameter
+/// is a set of constant value parameters specified in Attribute form. Returns
+/// success if the constraint successfully held, failure otherwise.
+using PDLConstraintFunction = std::function<LogicalResult(
+ ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
+/// A native PDL creation function. This function creates a new PDLValue given
+/// a set of existing PDL values, a set of constant parameters specified in
+/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
+using PDLCreateFunction =
+ std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
+/// A native PDL rewrite function. This function rewrites the given root
+/// operation using the provided PatternRewriter. This method is only invoked
+/// when the corresponding match was successful.
+using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
+ ArrayAttr, PatternRewriter &)>;
+/// A generic PDL pattern constraint function. This function applies a
+/// constraint to a given opaque PDLValue entity. The second parameter is a set
+/// of constant value parameters specified in Attribute form. Returns success if
+/// the constraint successfully held, failure otherwise.
+using PDLSingleEntityConstraintFunction =
+ std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>;
+
+/// This class contains all of the necessary data for a set of PDL patterns, or
+/// pattern rewrites specified in the form of the PDL dialect. This PDL module
+/// contained by this pattern may contain any number of `pdl.pattern`
+/// operations.
+class PDLPatternModule {
+public:
+ PDLPatternModule() = default;
+
+ /// Construct a PDL pattern with the given module.
+ PDLPatternModule(OwningModuleRef pdlModule)
+ : pdlModule(std::move(pdlModule)) {}
+
+ /// Merge the state in `other` into this pattern module.
+ void mergeIn(PDLPatternModule &&other);
+
+ /// Return the internal PDL module of this pattern.
+ ModuleOp getModule() { return pdlModule.get(); }
+
+ //===--------------------------------------------------------------------===//
+ // Function Registry
+
+ /// Register a constraint function.
+ void registerConstraintFunction(StringRef name,
+ PDLConstraintFunction constraintFn);
+ /// Register a single entity constraint function.
+ template <typename SingleEntityFn>
+ std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
+ ArrayAttr, PatternRewriter &>::value>
+ registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
+ registerConstraintFunction(name, [=](ArrayRef<PDLValue> values,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ assert(values.size() == 1 && "expected values to have a single entity");
+ return constraintFn(values[0], constantParams, rewriter);
+ });
+ }
+
+ /// Register a creation function.
+ void registerCreateFunction(StringRef name, PDLCreateFunction createFn);
+
+ /// Register a rewrite function.
+ void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+
+ /// Return the set of the registered constraint functions.
+ const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+ return constraintFunctions;
+ }
+ llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
+ return constraintFunctions;
+ }
+ /// Return the set of the registered create functions.
+ const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
+ return createFunctions;
+ }
+ llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
+ return createFunctions;
+ }
+ /// Return the set of the registered rewrite functions.
+ const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
+ return rewriteFunctions;
+ }
+ llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
+ return rewriteFunctions;
+ }
+
+ /// Clear out the patterns and functions within this module.
+ void clear() {
+ pdlModule = nullptr;
+ constraintFunctions.clear();
+ createFunctions.clear();
+ rewriteFunctions.clear();
+ }
+
+private:
+ /// The module containing the `pdl.pattern` operations.
+ OwningModuleRef pdlModule;
+
+ /// The external functions referenced from within the PDL module.
+ llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+ llvm::StringMap<PDLCreateFunction> createFunctions;
+ llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
+};
+
//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//
@@ -384,28 +568,28 @@ class PatternRewriter : public OpBuilder, public OpBuilder::Listener {
//===----------------------------------------------------------------------===//
class OwningRewritePatternList {
- using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+ using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
OwningRewritePatternList() = default;
- /// Construct a OwningRewritePatternList populated with the pattern `t` of
- /// type `T`.
- template <typename T>
- OwningRewritePatternList(T &&t) {
- patterns.emplace_back(std::make_unique<T>(std::forward<T>(t)));
+ /// Construct a OwningRewritePatternList populated with the given pattern.
+ OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
+ nativePatterns.emplace_back(std::move(pattern));
}
+ OwningRewritePatternList(PDLPatternModule &&pattern)
+ : pdlPatterns(std::move(pattern)) {}
+
+ /// Return the native patterns held in this list.
+ NativePatternListT &getNativePatterns() { return nativePatterns; }
- PatternListT::iterator begin() { return patterns.begin(); }
- PatternListT::iterator end() { return patterns.end(); }
- PatternListT::const_iterator begin() const { return patterns.begin(); }
- PatternListT::const_iterator end() const { return patterns.end(); }
- PatternListT::size_type size() const { return patterns.size(); }
- void clear() { patterns.clear(); }
+ /// Return the PDL patterns held in this list.
+ PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
- /// Take ownership of the patterns held by this list.
- std::vector<std::unique_ptr<RewritePattern>> takePatterns() {
- return std::move(patterns);
+ /// Clear out all of the held patterns in this list.
+ void clear() {
+ nativePatterns.clear();
+ pdlPatterns.clear();
}
//===--------------------------------------------------------------------===//
@@ -419,31 +603,53 @@ class OwningRewritePatternList {
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
OwningRewritePatternList &insert(ConstructorArg &&arg,
- ConstructorArgs &&... args) {
+ ConstructorArgs &&...args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{
- 0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
+ (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...};
return *this;
}
/// Add an instance of each of the pattern types 'Ts'. Return a reference to
/// `this` for chaining insertions.
template <typename... Ts> OwningRewritePatternList &insert() {
- (void)std::initializer_list<int>{
- 0, (patterns.emplace_back(std::make_unique<Ts>()), 0)...};
+ (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
return *this;
}
- /// Add the given pattern to the pattern list.
- void insert(std::unique_ptr<RewritePattern> pattern) {
- patterns.emplace_back(std::move(pattern));
+ /// Add the given native pattern to the pattern list. Return a reference to
+ /// `this` for chaining insertions.
+ OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) {
+ nativePatterns.emplace_back(std::move(pattern));
+ return *this;
+ }
+
+ /// Add the given PDL pattern to the pattern list. Return a reference to
+ /// `this` for chaining insertions.
+ OwningRewritePatternList &insert(PDLPatternModule &&pattern) {
+ pdlPatterns.mergeIn(std::move(pattern));
+ return *this;
}
private:
- PatternListT patterns;
+ /// Add an instance of the pattern type 'T'. Return a reference to `this` for
+ /// chaining insertions.
+ template <typename T, typename... Args>
+ std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
+ insertImpl(Args &&...args) {
+ nativePatterns.emplace_back(
+ std::make_unique<T>(std::forward<Args>(args)...));
+ }
+ template <typename T, typename... Args>
+ std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
+ insertImpl(Args &&...args) {
+ pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
+ }
+
+ NativePatternListT nativePatterns;
+ PDLPatternModule pdlPatterns;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c0096bb6b233..719bb1a62f97 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -104,6 +104,12 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
}
+ /// Get an instance of the concrete type from a void pointer.
+ static ConcreteT getFromOpaquePointer(const void *ptr) {
+ return ptr ? BaseT::getFromOpaquePointer(ptr).template cast<ConcreteT>()
+ : nullptr;
+ }
+
protected:
/// Mutate the current storage instance. This will not change the unique key.
/// The arguments are forwarded to 'ConcreteT::mutate'.
diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
index fb2657d99232..c2335b9dd5a1 100644
--- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
+++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h
@@ -12,25 +12,40 @@
#include "mlir/IR/PatternMatch.h"
namespace mlir {
+namespace detail {
+class PDLByteCode;
+} // end namespace detail
+
/// This class represents a frozen set of patterns that can be processed by a
/// pattern applicator. This class is designed to enable caching pattern lists
/// such that they need not be continuously recomputed.
class FrozenRewritePatternList {
- using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
+ using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternList(OwningRewritePatternList &&patterns);
+ FrozenRewritePatternList(FrozenRewritePatternList &&patterns);
+ ~FrozenRewritePatternList();
+
+ /// Return the native patterns held by this list.
+ iterator_range<llvm::pointee_iterator<NativePatternListT::const_iterator>>
+ getNativePatterns() const {
+ return llvm::make_pointee_range(nativePatterns);
+ }
- /// Return the patterns held by this list.
- iterator_range<llvm::pointee_iterator<PatternListT::const_iterator>>
- getPatterns() const {
- return llvm::make_pointee_range(patterns);
+ /// Return the compiled PDL bytecode held by this list. Returns null if
+ /// there are no PDL patterns within the list.
+ const detail::PDLByteCode *getPDLByteCode() const {
+ return pdlByteCode.get();
}
private:
- /// The patterns held by this list.
- std::vector<std::unique_ptr<RewritePattern>> patterns;
+ /// The set of.
+ std::vector<std::unique_ptr<RewritePattern>> nativePatterns;
+
+ /// The bytecode containing the compiled PDL patterns.
+ std::unique_ptr<detail::PDLByteCode> pdlByteCode;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
index cb7794bab9fc..9d197175b47d 100644
--- a/mlir/include/mlir/Rewrite/PatternApplicator.h
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -19,6 +19,10 @@
namespace mlir {
class PatternRewriter;
+namespace detail {
+class PDLByteCodeMutableState;
+} // end namespace detail
+
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
@@ -29,8 +33,8 @@ class PatternApplicator {
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
- explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList)
- : frozenPatternList(frozenPatternList) {}
+ explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList);
+ ~PatternApplicator();
/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
@@ -60,16 +64,6 @@ class PatternApplicator {
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
private:
- /// Attempt to match and rewrite the given op with the given pattern, allowing
- /// a predicate to decide if a pattern can be applied or not, and hooks for if
- /// the pattern match was a success or failure.
- LogicalResult
- matchAndRewrite(Operation *op, const RewritePattern &pattern,
- PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply,
- function_ref<void(const Pattern &)> onFailure,
- function_ref<LogicalResult(const Pattern &)> onSuccess);
-
/// The list that owns the patterns used within this applicator.
const FrozenRewritePatternList &frozenPatternList;
/// The set of patterns to match for each operation, stable sorted by benefit.
@@ -77,6 +71,8 @@ class PatternApplicator {
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<const RewritePattern *, 1> anyOpPatterns;
+ /// The mutable state used during execution of the PDL bytecode.
+ std::unique_ptr<detail::PDLByteCodeMutableState> mutableByteCodeState;
};
} // end namespace mlir
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index b9ddabb80800..79e7daa12a7c 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -302,13 +302,15 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
// SuccessorRange
//===----------------------------------------------------------------------===//
-SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
+SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
+
+SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
if (Operation *term = block->getTerminator())
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}
-SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
+SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index e725dd87d93f..3037bf082d58 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -61,8 +61,9 @@ const AbstractOperation *OperationName::getAbstractOperation() const {
return representation.dyn_cast<const AbstractOperation *>();
}
-OperationName OperationName::getFromOpaquePointer(void *pointer) {
- return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
+OperationName OperationName::getFromOpaquePointer(const void *pointer) {
+ return OperationName(
+ RepresentationUnion::getFromOpaqueValue(const_cast<void *>(pointer)));
}
//===----------------------------------------------------------------------===//
@@ -484,6 +485,12 @@ void Operation::erase() {
destroy();
}
+/// Remove the operation from its parent block, but don't delete it.
+void Operation::remove() {
+ if (Block *parent = getBlock())
+ parent->getOperations().remove(this);
+}
+
/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index edd5e7b9d6d7..6558fcf4606d 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -69,6 +69,84 @@ LogicalResult RewritePattern::match(Operation *op) const {
/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}
+//===----------------------------------------------------------------------===//
+// PDLValue
+//===----------------------------------------------------------------------===//
+
+void PDLValue::print(raw_ostream &os) {
+ if (!impl) {
+ os << "<Null-PDLValue>";
+ return;
+ }
+ if (Value val = impl.dyn_cast<Value>()) {
+ os << val;
+ return;
+ }
+ AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
+ if (Attribute attr = aotImpl.dyn_cast<Attribute>())
+ os << attr;
+ else if (Operation *op = aotImpl.dyn_cast<Operation *>())
+ os << *op;
+ else
+ os << aotImpl.get<Type>();
+}
+
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+//===----------------------------------------------------------------------===//
+
+void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
+ // Ignore the other module if it has no patterns.
+ if (!other.pdlModule)
+ return;
+ // Steal the other state if we have no patterns.
+ if (!pdlModule) {
+ constraintFunctions = std::move(other.constraintFunctions);
+ createFunctions = std::move(other.createFunctions);
+ rewriteFunctions = std::move(other.rewriteFunctions);
+ pdlModule = std::move(other.pdlModule);
+ return;
+ }
+ // Steal the functions of the other module.
+ for (auto &it : constraintFunctions)
+ registerConstraintFunction(it.first(), std::move(it.second));
+ for (auto &it : createFunctions)
+ registerCreateFunction(it.first(), std::move(it.second));
+ for (auto &it : rewriteFunctions)
+ registerRewriteFunction(it.first(), std::move(it.second));
+
+ // Merge the pattern operations from the other module into this one.
+ Block *block = pdlModule->getBody();
+ block->getTerminator()->erase();
+ block->getOperations().splice(block->end(),
+ other.pdlModule->getBody()->getOperations());
+}
+
+//===----------------------------------------------------------------------===//
+// Function Registry
+
+void PDLPatternModule::registerConstraintFunction(
+ StringRef name, PDLConstraintFunction constraintFn) {
+ auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
+ (void)it;
+ assert(it.second &&
+ "constraint with the given name has already been registered");
+}
+void PDLPatternModule::registerCreateFunction(StringRef name,
+ PDLCreateFunction createFn) {
+ auto it = createFunctions.try_emplace(name, std::move(createFn));
+ (void)it;
+ assert(it.second && "native create function with the given name has "
+ "already been registered");
+}
+void PDLPatternModule::registerRewriteFunction(StringRef name,
+ PDLRewriteFunction rewriteFn) {
+ auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
+ (void)it;
+ assert(it.second && "native rewrite function with the given name has "
+ "already been registered");
+}
+
//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
new file mode 100644
index 000000000000..ae5f322d2948
--- /dev/null
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -0,0 +1,1262 @@
+//===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements MLIR to byte-code generation and the interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ByteCode.h"
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/RegionGraphTraits.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "pdl-bytecode"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodePattern
+//===----------------------------------------------------------------------===//
+
+PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+ ByteCodeAddr rewriterAddr) {
+ SmallVector<StringRef, 8> generatedOps;
+ if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
+ generatedOps =
+ llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
+
+ PatternBenefit benefit = matchOp.benefit();
+ MLIRContext *ctx = matchOp.getContext();
+
+ // Check to see if this is pattern matches a specific operation type.
+ if (Optional<StringRef> rootKind = matchOp.rootKind())
+ return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
+ ctx);
+ return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
+ MatchAnyOpTypeTag());
+}
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodeMutableState
+//===----------------------------------------------------------------------===//
+
+/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
+/// to the position of the pattern within the range returned by
+/// `PDLByteCode::getPatterns`.
+void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
+ PatternBenefit benefit) {
+ currentPatternBenefits[patternIndex] = benefit;
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode OpCodes
+//===----------------------------------------------------------------------===//
+
+namespace {
+enum OpCode : ByteCodeField {
+ /// Apply an externally registered constraint.
+ ApplyConstraint,
+ /// Apply an externally registered rewrite.
+ ApplyRewrite,
+ /// Check if two generic values are equal.
+ AreEqual,
+ /// Unconditional branch.
+ Branch,
+ /// Compare the operand count of an operation with a constant.
+ CheckOperandCount,
+ /// Compare the name of an operation with a constant.
+ CheckOperationName,
+ /// Compare the result count of an operation with a constant.
+ CheckResultCount,
+ /// Invoke a native creation method.
+ CreateNative,
+ /// Create an operation.
+ CreateOperation,
+ /// Erase an operation.
+ EraseOp,
+ /// Terminate a matcher or rewrite sequence.
+ Finalize,
+ /// Get a specific attribute of an operation.
+ GetAttribute,
+ /// Get the type of an attribute.
+ GetAttributeType,
+ /// Get the defining operation of a value.
+ GetDefiningOp,
+ /// Get a specific operand of an operation.
+ GetOperand0,
+ GetOperand1,
+ GetOperand2,
+ GetOperand3,
+ GetOperandN,
+ /// Get a specific result of an operation.
+ GetResult0,
+ GetResult1,
+ GetResult2,
+ GetResult3,
+ GetResultN,
+ /// Get the type of a value.
+ GetValueType,
+ /// Check if a generic value is not null.
+ IsNotNull,
+ /// Record a successful pattern match.
+ RecordMatch,
+ /// Replace an operation.
+ ReplaceOp,
+ /// Compare an attribute with a set of constants.
+ SwitchAttribute,
+ /// Compare the operand count of an operation with a set of constants.
+ SwitchOperandCount,
+ /// Compare the name of an operation with a set of constants.
+ SwitchOperationName,
+ /// Compare the result count of an operation with a set of constants.
+ SwitchResultCount,
+ /// Compare a type with a set of constants.
+ SwitchType,
+};
+
+enum class PDLValueKind { Attribute, Operation, Type, Value };
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ByteCode Generation
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Generator
+
+namespace {
+struct ByteCodeWriter;
+
+/// This class represents the main generator for the pattern bytecode.
+class Generator {
+public:
+ Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
+ SmallVectorImpl<ByteCodeField> &matcherByteCode,
+ SmallVectorImpl<ByteCodeField> &rewriterByteCode,
+ SmallVectorImpl<PDLByteCodePattern> &patterns,
+ ByteCodeField &maxValueMemoryIndex,
+ llvm::StringMap<PDLConstraintFunction> &constraintFns,
+ llvm::StringMap<PDLCreateFunction> &createFns,
+ llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+ : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
+ rewriterByteCode(rewriterByteCode), patterns(patterns),
+ maxValueMemoryIndex(maxValueMemoryIndex) {
+ for (auto it : llvm::enumerate(constraintFns))
+ constraintToMemIndex.try_emplace(it.value().first(), it.index());
+ for (auto it : llvm::enumerate(createFns))
+ nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
+ for (auto it : llvm::enumerate(rewriteFns))
+ externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
+ }
+
+ /// Generate the bytecode for the given PDL interpreter module.
+ void generate(ModuleOp module);
+
+ /// Return the memory index to use for the given value.
+ ByteCodeField &getMemIndex(Value value) {
+ assert(valueToMemIndex.count(value) &&
+ "expected memory index to be assigned");
+ return valueToMemIndex[value];
+ }
+
+ /// Return an index to use when referring to the given data that is uniqued in
+ /// the MLIR context.
+ template <typename T>
+ std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
+ getMemIndex(T val) {
+ const void *opaqueVal = val.getAsOpaquePointer();
+
+ // Get or insert a reference to this value.
+ auto it = uniquedDataToMemIndex.try_emplace(
+ opaqueVal, maxValueMemoryIndex + uniquedData.size());
+ if (it.second)
+ uniquedData.push_back(opaqueVal);
+ return it.first->second;
+ }
+
+private:
+ /// Allocate memory indices for the results of operations within the matcher
+ /// and rewriters.
+ void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
+
+ /// Generate the bytecode for the given operation.
+ void generate(Operation *op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
+
+ /// Mapping from value to its corresponding memory index.
+ DenseMap<Value, ByteCodeField> valueToMemIndex;
+
+ /// Mapping from the name of an externally registered rewrite to its index in
+ /// the bytecode registry.
+ llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
+
+ /// Mapping from the name of an externally registered constraint to its index
+ /// in the bytecode registry.
+ llvm::StringMap<ByteCodeField> constraintToMemIndex;
+
+ /// Mapping from the name of an externally registered creation method to its
+ /// index in the bytecode registry.
+ llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
+
+ /// Mapping from rewriter function name to the bytecode address of the
+ /// rewriter function in byte.
+ llvm::StringMap<ByteCodeAddr> rewriterToAddr;
+
+ /// Mapping from a uniqued storage object to its memory index within
+ /// `uniquedData`.
+ DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
+
+ /// The current MLIR context.
+ MLIRContext *ctx;
+
+ /// Data of the ByteCode class to be populated.
+ std::vector<const void *> &uniquedData;
+ SmallVectorImpl<ByteCodeField> &matcherByteCode;
+ SmallVectorImpl<ByteCodeField> &rewriterByteCode;
+ SmallVectorImpl<PDLByteCodePattern> &patterns;
+ ByteCodeField &maxValueMemoryIndex;
+};
+
+/// This class provides utilities for writing a bytecode stream.
+struct ByteCodeWriter {
+ ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
+ : bytecode(bytecode), generator(generator) {}
+
+ /// Append a field to the bytecode.
+ void append(ByteCodeField field) { bytecode.push_back(field); }
+
+ /// Append an address to the bytecode.
+ void append(ByteCodeAddr field) {
+ static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
+ "unexpected ByteCode address size");
+
+ ByteCodeField fieldParts[2];
+ std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
+ bytecode.append({fieldParts[0], fieldParts[1]});
+ }
+
+ /// Append a successor range to the bytecode, the exact address will need to
+ /// be resolved later.
+ void append(SuccessorRange successors) {
+ // Add back references to the any successors so that the address can be
+ // resolved later.
+ for (Block *successor : successors) {
+ unresolvedSuccessorRefs[successor].push_back(bytecode.size());
+ append(ByteCodeAddr(0));
+ }
+ }
+
+ /// Append a range of values that will be read as generic PDLValues.
+ void appendPDLValueList(OperandRange values) {
+ bytecode.push_back(values.size());
+ for (Value value : values) {
+ // Append the type of the value in addition to the value itself.
+ PDLValueKind kind =
+ TypeSwitch<Type, PDLValueKind>(value.getType())
+ .Case<pdl::AttributeType>(
+ [](Type) { return PDLValueKind::Attribute; })
+ .Case<pdl::OperationType>(
+ [](Type) { return PDLValueKind::Operation; })
+ .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
+ .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
+ bytecode.push_back(static_cast<ByteCodeField>(kind));
+ append(value);
+ }
+ }
+
+ /// Check if the given class `T` has an iterator type.
+ template <typename T, typename... Args>
+ using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
+
+ /// Append a value that will be stored in a memory slot and not inline within
+ /// the bytecode.
+ template <typename T>
+ std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
+ std::is_pointer<T>::value>
+ append(T value) {
+ bytecode.push_back(generator.getMemIndex(value));
+ }
+
+ /// Append a range of values.
+ template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
+ std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
+ append(T range) {
+ bytecode.push_back(llvm::size(range));
+ for (auto it : range)
+ append(it);
+ }
+
+ /// Append a variadic number of fields to the bytecode.
+ template <typename FieldTy, typename Field2Ty, typename... FieldTys>
+ void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
+ append(field);
+ append(field2, fields...);
+ }
+
+ /// Successor references in the bytecode that have yet to be resolved.
+ DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
+
+ /// The underlying bytecode buffer.
+ SmallVectorImpl<ByteCodeField> &bytecode;
+
+ /// The main generator producing PDL.
+ Generator &generator;
+};
+} // end anonymous namespace
+
+void Generator::generate(ModuleOp module) {
+ FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
+ pdl_interp::PDLInterpDialect::getMatcherFunctionName());
+ ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
+ pdl_interp::PDLInterpDialect::getRewriterModuleName());
+ assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
+
+ // Allocate memory indices for the results of operations within the matcher
+ // and rewriters.
+ allocateMemoryIndices(matcherFunc, rewriterModule);
+
+ // Generate code for the rewriter functions.
+ ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
+ for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
+ rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
+ for (Operation &op : rewriterFunc.getOps())
+ generate(&op, rewriterByteCodeWriter);
+ }
+ assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
+ "unexpected branches in rewriter function");
+
+ // Generate code for the matcher function.
+ DenseMap<Block *, ByteCodeAddr> blockToAddr;
+ llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
+ ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
+ for (Block *block : rpot) {
+ // Keep track of where this block begins within the matcher function.
+ blockToAddr.try_emplace(block, matcherByteCode.size());
+ for (Operation &op : *block)
+ generate(&op, matcherByteCodeWriter);
+ }
+
+ // Resolve successor references in the matcher.
+ for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
+ ByteCodeAddr addr = blockToAddr[it.first];
+ for (unsigned offsetToFix : it.second)
+ std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
+ }
+}
+
+void Generator::allocateMemoryIndices(FuncOp matcherFunc,
+ ModuleOp rewriterModule) {
+ // Rewriters use simplistic allocation scheme that simply assigns an index to
+ // each result.
+ for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
+ ByteCodeField index = 0;
+ for (BlockArgument arg : rewriterFunc.getArguments())
+ valueToMemIndex.try_emplace(arg, index++);
+ rewriterFunc.getBody().walk([&](Operation *op) {
+ for (Value result : op->getResults())
+ valueToMemIndex.try_emplace(result, index++);
+ });
+ if (index > maxValueMemoryIndex)
+ maxValueMemoryIndex = index;
+ }
+
+ // The matcher function uses a more sophisticated numbering that tries to
+ // minimize the number of memory indices assigned. This is done by determining
+ // a live range of the values within the matcher, then the allocation is just
+ // finding the minimal number of overlapping live ranges. This is essentially
+ // a simplified form of register allocation where we don't necessarily have a
+ // limited number of registers, but we still want to minimize the number used.
+ DenseMap<Operation *, ByteCodeField> opToIndex;
+ matcherFunc.getBody().walk([&](Operation *op) {
+ opToIndex.insert(std::make_pair(op, opToIndex.size()));
+ });
+
+ // Liveness info for each of the defs within the matcher.
+ using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
+ LivenessSet::Allocator allocator;
+ DenseMap<Value, LivenessSet> valueDefRanges;
+
+ // Assign the root operation being matched to slot 0.
+ BlockArgument rootOpArg = matcherFunc.getArgument(0);
+ valueToMemIndex[rootOpArg] = 0;
+
+ // Walk each of the blocks, computing the def interval that the value is used.
+ Liveness matcherLiveness(matcherFunc);
+ for (Block &block : matcherFunc.getBody()) {
+ const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
+ assert(info && "expected liveness info for block");
+ auto processValue = [&](Value value, Operation *firstUseOrDef) {
+ // We don't need to process the root op argument, this value is always
+ // assigned to the first memory slot.
+ if (value == rootOpArg)
+ return;
+
+ // Set indices for the range of this block that the value is used.
+ auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
+ defRangeIt->second.insert(
+ opToIndex[firstUseOrDef],
+ opToIndex[info->getEndOperation(value, firstUseOrDef)],
+ /*dummyValue*/ 0);
+ };
+
+ // Process the live-ins of this block.
+ for (Value liveIn : info->in())
+ processValue(liveIn, &block.front());
+
+ // Process any new defs within this block.
+ for (Operation &op : block)
+ for (Value result : op.getResults())
+ processValue(result, &op);
+ }
+
+ // Greedily allocate memory slots using the computed def live ranges.
+ std::vector<LivenessSet> allocatedIndices;
+ for (auto &defIt : valueDefRanges) {
+ ByteCodeField &memIndex = valueToMemIndex[defIt.first];
+ LivenessSet &defSet = defIt.second;
+
+ // Try to allocate to an existing index.
+ for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
+ LivenessSet &existingIndex = existingIndexIt.value();
+ llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
+ defIt.second, existingIndex);
+ if (overlaps.valid())
+ continue;
+ // Union the range of the def within the existing index.
+ for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
+ existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ memIndex = existingIndexIt.index() + 1;
+ }
+
+ // If no existing index could be used, add a new one.
+ if (memIndex == 0) {
+ allocatedIndices.emplace_back(allocator);
+ for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
+ allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
+ memIndex = allocatedIndices.size();
+ }
+ }
+
+ // Update the max number of indices.
+ ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
+ if (numMatcherIndices > maxValueMemoryIndex)
+ maxValueMemoryIndex = numMatcherIndices;
+}
+
+void Generator::generate(Operation *op, ByteCodeWriter &writer) {
+ TypeSwitch<Operation *>(op)
+ .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
+ pdl_interp::AreEqualOp, pdl_interp::BranchOp,
+ pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
+ pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
+ pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
+ pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
+ pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
+ pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
+ pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
+ pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
+ pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
+ pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
+ pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
+ pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
+ pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+ [&](auto interpOp) { this->generate(interpOp, writer); })
+ .Default([](Operation *) {
+ llvm_unreachable("unknown `pdl_interp` operation");
+ });
+}
+
+void Generator::generate(pdl_interp::ApplyConstraintOp op,
+ ByteCodeWriter &writer) {
+ assert(constraintToMemIndex.count(op.name()) &&
+ "expected index for constraint function");
+ writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
+ op.constParamsAttr());
+ writer.appendPDLValueList(op.args());
+ writer.append(op.getSuccessors());
+}
+void Generator::generate(pdl_interp::ApplyRewriteOp op,
+ ByteCodeWriter &writer) {
+ assert(externalRewriterToMemIndex.count(op.name()) &&
+ "expected index for rewrite function");
+ writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
+ op.constParamsAttr(), op.root());
+ writer.appendPDLValueList(op.args());
+}
+void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::Branch, SuccessorRange(op));
+}
+void Generator::generate(pdl_interp::CheckAttributeOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckOperandCountOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckOperationNameOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::CheckOperationName, op.operation(),
+ OperationName(op.name(), ctx), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckResultCountOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::CreateAttributeOp op,
+ ByteCodeWriter &writer) {
+ // Simply repoint the memory index of the result to the constant.
+ getMemIndex(op.attribute()) = getMemIndex(op.value());
+}
+void Generator::generate(pdl_interp::CreateNativeOp op,
+ ByteCodeWriter &writer) {
+ assert(nativeCreateToMemIndex.count(op.name()) &&
+ "expected index for creation function");
+ writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
+ op.result(), op.constParamsAttr());
+ writer.appendPDLValueList(op.args());
+}
+void Generator::generate(pdl_interp::CreateOperationOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::CreateOperation, op.operation(),
+ OperationName(op.name(), ctx), op.operands());
+
+ // Add the attributes.
+ OperandRange attributes = op.attributes();
+ writer.append(static_cast<ByteCodeField>(attributes.size()));
+ for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
+ writer.append(
+ Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
+ std::get<1>(it));
+ }
+ writer.append(op.types());
+}
+void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
+ // Simply repoint the memory index of the result to the constant.
+ getMemIndex(op.result()) = getMemIndex(op.value());
+}
+void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::EraseOp, op.operation());
+}
+void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::Finalize);
+}
+void Generator::generate(pdl_interp::GetAttributeOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
+ Identifier::get(op.name(), ctx));
+}
+void Generator::generate(pdl_interp::GetAttributeTypeOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::GetAttributeType, op.result(), op.value());
+}
+void Generator::generate(pdl_interp::GetDefiningOpOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
+ uint32_t index = op.index();
+ if (index < 4)
+ writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
+ else
+ writer.append(OpCode::GetOperandN, index);
+ writer.append(op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
+ uint32_t index = op.index();
+ if (index < 4)
+ writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
+ else
+ writer.append(OpCode::GetResultN, index);
+ writer.append(op.operation(), op.value());
+}
+void Generator::generate(pdl_interp::GetValueTypeOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::GetValueType, op.result(), op.value());
+}
+void Generator::generate(pdl_interp::InferredTypeOp op,
+ ByteCodeWriter &writer) {
+ // InferType maps to a null type as a marker for inferring a result type.
+ getMemIndex(op.type()) = getMemIndex(Type());
+}
+void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
+}
+void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
+ ByteCodeField patternIndex = patterns.size();
+ patterns.emplace_back(PDLByteCodePattern::create(
+ op, rewriterToAddr[op.rewriter().getLeafReference()]));
+ writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op),
+ op.matchedOps(), op.inputs());
+}
+void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
+}
+void Generator::generate(pdl_interp::SwitchAttributeOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchOperandCountOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchOperationNameOp op,
+ ByteCodeWriter &writer) {
+ auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
+ return OperationName(attr.cast<StringAttr>().getValue(), ctx);
+ });
+ writer.append(OpCode::SwitchOperationName, op.operation(), cases,
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchResultCountOp op,
+ ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
+void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
+ writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
+ op.getSuccessors());
+}
+
+//===----------------------------------------------------------------------===//
+// PDLByteCode
+//===----------------------------------------------------------------------===//
+
+PDLByteCode::PDLByteCode(ModuleOp module,
+ llvm::StringMap<PDLConstraintFunction> constraintFns,
+ llvm::StringMap<PDLCreateFunction> createFns,
+ llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+ Generator generator(module.getContext(), uniquedData, matcherByteCode,
+ rewriterByteCode, patterns, maxValueMemoryIndex,
+ constraintFns, createFns, rewriteFns);
+ generator.generate(module);
+
+ // Initialize the external functions.
+ for (auto &it : constraintFns)
+ constraintFunctions.push_back(std::move(it.second));
+ for (auto &it : createFns)
+ createFunctions.push_back(std::move(it.second));
+ for (auto &it : rewriteFns)
+ rewriteFunctions.push_back(std::move(it.second));
+}
+
+/// Initialize the given state such that it can be used to execute the current
+/// bytecode.
+void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
+ state.memory.resize(maxValueMemoryIndex, nullptr);
+ state.currentPatternBenefits.reserve(patterns.size());
+ for (const PDLByteCodePattern &pattern : patterns)
+ state.currentPatternBenefits.push_back(pattern.getBenefit());
+}
+
+//===----------------------------------------------------------------------===//
+// ByteCode Execution
+
+namespace {
+/// This class provides support for executing a bytecode stream.
+class ByteCodeExecutor {
+public:
+ ByteCodeExecutor(const ByteCodeField *curCodeIt,
+ MutableArrayRef<const void *> memory,
+ ArrayRef<const void *> uniquedMemory,
+ ArrayRef<ByteCodeField> code,
+ ArrayRef<PatternBenefit> currentPatternBenefits,
+ ArrayRef<PDLByteCodePattern> patterns,
+ ArrayRef<PDLConstraintFunction> constraintFunctions,
+ ArrayRef<PDLCreateFunction> createFunctions,
+ ArrayRef<PDLRewriteFunction> rewriteFunctions)
+ : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
+ code(code), currentPatternBenefits(currentPatternBenefits),
+ patterns(patterns), constraintFunctions(constraintFunctions),
+ createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
+
+ /// Start executing the code at the current bytecode index. `matches` is an
+ /// optional field provided when this function is executed in a matching
+ /// context.
+ void execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+ Optional<Location> mainRewriteLoc = {});
+
+private:
+ /// Read a value from the bytecode buffer, optionally skipping a certain
+ /// number of prefix values. These methods always update the buffer to point
+ /// to the next field after the read data.
+ template <typename T = ByteCodeField>
+ T read(size_t skipN = 0) {
+ curCodeIt += skipN;
+ return readImpl<T>();
+ }
+ ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
+
+ /// Read a list of values from the bytecode buffer.
+ template <typename ValueT, typename T>
+ void readList(SmallVectorImpl<T> &list) {
+ list.clear();
+ for (unsigned i = 0, e = read(); i != e; ++i)
+ list.push_back(read<ValueT>());
+ }
+
+ /// Jump to a specific successor based on a predicate value.
+ void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
+ /// Jump to a specific successor based on a destination index.
+ void selectJump(size_t destIndex) {
+ curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
+ }
+
+ /// Handle a switch operation with the provided value and cases.
+ template <typename T, typename RangeT>
+ void handleSwitch(const T &value, RangeT &&cases) {
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Cases: ";
+ llvm::interleaveComma(cases, llvm::dbgs());
+ llvm::dbgs() << "\n\n";
+ });
+
+ // Check to see if the attribute value is within the case list. Jump to
+ // the correct successor index based on the result.
+ auto it = llvm::find(cases, value);
+ selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1));
+ }
+
+ /// Internal implementation of reading various data types from the bytecode
+ /// stream.
+ template <typename T>
+ const void *readFromMemory() {
+ size_t index = *curCodeIt++;
+
+ // If this type is an SSA value, it can only be stored in non-const memory.
+ if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
+ return memory[index];
+
+ // Otherwise, if this index is not inbounds it is uniqued.
+ return uniquedMemory[index - memory.size()];
+ }
+ template <typename T>
+ std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
+ return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
+ }
+ template <typename T>
+ std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
+ T>
+ readImpl() {
+ return T(T::getFromOpaquePointer(readFromMemory<T>()));
+ }
+ template <typename T>
+ std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
+ switch (static_cast<PDLValueKind>(read())) {
+ case PDLValueKind::Attribute:
+ return read<Attribute>();
+ case PDLValueKind::Operation:
+ return read<Operation *>();
+ case PDLValueKind::Type:
+ return read<Type>();
+ case PDLValueKind::Value:
+ return read<Value>();
+ }
+ }
+ template <typename T>
+ std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
+ static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
+ "unexpected ByteCode address size");
+ ByteCodeAddr result;
+ std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
+ curCodeIt += 2;
+ return result;
+ }
+ template <typename T>
+ std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
+ return *curCodeIt++;
+ }
+
+ /// The underlying bytecode buffer.
+ const ByteCodeField *curCodeIt;
+
+ /// The current execution memory.
+ MutableArrayRef<const void *> memory;
+
+ /// References to ByteCode data necessary for execution.
+ ArrayRef<const void *> uniquedMemory;
+ ArrayRef<ByteCodeField> code;
+ ArrayRef<PatternBenefit> currentPatternBenefits;
+ ArrayRef<PDLByteCodePattern> patterns;
+ ArrayRef<PDLConstraintFunction> constraintFunctions;
+ ArrayRef<PDLCreateFunction> createFunctions;
+ ArrayRef<PDLRewriteFunction> rewriteFunctions;
+};
+} // end anonymous namespace
+
+void ByteCodeExecutor::execute(
+ PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+ Optional<Location> mainRewriteLoc) {
+ while (true) {
+ OpCode opCode = static_cast<OpCode>(read());
+ switch (opCode) {
+ case ApplyConstraint: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
+ const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+ ArrayAttr constParams = read<ArrayAttr>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
+ });
+
+ // Invoke the constraint and jump to the proper destination.
+ selectJump(succeeded(constraintFn(args, constParams, rewriter)));
+ break;
+ }
+ case ApplyRewrite: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
+ const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
+ ArrayAttr constParams = read<ArrayAttr>();
+ Operation *root = read<Operation *>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Root: " << *root << "\n"
+ << " * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
+ });
+ rewriteFn(root, args, constParams, rewriter);
+ break;
+ }
+ case AreEqual: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
+ const void *lhs = read<const void *>();
+ const void *rhs = read<const void *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
+ selectJump(lhs == rhs);
+ break;
+ }
+ case Branch: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
+ curCodeIt = &code[read<ByteCodeAddr>()];
+ break;
+ }
+ case CheckOperandCount: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
+ Operation *op = read<Operation *>();
+ uint32_t expectedCount = read<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
+ << " * Expected: " << expectedCount << "\n\n");
+ selectJump(op->getNumOperands() == expectedCount);
+ break;
+ }
+ case CheckOperationName: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
+ Operation *op = read<Operation *>();
+ OperationName expectedName = read<OperationName>();
+
+ LLVM_DEBUG(llvm::dbgs()
+ << " * Found: \"" << op->getName() << "\"\n"
+ << " * Expected: \"" << expectedName << "\"\n\n");
+ selectJump(op->getName() == expectedName);
+ break;
+ }
+ case CheckResultCount: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
+ Operation *op = read<Operation *>();
+ uint32_t expectedCount = read<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
+ << " * Expected: " << expectedCount << "\n\n");
+ selectJump(op->getNumResults() == expectedCount);
+ break;
+ }
+ case CreateNative: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
+ const PDLCreateFunction &createFn = createFunctions[read()];
+ ByteCodeField resultIndex = read();
+ ArrayAttr constParams = read<ArrayAttr>();
+ SmallVector<PDLValue, 16> args;
+ readList<PDLValue>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Arguments: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
+ });
+
+ PDLValue result = createFn(args, constParams, rewriter);
+ memory[resultIndex] = result.getAsOpaquePointer();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
+ break;
+ }
+ case CreateOperation: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
+ assert(mainRewriteLoc && "expected rewrite loc to be provided when "
+ "executing the rewriter bytecode");
+
+ unsigned memIndex = read();
+ OperationState state(*mainRewriteLoc, read<OperationName>());
+ readList<Value>(state.operands);
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ Identifier name = read<Identifier>();
+ if (Attribute attr = read<Attribute>())
+ state.addAttribute(name, attr);
+ }
+
+ bool hasInferredTypes = false;
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ Type resultType = read<Type>();
+ hasInferredTypes |= !resultType;
+ state.types.push_back(resultType);
+ }
+
+ // Handle the case where the operation has inferred types.
+ if (hasInferredTypes) {
+ InferTypeOpInterface::Concept *concept =
+ state.name.getAbstractOperation()
+ ->getInterface<InferTypeOpInterface>();
+
+ // TODO: Handle failure.
+ SmallVector<Type, 2> inferredTypes;
+ if (failed(concept->inferReturnTypes(
+ state.getContext(), state.location, state.operands,
+ state.attributes.getDictionary(state.getContext()),
+ state.regions, inferredTypes)))
+ return;
+
+ for (unsigned i = 0, e = state.types.size(); i != e; ++i)
+ if (!state.types[i])
+ state.types[i] = inferredTypes[i];
+ }
+ Operation *resultOp = rewriter.createOperation(state);
+ memory[memIndex] = resultOp;
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Attributes: "
+ << state.attributes.getDictionary(state.getContext())
+ << "\n * Operands: ";
+ llvm::interleaveComma(state.operands, llvm::dbgs());
+ llvm::dbgs() << "\n * Result Types: ";
+ llvm::interleaveComma(state.types, llvm::dbgs());
+ llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
+ });
+ break;
+ }
+ case EraseOp: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
+ Operation *op = read<Operation *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
+ rewriter.eraseOp(op);
+ break;
+ }
+ case Finalize: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
+ return;
+ }
+ case GetAttribute: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
+ unsigned memIndex = read();
+ Operation *op = read<Operation *>();
+ Identifier attrName = read<Identifier>();
+ Attribute attr = op->getAttr(attrName);
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Attribute: " << attrName << "\n"
+ << " * Result: " << attr << "\n\n");
+ memory[memIndex] = attr.getAsOpaquePointer();
+ break;
+ }
+ case GetAttributeType: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
+ unsigned memIndex = read();
+ Attribute attr = read<Attribute>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
+ << " * Result: " << attr.getType() << "\n\n");
+ memory[memIndex] = attr.getType().getAsOpaquePointer();
+ break;
+ }
+ case GetDefiningOp: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
+ unsigned memIndex = read();
+ Value value = read<Value>();
+ Operation *op = value ? value.getDefiningOp() : nullptr;
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Result: " << *op << "\n\n");
+ memory[memIndex] = op;
+ break;
+ }
+ case GetOperand0:
+ case GetOperand1:
+ case GetOperand2:
+ case GetOperand3:
+ case GetOperandN: {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Executing GetOperand"
+ << (opCode == GetOperandN ? Twine("N")
+ : Twine(opCode - GetOperand0))
+ << ":\n";
+ });
+ unsigned index =
+ opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
+ Operation *op = read<Operation *>();
+ unsigned memIndex = read();
+ Value operand =
+ index < op->getNumOperands() ? op->getOperand(index) : Value();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << operand << "\n\n");
+ memory[memIndex] = operand.getAsOpaquePointer();
+ break;
+ }
+ case GetResult0:
+ case GetResult1:
+ case GetResult2:
+ case GetResult3:
+ case GetResultN: {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Executing GetResult"
+ << (opCode == GetResultN ? Twine("N")
+ : Twine(opCode - GetResult0))
+ << ":\n";
+ });
+ unsigned index =
+ opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
+ Operation *op = read<Operation *>();
+ unsigned memIndex = read();
+ OpResult result =
+ index < op->getNumResults() ? op->getResult(index) : OpResult();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Index: " << index << "\n"
+ << " * Result: " << result << "\n\n");
+ memory[memIndex] = result.getAsOpaquePointer();
+ break;
+ }
+ case GetValueType: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
+ unsigned memIndex = read();
+ Value value = read<Value>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Result: " << value.getType() << "\n\n");
+ memory[memIndex] = value.getType().getAsOpaquePointer();
+ break;
+ }
+ case IsNotNull: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
+ const void *value = read<const void *>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
+ selectJump(value != nullptr);
+ break;
+ }
+ case RecordMatch: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
+ assert(matches &&
+ "expected matches to be provided when executing the matcher");
+ unsigned patternIndex = read();
+ PatternBenefit benefit = currentPatternBenefits[patternIndex];
+ const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
+
+ // If the benefit of the pattern is impossible, skip the processing of the
+ // rest of the pattern.
+ if (benefit.isImpossibleToMatch()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
+ curCodeIt = dest;
+ break;
+ }
+
+ // Create a fused location containing the locations of each of the
+ // operations used in the match. This will be used as the location for
+ // created operations during the rewrite that don't already have an
+ // explicit location set.
+ unsigned numMatchLocs = read();
+ SmallVector<Location, 4> matchLocs;
+ matchLocs.reserve(numMatchLocs);
+ for (unsigned i = 0; i != numMatchLocs; ++i)
+ matchLocs.push_back(read<Operation *>()->getLoc());
+ Location matchLoc = rewriter.getFusedLoc(matchLocs);
+
+ LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
+ << " * Location: " << matchLoc << "\n\n");
+ matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
+ readList<const void *>(matches->back().values);
+ curCodeIt = dest;
+ break;
+ }
+ case ReplaceOp: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
+ Operation *op = read<Operation *>();
+ SmallVector<Value, 16> args;
+ readList<Value>(args);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << " * Operation: " << *op << "\n"
+ << " * Values: ";
+ llvm::interleaveComma(args, llvm::dbgs());
+ llvm::dbgs() << "\n\n";
+ });
+ rewriter.replaceOp(op, args);
+ break;
+ }
+ case SwitchAttribute: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
+ Attribute value = read<Attribute>();
+ ArrayAttr cases = read<ArrayAttr>();
+ handleSwitch(value, cases);
+ break;
+ }
+ case SwitchOperandCount: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
+ Operation *op = read<Operation *>();
+ auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ handleSwitch(op->getNumOperands(), cases);
+ break;
+ }
+ case SwitchOperationName: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
+ OperationName value = read<Operation *>()->getName();
+ size_t caseCount = read();
+
+ // The operation names are stored in-line, so to print them out for
+ // debugging purposes we need to read the array before executing the
+ // switch so that we can display all of the possible values.
+ LLVM_DEBUG({
+ const ByteCodeField *prevCodeIt = curCodeIt;
+ llvm::dbgs() << " * Value: " << value << "\n"
+ << " * Cases: ";
+ llvm::interleaveComma(
+ llvm::map_range(llvm::seq<size_t>(0, caseCount),
+ [&](size_t i) { return read<OperationName>(); }),
+ llvm::dbgs());
+ llvm::dbgs() << "\n\n";
+ curCodeIt = prevCodeIt;
+ });
+
+ // Try to find the switch value within any of the cases.
+ size_t jumpDest = 0;
+ for (size_t i = 0; i != caseCount; ++i) {
+ if (read<OperationName>() == value) {
+ curCodeIt += (caseCount - i - 1);
+ jumpDest = i + 1;
+ break;
+ }
+ }
+ selectJump(jumpDest);
+ break;
+ }
+ case SwitchResultCount: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
+ Operation *op = read<Operation *>();
+ auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
+
+ LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
+ handleSwitch(op->getNumResults(), cases);
+ break;
+ }
+ case SwitchType: {
+ LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
+ Type value = read<Type>();
+ auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
+ handleSwitch(value, cases);
+ break;
+ }
+ }
+ }
+}
+
+/// Run the pattern matcher on the given root operation, collecting the matched
+/// patterns in `matches`.
+void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
+ SmallVectorImpl<MatchResult> &matches,
+ PDLByteCodeMutableState &state) const {
+ // The first memory slot is always the root operation.
+ state.memory[0] = op;
+
+ // The matcher function always starts at code address 0.
+ ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
+ matcherByteCode, state.currentPatternBenefits,
+ patterns, constraintFunctions, createFunctions,
+ rewriteFunctions);
+ executor.execute(rewriter, &matches);
+
+ // Order the found matches by benefit.
+ std::stable_sort(matches.begin(), matches.end(),
+ [](const MatchResult &lhs, const MatchResult &rhs) {
+ return lhs.benefit > rhs.benefit;
+ });
+}
+
+/// Run the rewriter of the given pattern on the root operation `op`.
+void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
+ PDLByteCodeMutableState &state) const {
+ // The arguments of the rewrite function are stored at the start of the
+ // memory buffer.
+ llvm::copy(match.values, state.memory.begin());
+
+ ByteCodeExecutor executor(
+ &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
+ uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
+ constraintFunctions, createFunctions, rewriteFunctions);
+ executor.execute(rewriter, /*matches=*/nullptr, match.location);
+}
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
new file mode 100644
index 000000000000..7126037f864a
--- /dev/null
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -0,0 +1,173 @@
+//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
+// The byte-code is constructed from the PDL Interpreter dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REWRITE_BYTECODE_H_
+#define MLIR_REWRITE_BYTECODE_H_
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace pdl_interp {
+class RecordMatchOp;
+} // end namespace pdl_interp
+
+namespace detail {
+class PDLByteCode;
+
+/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
+/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
+/// indices into the bytecode. Correctness is checked with static asserts.
+using ByteCodeField = uint16_t;
+using ByteCodeAddr = uint32_t;
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodePattern
+//===----------------------------------------------------------------------===//
+
+/// All of the data pertaining to a specific pattern within the bytecode.
+class PDLByteCodePattern : public Pattern {
+public:
+ static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
+ ByteCodeAddr rewriterAddr);
+
+ /// Return the bytecode address of the rewriter for this pattern.
+ ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
+
+private:
+ template <typename... Args>
+ PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
+ : Pattern(std::forward<Args>(patternArgs)...),
+ rewriterAddr(rewriterAddr) {}
+
+ /// The address of the rewriter for this pattern.
+ ByteCodeAddr rewriterAddr;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLByteCodeMutableState
+//===----------------------------------------------------------------------===//
+
+/// This class contains the mutable state of a bytecode instance. This allows
+/// for a bytecode instance to be cached and reused across various
diff erent
+/// threads/drivers.
+class PDLByteCodeMutableState {
+public:
+ /// Initialize the state from a bytecode instance.
+ void initialize(PDLByteCode &bytecode);
+
+ /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
+ /// to the position of the pattern within the range returned by
+ /// `PDLByteCode::getPatterns`.
+ void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
+
+private:
+ /// Allow access to data fields.
+ friend class PDLByteCode;
+
+ /// The mutable block of memory used during the matching and rewriting phases
+ /// of the bytecode.
+ std::vector<const void *> memory;
+
+ /// The up-to-date benefits of the patterns held by the bytecode. The order
+ /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
+ std::vector<PatternBenefit> currentPatternBenefits;
+};
+
+//===----------------------------------------------------------------------===//
+// PDLByteCode
+//===----------------------------------------------------------------------===//
+
+/// The bytecode class is also the interpreter. Contains the bytecode itself,
+/// the static info, addresses of the rewriter functions, the interpreter
+/// memory buffer, and the execution context.
+class PDLByteCode {
+public:
+ /// Each successful match returns a MatchResult, which contains information
+ /// necessary to execute the rewriter and indicates the originating pattern.
+ struct MatchResult {
+ MatchResult(Location loc, const PDLByteCodePattern &pattern,
+ PatternBenefit benefit)
+ : location(loc), pattern(&pattern), benefit(benefit) {}
+
+ /// The location of operations to be replaced.
+ Location location;
+ /// Memory values defined in the matcher that are passed to the rewriter.
+ SmallVector<const void *, 4> values;
+ /// The originating pattern that was matched. This is always non-null, but
+ /// represented with a pointer to allow for assignment.
+ const PDLByteCodePattern *pattern;
+ /// The current benefit of the pattern that was matched.
+ PatternBenefit benefit;
+ };
+
+ /// Create a ByteCode instance from the given module containing operations in
+ /// the PDL interpreter dialect.
+ PDLByteCode(ModuleOp module,
+ llvm::StringMap<PDLConstraintFunction> constraintFns,
+ llvm::StringMap<PDLCreateFunction> createFns,
+ llvm::StringMap<PDLRewriteFunction> rewriteFns);
+
+ /// Return the patterns held by the bytecode.
+ ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
+
+ /// Initialize the given state such that it can be used to execute the current
+ /// bytecode.
+ void initializeMutableState(PDLByteCodeMutableState &state) const;
+
+ /// Run the pattern matcher on the given root operation, collecting the
+ /// matched patterns in `matches`.
+ void match(Operation *op, PatternRewriter &rewriter,
+ SmallVectorImpl<MatchResult> &matches,
+ PDLByteCodeMutableState &state) const;
+
+ /// Run the rewriter of the given pattern that was previously matched in
+ /// `match`.
+ void rewrite(PatternRewriter &rewriter, const MatchResult &match,
+ PDLByteCodeMutableState &state) const;
+
+private:
+ /// Execute the given byte code starting at the provided instruction `inst`.
+ /// `matches` is an optional field provided when this function is executed in
+ /// a matching context.
+ void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
+ PDLByteCodeMutableState &state,
+ SmallVectorImpl<MatchResult> *matches) const;
+
+ /// A vector containing pointers to unqiued data. The storage is intentionally
+ /// opaque such that we can store a wide range of data types. The types of
+ /// data stored here include:
+ /// * Attribute, Identifier, OperationName, Type
+ std::vector<const void *> uniquedData;
+
+ /// A vector containing the generated bytecode for the matcher.
+ SmallVector<ByteCodeField, 64> matcherByteCode;
+
+ /// A vector containing the generated bytecode for all of the rewriters.
+ SmallVector<ByteCodeField, 64> rewriterByteCode;
+
+ /// The set of patterns contained within the bytecode.
+ SmallVector<PDLByteCodePattern, 32> patterns;
+
+ /// A set of user defined functions invoked via PDL.
+ std::vector<PDLConstraintFunction> constraintFunctions;
+ std::vector<PDLCreateFunction> createFunctions;
+ std::vector<PDLRewriteFunction> rewriteFunctions;
+
+ /// The maximum memory index used by a value.
+ ByteCodeField maxValueMemoryIndex = 0;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_REWRITE_BYTECODE_H_
diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt
index e37b9c31dab9..5822789cc916 100644
--- a/mlir/lib/Rewrite/CMakeLists.txt
+++ b/mlir/lib/Rewrite/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_library(MLIRRewrite
+ ByteCode.cpp
FrozenRewritePatternList.cpp
PatternApplicator.cpp
@@ -10,4 +11,8 @@ add_mlir_library(MLIRRewrite
LINK_LIBS PUBLIC
MLIRIR
+ MLIRPDL
+ MLIRPDLInterp
+ MLIRPDLToPDLInterp
+ MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
index d0e45184ac28..60f6dcea88f2 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
@@ -7,13 +7,71 @@
//===----------------------------------------------------------------------===//
#include "mlir/Rewrite/FrozenRewritePatternList.h"
+#include "ByteCode.h"
+#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
using namespace mlir;
+static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
+ // Skip the conversion if the module doesn't contain pdl.
+ if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
+ return success();
+
+ // Simplify the provided PDL module. Note that we can't use the canonicalizer
+ // here because it would create a cyclic dependency.
+ auto simplifyFn = [](Operation *op) {
+ // TODO: Add folding here if ever necessary.
+ if (isOpTriviallyDead(op))
+ op->erase();
+ };
+ pdlModule.getBody()->walk(simplifyFn);
+
+ /// Lower the PDL pattern module to the interpreter dialect.
+ PassManager pdlPipeline(pdlModule.getContext());
+#ifdef NDEBUG
+ // We don't want to incur the hit of running the verifier when in release
+ // mode.
+ pdlPipeline.enableVerifier(false);
+#endif
+ pdlPipeline.addPass(createPDLToPDLInterpPass());
+ if (failed(pdlPipeline.run(pdlModule)))
+ return failure();
+
+ // Simplify again after running the lowering pipeline.
+ pdlModule.getBody()->walk(simplifyFn);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FrozenRewritePatternList
//===----------------------------------------------------------------------===//
FrozenRewritePatternList::FrozenRewritePatternList(
OwningRewritePatternList &&patterns)
- : patterns(patterns.takePatterns()) {}
+ : nativePatterns(std::move(patterns.getNativePatterns())) {
+ PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
+
+ // Generate the bytecode for the PDL patterns if any were provided.
+ ModuleOp pdlModule = pdlPatterns.getModule();
+ if (!pdlModule)
+ return;
+ if (failed(convertPDLToPDLInterp(pdlModule)))
+ llvm::report_fatal_error(
+ "failed to lower PDL pattern module to the PDL Interpreter");
+
+ // Generate the pdl bytecode.
+ pdlByteCode = std::make_unique<detail::PDLByteCode>(
+ pdlModule, pdlPatterns.takeConstraintFunctions(),
+ pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
+}
+
+FrozenRewritePatternList::FrozenRewritePatternList(
+ FrozenRewritePatternList &&patterns)
+ : nativePatterns(std::move(patterns.nativePatterns)),
+ pdlByteCode(std::move(patterns.pdlByteCode)) {}
+
+FrozenRewritePatternList::~FrozenRewritePatternList() {}
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 5d6ae51e8eeb..6f5e1f299f26 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -12,17 +12,36 @@
//===----------------------------------------------------------------------===//
#include "mlir/Rewrite/PatternApplicator.h"
+#include "ByteCode.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
+using namespace mlir::detail;
+
+PatternApplicator::PatternApplicator(
+ const FrozenRewritePatternList &frozenPatternList)
+ : frozenPatternList(frozenPatternList) {
+ if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+ mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
+ bytecode->initializeMutableState(*mutableByteCodeState);
+ }
+}
+PatternApplicator::~PatternApplicator() {}
#define DEBUG_TYPE "pattern-match"
void PatternApplicator::applyCostModel(CostModel model) {
+ // Apply the cost model to the bytecode patterns first, and then the native
+ // patterns.
+ if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+ for (auto it : llvm::enumerate(bytecode->getPatterns()))
+ mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
+ }
+
// Separate patterns by root kind to simplify lookup later on.
patterns.clear();
anyOpPatterns.clear();
- for (const auto &pat : frozenPatternList.getPatterns()) {
+ for (const auto &pat : frozenPatternList.getNativePatterns()) {
// If the pattern is always impossible to match, just ignore it.
if (pat.getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
@@ -81,8 +100,12 @@ void PatternApplicator::applyCostModel(CostModel model) {
void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
- for (auto &it : frozenPatternList.getPatterns())
+ for (const Pattern &it : frozenPatternList.getNativePatterns())
walk(it);
+ if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
+ for (const Pattern &it : bytecode->getPatterns())
+ walk(it);
+ }
}
LogicalResult PatternApplicator::matchAndRewrite(
@@ -90,6 +113,14 @@ LogicalResult PatternApplicator::matchAndRewrite(
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
+ // Before checking native patterns, first match against the bytecode. This
+ // won't automatically perform any rewrites so there is no need to worry about
+ // conflicts.
+ SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
+ const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
+ if (bytecode)
+ bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
+
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<const RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
@@ -98,51 +129,50 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
- // FIXME: It'd be nice to just write an llvm::make_merge_range utility
- // and pass in a comparison function. That would make this code trivial.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
- while (opIt != opE && anyIt != anyE) {
- // Try to match the pattern providing the most benefit.
- const RewritePattern *pattern;
- if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
- pattern = *(opIt++);
- else
- pattern = *(anyIt++);
+ auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
+ while (true) {
+ // Find the next pattern with the highest benefit.
+ const Pattern *bestPattern = nullptr;
+ const PDLByteCode::MatchResult *pdlMatch = nullptr;
+ /// Operation specific patterns.
+ if (opIt != opE)
+ bestPattern = *(opIt++);
+ /// Operation agnostic patterns.
+ if (anyIt != anyE &&
+ (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
+ bestPattern = *(anyIt++);
+ /// PDL patterns.
+ if (pdlIt != pdlE &&
+ (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
+ pdlMatch = pdlIt;
+ bestPattern = (pdlIt++)->pattern;
+ }
+ if (!bestPattern)
+ break;
- // Otherwise, try to match the generic pattern.
- if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
- onSuccess)))
- return success();
- }
- // If we break from the loop, then only one of the ranges can still have
- // elements. Loop over both without checking given that we don't need to
- // interleave anymore.
- for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>(
- llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
- if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
- onSuccess)))
+ // Check that the pattern can be applied.
+ if (canApply && !canApply(*bestPattern))
+ continue;
+
+ // Try to match and rewrite this pattern. The patterns are sorted by
+ // benefit, so if we match we can immediately rewrite. For PDL patterns, the
+ // match has already been performed, we just need to rewrite.
+ rewriter.setInsertionPoint(op);
+ LogicalResult result = success();
+ if (pdlMatch) {
+ bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+ } else {
+ result = static_cast<const RewritePattern *>(bestPattern)
+ ->matchAndRewrite(op, rewriter);
+ }
+ if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
return success();
- }
- return failure();
-}
-LogicalResult PatternApplicator::matchAndRewrite(
- Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
- function_ref<bool(const Pattern &)> canApply,
- function_ref<void(const Pattern &)> onFailure,
- function_ref<LogicalResult(const Pattern &)> onSuccess) {
- // Check that the pattern can be applied.
- if (canApply && !canApply(pattern))
- return failure();
-
- // Try to match and rewrite this pattern. The patterns are sorted by
- // benefit, so if we match we can immediately rewrite.
- rewriter.setInsertionPoint(op);
- if (succeeded(pattern.matchAndRewrite(op, rewriter)))
- return success(!onSuccess || succeeded(onSuccess(pattern)));
-
- if (onFailure)
- onFailure(pattern);
+ // Perform any necessary cleanups.
+ if (onFailure)
+ onFailure(*bestPattern);
+ }
return failure();
}
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
new file mode 100644
index 000000000000..b2a22d0a8749
--- /dev/null
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -0,0 +1,785 @@
+// RUN: mlir-opt %s -test-pdl-bytecode-pass -split-input-file | FileCheck %s
+
+// Note: Tests here are written using the PDL Interpreter dialect to avoid
+// unnecessarily testing unnecessary aspects of the pattern compilation
+// pipeline. These tests are written such that we can focus solely on the
+// lowering/execution of the bytecode itself.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyConstraintOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.apply_constraint "multi_entity_constraint"(%root, %root : !pdl.operation, !pdl.operation) -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.replaced_by_pattern"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_constraint_1
+// CHECK: "test.replaced_by_pattern"
+module @ir attributes { test.apply_constraint_1 } {
+ "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyRewriteOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %operand = pdl_interp.get_operand 0 of %root
+ pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_rewrite_1
+// CHECK: %[[INPUT:.*]] = "test.op_input"
+// CHECK-NOT: "test.op"
+// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]}
+module @ir attributes { test.apply_rewrite_1 } {
+ %input = "test.op_input"() : () -> i32
+ "test.op"(%input) : (i32) -> ()
+}
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::AreEqualOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %test_attr = pdl_interp.create_attribute unit
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.are_equal_1
+// CHECK: "test.success"
+module @ir attributes { test.are_equal_1 } {
+ "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::BranchOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+ ^pat1:
+ pdl_interp.branch ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.branch_1
+// CHECK: "test.success"
+module @ir attributes { test.branch_1 } {
+ "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckAttributeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.check_attribute %attr is unit -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_attribute_1
+// CHECK: "test.success"
+module @ir attributes { test.check_attribute_1 } {
+ "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperandCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_operand_count_1
+// CHECK: "test.op"() : () -> i32
+// CHECK: "test.success"
+module @ir attributes { test.check_operand_count_1 } {
+ %operand = "test.op"() : () -> i32
+ "test.op"(%operand) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperationNameOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_operation_name_1
+// CHECK: "test.success"
+module @ir attributes { test.check_operation_name_1 } {
+ "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckResultCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is 1 -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_result_count_1
+// CHECK: "test.success"() : () -> ()
+module @ir attributes { test.check_result_count_1 } {
+ "test.op"() : () -> i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+ ^pat1:
+ %type = pdl_interp.get_attribute_type of %attr
+ pdl_interp.check_type %type is i32 -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.check_type_1
+// CHECK: "test.success"
+module @ir attributes { test.check_type_1 } {
+ "test.op"() { test_attr = 10 : i32 } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateAttributeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateNativeOp
+//===----------------------------------------------------------------------===//
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.create_native_1
+// CHECK: "test.success"
+module @ir attributes { test.create_native_1 } {
+ "test.op"() : () -> ()
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+ ^pat1:
+ %test_type = pdl_interp.create_type i32
+ %type = pdl_interp.get_attribute_type of %attr
+ pdl_interp.are_equal %type, %test_type : !pdl.type -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.create_type_1
+// CHECK: "test.success"
+module @ir attributes { test.create_type_1 } {
+ "test.op"() { test_attr = 0 : i32 } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::EraseOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::FinalizeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetDefiningOpOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operand_count of %root is 5 -> ^pat1, ^end
+
+ ^pat1:
+ %operand0 = pdl_interp.get_operand 0 of %root
+ %operand4 = pdl_interp.get_operand 4 of %root
+ %defOp0 = pdl_interp.get_defining_op of %operand0
+ %defOp4 = pdl_interp.get_defining_op of %operand4
+ pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_defining_op_1
+// CHECK: %[[OPERAND0:.*]] = "test.op"
+// CHECK: %[[OPERAND1:.*]] = "test.op"
+// CHECK: "test.success"
+// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]])
+module @ir attributes { test.get_defining_op_1 } {
+ %operand = "test.op"() : () -> i32
+ %other_operand = "test.op"() : () -> i32
+ "test.op"(%operand, %operand, %operand, %operand, %operand) : (i32, i32, i32, i32, i32) -> ()
+ "test.op"(%operand, %operand, %operand, %operand, %other_operand) : (i32, i32, i32, i32, i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end
+
+ ^pat1:
+ %result0 = pdl_interp.get_result 0 of %root
+ %result4 = pdl_interp.get_result 4 of %root
+ %result0_type = pdl_interp.get_value_type of %result0
+ %result4_type = pdl_interp.get_value_type of %result4
+ pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.get_result_1
+// CHECK: "test.success"
+// CHECK: "test.op"() : () -> (i32, i32, i32, i32, i64)
+module @ir attributes { test.get_result_1 } {
+ %a:5 = "test.op"() : () -> (i32, i32, i32, i32, i32)
+ %b:5 = "test.op"() : () -> (i32, i32, i32, i32, i64)
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetValueTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::InferredTypeOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::IsNotNullOp
+//===----------------------------------------------------------------------===//
+
+// Fully tested within the tests for other operations.
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::RecordMatchOp
+//===----------------------------------------------------------------------===//
+
+// Check that the highest benefit pattern is selected.
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end
+
+ ^pat1:
+ pdl_interp.record_match @rewriters::@failure(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @failure(%root : !pdl.operation) {
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.record_match_1
+// CHECK: "test.success"
+module @ir attributes { test.record_match_1 } {
+ "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ReplaceOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %operand = pdl_interp.get_operand 0 of %root
+ pdl_interp.replace %root with (%operand)
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.replace_op_1
+// CHECK: %[[INPUT:.*]] = "test.op_input"
+// CHECK-NOT: "test.op"
+// CHECK: "test.op_consumer"(%[[INPUT]])
+module @ir attributes { test.replace_op_1 } {
+ %input = "test.op_input"() : () -> i32
+ %result = "test.op"(%input) : (i32) -> i32
+ "test.op_consumer"(%result) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchAttributeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.switch_attribute %attr to [0, unit](^end, ^pat) -> ^end
+
+ ^pat:
+ %attr_2 = pdl_interp.get_attribute "test_attr_2" of %root
+ pdl_interp.switch_attribute %attr_2 to [0, unit](^end, ^end) -> ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_attribute_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_attribute_1 } {
+ "test.op"() { test_attr } : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperandCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.switch_operand_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end
+
+ ^pat:
+ pdl_interp.switch_operand_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_operand_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_operand_1 } {
+ %input = "test.op_input"() : () -> i32
+ "test.op"(%input) : (i32) -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperationNameOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.switch_operation_name of %root to ["foo.op", "test.op"](^end, ^pat1) -> ^end
+
+ ^pat1:
+ pdl_interp.switch_operation_name of %root to ["foo.op", "bar.op"](^end, ^end) -> ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_operation_name_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_operation_name_1 } {
+ "test.op"() : () -> ()
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchResultCountOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.switch_result_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end
+
+ ^pat:
+ pdl_interp.switch_result_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_result_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_result_1 } {
+ "test.op"() : () -> i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypeOp
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ %attr = pdl_interp.get_attribute "test_attr" of %root
+ pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end
+
+ ^pat1:
+ %type = pdl_interp.get_attribute_type of %attr
+ pdl_interp.switch_type %type to [i32, i64](^pat2, ^end) -> ^end
+
+ ^pat2:
+ pdl_interp.switch_type %type to [i16, i64](^end, ^end) -> ^pat3
+
+ ^pat3:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.create_operation "test.success"() -> ()
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.switch_type_1
+// CHECK: "test.success"
+module @ir attributes { test.switch_type_1 } {
+ "test.op"() { test_attr = 10 : i32 } : () -> ()
+}
diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt
index 0df357c8c355..9b156867702c 100644
--- a/mlir/test/lib/CMakeLists.txt
+++ b/mlir/test/lib/CMakeLists.txt
@@ -2,4 +2,5 @@ add_subdirectory(Dialect)
add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Reducer)
+add_subdirectory(Rewrite)
add_subdirectory(Transforms)
diff --git a/mlir/test/lib/Rewrite/CMakeLists.txt b/mlir/test/lib/Rewrite/CMakeLists.txt
new file mode 100644
index 000000000000..fd5d5d586160
--- /dev/null
+++ b/mlir/test/lib/Rewrite/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTestRewrite
+ TestPDLByteCode.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIRSupport
+ MLIRTransformUtils
+ )
+
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
new file mode 100644
index 000000000000..3b23cb103675
--- /dev/null
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -0,0 +1,85 @@
+//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+/// Custom constraint invoked from PDL.
+static LogicalResult customSingleEntityConstraint(PDLValue value,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ Operation *rootOp = value.cast<Operation *>();
+ return success(rootOp->getName().getStringRef() == "test.op");
+}
+static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ return customSingleEntityConstraint(values[1], constantParams, rewriter);
+}
+
+// Custom creator invoked from PDL.
+static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ return rewriter.createOperation(
+ OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
+}
+
+/// Custom rewriter invoked from PDL.
+static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
+ ArrayAttr constantParams,
+ PatternRewriter &rewriter) {
+ OperationState successOpState(root->getLoc(), "test.success");
+ successOpState.addOperands(args[0].cast<Value>());
+ successOpState.addAttribute("constantParams", constantParams);
+ rewriter.createOperation(successOpState);
+ rewriter.eraseOp(root);
+}
+
+namespace {
+struct TestPDLByteCodePass
+ : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
+ void runOnOperation() final {
+ ModuleOp module = getOperation();
+
+ // The test cases are encompassed via two modules, one containing the
+ // patterns and one containing the operations to rewrite.
+ ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
+ ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
+ if (!patternModule || !irModule)
+ return;
+
+ // Process the pattern module.
+ patternModule.getOperation()->remove();
+ PDLPatternModule pdlPattern(patternModule);
+ pdlPattern.registerConstraintFunction("multi_entity_constraint",
+ customMultiEntityConstraint);
+ pdlPattern.registerConstraintFunction("single_entity_constraint",
+ customSingleEntityConstraint);
+ pdlPattern.registerCreateFunction("creator", customCreate);
+ pdlPattern.registerRewriteFunction("rewriter", customRewriter);
+
+ OwningRewritePatternList patternList(std::move(pdlPattern));
+
+ // Invoke the pattern driver with the provided patterns.
+ (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
+ std::move(patternList));
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestPDLByteCodePass() {
+ PassRegistration<TestPDLByteCodePass>("test-pdl-bytecode-pass",
+ "Test PDL ByteCode functionality");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 8857bbe09eef..52e96dc44e0b 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -220,18 +220,21 @@ static void fillL1TilingAndMatmulToVectorPatterns(
FuncOp funcOp, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
- patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
+ patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
LinalgMarker(Identifier::get(startMarker, ctx),
Identifier::get("L1", ctx))));
- patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
- ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
- LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
+ patternsVector.emplace_back(
+ std::make_unique<LinalgPromotionPattern<MatmulOp>>(
+ ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
+ LinalgMarker(Identifier::get("L1", ctx),
+ Identifier::get("VEC", ctx))));
- patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
- ctx, LinalgMarker(Identifier::get("VEC", ctx))));
+ patternsVector.emplace_back(
+ std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
+ ctx, LinalgMarker(Identifier::get("VEC", ctx))));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<CopyOp>>(ctx);
@@ -437,7 +440,7 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
- stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
+ stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({768, 264, 768})
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index e8b0842a9e33..8bee2f5faa75 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -19,6 +19,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestIR
MLIRTestPass
MLIRTestReducer
+ MLIRTestRewrite
MLIRTestTransforms
)
endif()
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 4095cc21cbaf..67aa855092ef 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -86,6 +86,7 @@ void registerTestMemRefStrideCalculation();
void registerTestNumberOfBlockExecutionsPass();
void registerTestNumberOfOperationExecutionsPass();
void registerTestOpaqueLoc();
+void registerTestPDLByteCodePass();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
@@ -155,6 +156,7 @@ void registerTestPasses() {
test::registerTestNumberOfBlockExecutionsPass();
test::registerTestNumberOfOperationExecutionsPass();
test::registerTestOpaqueLoc();
+ test::registerTestPDLByteCodePass();
test::registerTestRecursiveTypesPass();
test::registerTestSCFUtilsPass();
test::registerTestSparsification();
More information about the llvm-branch-commits
mailing list