[Mlir-commits] [mlir] ea64828 - [mlir:PDL] Expand how native constraint/rewrite functions can be defined
River Riddle
llvmlistbot at llvm.org
Wed Apr 6 17:42:09 PDT 2022
Author: River Riddle
Date: 2022-04-06T17:41:59-07:00
New Revision: ea64828a10e304f8131e40dfef062173fe606e6a
URL: https://github.com/llvm/llvm-project/commit/ea64828a10e304f8131e40dfef062173fe606e6a
DIFF: https://github.com/llvm/llvm-project/commit/ea64828a10e304f8131e40dfef062173fe606e6a.diff
LOG: [mlir:PDL] Expand how native constraint/rewrite functions can be defined
This commit refactors the expected form of native constraint and rewrite
functions, and greatly reduces the necessary user complexity required when
defining a native function. Namely, this commit adds in automatic processing
of the necessary PDLValue glue code, and allows for users to define
constraint/rewrite functions using the C++ types that they actually want to
use.
As an example, lets see a simple example rewrite defined today:
```
static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
ValueRange operandValues = args[0].cast<ValueRange>();
TypeRange typeValues = args[1].cast<TypeRange>();
...
// Create an operation at some point and pass it back to PDL.
Operation *op = rewriter.create<SomeOp>(...);
results.push_back(op);
}
```
After this commit, that same rewrite could be defined as:
```
static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues,
TypeRange typeValues) {
...
// Create an operation at some point and pass it back to PDL.
return rewriter.create<SomeOp>(...);
}
```
Differential Revision: https://reviews.llvm.org/D122086
Added:
Modified:
llvm/include/llvm/ADT/STLExtras.h
mlir/docs/PDLL.md
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 75d7175453583..fc461f8fe7e50 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -129,7 +129,7 @@ struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
/// Overload for class function types.
template <typename ClassType, typename ReturnType, typename... Args>
struct function_traits<ReturnType (ClassType::*)(Args...), false>
- : function_traits<ReturnType (ClassType::*)(Args...) const> {};
+ : public function_traits<ReturnType (ClassType::*)(Args...) const> {};
/// Overload for non-class function types.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (*)(Args...), false> {
@@ -143,6 +143,9 @@ struct function_traits<ReturnType (*)(Args...), false> {
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType (*const)(Args...), false>
+ : public function_traits<ReturnType (*)(Args...)> {};
/// Overload for non-class function type references.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (&)(Args...), false>
diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md
index ab24a680d37b8..33293505d3714 100644
--- a/mlir/docs/PDLL.md
+++ b/mlir/docs/PDLL.md
@@ -1006,17 +1006,11 @@ External constraints are those registered explicitly with the `RewritePatternSet
the C++ PDL API. For example, the constraints above may be registered as:
```c++
-// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
-static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) {
- Value value = pdlValue.cast<Value>();
-
+static LogicalResult hasOneUseImpl(PatternRewriter &rewriter, Value value) {
return success(value.hasOneUse());
}
-static LogicalResult hasSameElementTypeImpl(ArrayRef<PDLValue> pdlValues,
- PatternRewriter &rewriter) {
- Value value1 = pdlValues[0].cast<Value>();
- Value value2 = pdlValues[1].cast<Value>();
-
+static LogicalResult hasSameElementTypeImpl(PatternRewriter &rewriter,
+ Value value1, Value Value2) {
return success(value1.getType().cast<ShapedType>().getElementType() ==
value2.getType().cast<ShapedType>().getElementType());
}
@@ -1307,14 +1301,10 @@ External rewrites are those registered explicitly with the `RewritePatternSet` v
the C++ PDL API. For example, the rewrite above may be registered as:
```c++
-// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
-static void buildOpImpl(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
- PDLResultList &results) {
- Value value = args[0].cast<Value>();
-
+static Operation *buildOpImpl(PDLResultList &results, Value value) {
// insert special rewrite logic here.
Operation *resultOp = ...;
- results.push_back(resultOp);
+ return resultOp;
}
void registerNativeRewrite(RewritePatternSet &patterns) {
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 7f0253e59a32b..4e43b2677f4a8 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -68,18 +68,14 @@ def PDL_ApplyNativeRewriteOp
```mlir
// Apply a native rewrite method that returns an attribute.
- %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute
+ %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %attr1) : !pdl.attribute
```
```c++
// The native rewrite as defined in C++:
- static void myNativeFunc(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
- PDLResultList &results) {
- Value arg0 = args[0].cast<Value>();
- Value arg1 = args[1].cast<Value>();
-
- // Just push back the first param attribute.
- results.push_back(param0);
+ static Attribute myNativeFunc(PatternRewriter &rewriter, Value arg0, Attribute arg1) {
+ // Just return the second arg.
+ return arg1;
}
void registerNativeRewrite(PDLPatternModule &pdlModule) {
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1bd6187a78d48..f4c8863624740 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -409,7 +409,8 @@ class OpBuilder : public Builder {
/// Creates an operation with the given fields.
Operation *create(Location loc, StringAttr opName, ValueRange operands,
- TypeRange types, ArrayRef<NamedAttribute> attributes = {},
+ TypeRange types = {},
+ ArrayRef<NamedAttribute> attributes = {},
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 11f85ee38bef8..478fa2ae97b1c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -386,6 +386,222 @@ class OpTraitRewritePattern : public RewritePattern {
benefit, context) {}
};
+//===----------------------------------------------------------------------===//
+// RewriterBase
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates the application of a rewrite on a set of IR,
+/// providing a way for clients to track mutations and create new operations.
+/// This class serves as a common API for IR mutation between pattern rewrites
+/// and non-pattern rewrites, and facilitates the development of shared
+/// IR transformation utilities.
+class RewriterBase : public OpBuilder, public OpBuilder::Listener {
+public:
+ /// Move the blocks that belong to "region" before the given position in
+ /// another region "parent". The two regions must be
diff erent. The caller
+ /// is responsible for creating or updating the operation transferring flow
+ /// of control to the region and passing it the correct block arguments.
+ virtual void inlineRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before);
+ void inlineRegionBefore(Region ®ion, Block *before);
+
+ /// Clone the blocks that belong to "region" before the given position in
+ /// another region "parent". The two regions must be
diff erent. The caller is
+ /// responsible for creating or updating the operation transferring flow of
+ /// control to the region and passing it the correct block arguments.
+ virtual void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before,
+ BlockAndValueMapping &mapping);
+ void cloneRegionBefore(Region ®ion, Region &parent,
+ Region::iterator before);
+ void cloneRegionBefore(Region ®ion, Block *before);
+
+ /// This method replaces the uses of the results of `op` with the values in
+ /// `newValues` when the provided `functor` returns true for a specific use.
+ /// The number of values in `newValues` is required to match the number of
+ /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
+ /// the uses of `op` were replaced. Note that in some rewriters, the given
+ /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
+ /// As such, the function should not capture by reference and instead use
+ /// value capture as necessary.
+ virtual void
+ replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
+ llvm::unique_function<bool(OpOperand &) const> functor);
+ void replaceOpWithIf(Operation *op, ValueRange newValues,
+ llvm::unique_function<bool(OpOperand &) const> functor) {
+ replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
+ std::move(functor));
+ }
+
+ /// This method replaces the uses of the results of `op` with the values in
+ /// `newValues` when a use is nested within the given `block`. The number of
+ /// values in `newValues` is required to match the number of results of `op`.
+ /// If all uses of this operation are replaced, the operation is erased.
+ void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
+ bool *allUsesReplaced = nullptr);
+
+ /// This method replaces the results of the operation with the specified list
+ /// of values. The number of provided values must match the number of results
+ /// of the operation.
+ virtual void replaceOp(Operation *op, ValueRange newValues);
+
+ /// Replaces the result op with a new op that is created without verification.
+ /// The result values of the two ops must be the same types.
+ template <typename OpTy, typename... Args>
+ OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
+ auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+ replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
+ return newOp;
+ }
+
+ /// This method erases an operation that is known to have no uses.
+ virtual void eraseOp(Operation *op);
+
+ /// This method erases all operations in a block.
+ virtual void eraseBlock(Block *block);
+
+ /// Merge the operations of block 'source' into the end of block 'dest'.
+ /// 'source's predecessors must either be empty or only contain 'dest`.
+ /// 'argValues' is used to replace the block arguments of 'source' after
+ /// merging.
+ virtual void mergeBlocks(Block *source, Block *dest,
+ ValueRange argValues = llvm::None);
+
+ // Merge the operations of block 'source' before the operation 'op'. Source
+ // block should not have existing predecessors or successors.
+ void mergeBlockBefore(Block *source, Operation *op,
+ ValueRange argValues = llvm::None);
+
+ /// Split the operations starting at "before" (inclusive) out of the given
+ /// block into a new block, and return it.
+ virtual Block *splitBlock(Block *block, Block::iterator before);
+
+ /// This method is used to notify the rewriter that an in-place operation
+ /// modification is about to happen. A call to this function *must* be
+ /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
+ /// This is a minor efficiency win (it avoids creating a new operation and
+ /// removing the old one) but also often allows simpler code in the client.
+ virtual void startRootUpdate(Operation *op) {}
+
+ /// This method is used to signal the end of a root update on the given
+ /// operation. This can only be called on operations that were provided to a
+ /// call to `startRootUpdate`.
+ virtual void finalizeRootUpdate(Operation *op) {}
+
+ /// This method cancels a pending root update. This can only be called on
+ /// operations that were provided to a call to `startRootUpdate`.
+ virtual void cancelRootUpdate(Operation *op) {}
+
+ /// This method is a utility wrapper around a root update of an operation. It
+ /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
+ /// callable.
+ template <typename CallableT>
+ void updateRootInPlace(Operation *root, CallableT &&callable) {
+ startRootUpdate(root);
+ callable();
+ finalizeRootUpdate(root);
+ }
+
+ /// Used to notify the rewriter that the IR failed to be rewritten because of
+ /// a match failure, and provide a callback to populate a diagnostic with the
+ /// reason why the failure occurred. This method allows for derived rewriters
+ /// to optionally hook into the reason why a rewrite failed, and display it to
+ /// users.
+ template <typename CallbackT>
+ std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
+ notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
+#ifndef NDEBUG
+ return notifyMatchFailure(loc,
+ function_ref<void(Diagnostic &)>(reasonCallback));
+#else
+ return failure();
+#endif
+ }
+ template <typename CallbackT>
+ std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
+ notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
+ return notifyMatchFailure(op->getLoc(),
+ function_ref<void(Diagnostic &)>(reasonCallback));
+ }
+ template <typename ArgT>
+ LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
+ return notifyMatchFailure(std::forward<ArgT>(arg),
+ [&](Diagnostic &diag) { diag << msg; });
+ }
+ template <typename ArgT>
+ LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
+ return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
+ }
+
+protected:
+ /// Initialize the builder with this rewriter as the listener.
+ explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+ explicit RewriterBase(const OpBuilder &otherBuilder)
+ : OpBuilder(otherBuilder) {
+ setListener(this);
+ }
+ ~RewriterBase() override;
+
+ /// These are the callback methods that subclasses can choose to implement if
+ /// they would like to be notified about certain types of mutations.
+
+ /// Notify the rewriter that the specified operation is about to be replaced
+ /// with another set of operations. This is called before the uses of the
+ /// operation have been changed.
+ virtual void notifyRootReplaced(Operation *op) {}
+
+ /// This is called on an operation that a rewrite is removing, right before
+ /// the operation is deleted. At this point, the operation has zero uses.
+ virtual void notifyOperationRemoved(Operation *op) {}
+
+ /// Notify the rewriter that the pattern failed to match the given operation,
+ /// and provide a callback to populate a diagnostic with the reason why the
+ /// failure occurred. This method allows for derived rewriters to optionally
+ /// hook into the reason why a rewrite failed, and display it to users.
+ virtual LogicalResult
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) {
+ return failure();
+ }
+
+private:
+ void operator=(const RewriterBase &) = delete;
+ RewriterBase(const RewriterBase &) = delete;
+
+ /// 'op' and 'newOp' are known to have the same number of results, replace the
+ /// uses of op with uses of newOp.
+ void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
+};
+
+//===----------------------------------------------------------------------===//
+// IRRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
+/// providing a way to keep track of the mutations made to the IR. This class
+/// should only be used in situations where another `RewriterBase` instance,
+/// such as a `PatternRewriter`, is not available.
+class IRRewriter : public RewriterBase {
+public:
+ explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+ explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter
+//===----------------------------------------------------------------------===//
+
+/// A special type of `RewriterBase` that coordinates the application of a
+/// rewrite pattern on the current IR being matched, providing a way to keep
+/// track of any mutations made. This class should be used to perform all
+/// necessary IR mutations within a rewrite pattern, as the pattern driver may
+/// be tracking various state that would be invalidated when a mutation takes
+/// place.
+class PatternRewriter : public RewriterBase {
+public:
+ using RewriterBase::RewriterBase;
+};
+
//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
@@ -587,291 +803,561 @@ class PDLResultList {
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
- std::function<LogicalResult(ArrayRef<PDLValue>, PatternRewriter &)>;
+ std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful.
using PDLRewriteFunction =
- std::function<void(ArrayRef<PDLValue>, PatternRewriter &, PDLResultList &)>;
-
-/// This class contains all of the necessary data for a set of PDL patterns, or
-/// pattern rewrites specified in the form of the PDL dialect. This PDL module
-/// contained by this pattern may contain any number of `pdl.pattern`
-/// operations.
-class PDLPatternModule {
-public:
- PDLPatternModule() = default;
+ std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
- /// Construct a PDL pattern with the given module.
- PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
- : pdlModule(std::move(pdlModule)) {}
+namespace detail {
+namespace pdl_function_builder {
+/// A utility variable that always resolves to false. This is useful for static
+/// asserts that are always false, but only should fire in certain templated
+/// constructs. For example, if a templated function should never be called, the
+/// function could be defined as:
+///
+/// template <typename T>
+/// void foo() {
+/// static_assert(always_false<T>, "This function should never be called");
+/// }
+///
+template <class... T>
+constexpr bool always_false = false;
- /// Merge the state in `other` into this pattern module.
- void mergeIn(PDLPatternModule &&other);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Type Processing
+//===----------------------------------------------------------------------===//
- /// Return the internal PDL module of this pattern.
- ModuleOp getModule() { return pdlModule.get(); }
+/// This struct provides a convenient way to determine how to process a given
+/// type as either a PDL parameter, or a result value. This allows for
+/// supporting complex types in constraint and rewrite functions, without
+/// requiring the user to hand-write the necessary glue code themselves.
+/// Specializations of this class should implement the following methods to
+/// enable support as a PDL argument or result type:
+///
+/// static LogicalResult verifyAsArg(
+/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
+/// size_t argIdx);
+///
+/// * This method verifies that the given PDLValue is valid for use as a
+/// value of `T`.
+///
+/// static T processAsArg(PDLValue pdlValue);
+///
+/// * This method processes the given PDLValue as a value of `T`.
+///
+/// static void processAsResult(PatternRewriter &, PDLResultList &results,
+/// const T &value);
+///
+/// * This method processes the given value of `T` as the result of a
+/// function invocation. The method should package the value into an
+/// appropriate form and append it to the given result list.
+///
+/// If the type `T` is based on a higher order value, consider using
+/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
+/// the implementation.
+///
+template <typename T, typename Enable = void>
+struct ProcessPDLValue;
+
+/// This struct provides a simplified model for processing types that are based
+/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
+/// allows for building the necessary processing functions on top of the base
+/// value instead of a PDLValue. Derived users should implement the following
+/// (which subsume the ProcessPDLValue variants):
+///
+/// static LogicalResult verifyAsArg(
+/// function_ref<LogicalResult(const Twine &)> errorFn,
+/// const BaseT &baseValue, size_t argIdx);
+///
+/// * This method verifies that the given PDLValue is valid for use as a
+/// value of `T`.
+///
+/// static T processAsArg(BaseT baseValue);
+///
+/// * This method processes the given base value as a value of `T`.
+///
+template <typename T, typename BaseT>
+struct ProcessPDLValueBasedOn {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ PDLValue pdlValue, size_t argIdx) {
+ // Verify the base class before continuing.
+ if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
+ return failure();
+ return ProcessPDLValue<T>::verifyAsArg(
+ errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
+ }
+ static T processAsArg(PDLValue pdlValue) {
+ return ProcessPDLValue<T>::processAsArg(
+ ProcessPDLValue<BaseT>::processAsArg(pdlValue));
+ }
- //===--------------------------------------------------------------------===//
- // Function Registry
+ /// Explicitly add the expected parent API to ensure the parent class
+ /// implements the necessary API (and doesn't implicitly inherit it from
+ /// somewhere else).
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
+ size_t argIdx) {
+ return success();
+ }
+ static T processAsArg(BaseT baseValue);
+};
- /// Register a constraint function.
- void registerConstraintFunction(StringRef name,
- PDLConstraintFunction constraintFn);
- /// Register a single entity constraint function.
- template <typename SingleEntityFn>
- std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
- PatternRewriter &>::value>
- registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
- registerConstraintFunction(
- name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
- ArrayRef<PDLValue> values, PatternRewriter &rewriter) {
- assert(values.size() == 1 &&
- "expected values to have a single entity");
- return constraintFn(values[0], rewriter);
- });
+/// This struct provides a simplified model for processing types that have
+/// "builtin" PDLValue support:
+/// * Attribute, Operation *, Type, TypeRange, ValueRange
+template <typename T>
+struct ProcessBuiltinPDLValue {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ PDLValue pdlValue, size_t argIdx) {
+ if (pdlValue)
+ return success();
+ return errorFn("expected a non-null value for argument " + Twine(argIdx) +
+ " of type: " + llvm::getTypeName<T>());
}
- /// Register a rewrite function.
- void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+ static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ T value) {
+ results.push_back(value);
+ }
+};
- /// Return the set of the registered constraint functions.
- const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
- return constraintFunctions;
+/// This struct provides a simplified model for processing types that inherit
+/// from builtin PDLValue types. For example, derived attributes like
+/// IntegerAttr, derived types like IntegerType, derived operations like
+/// ModuleOp, Interfaces, etc.
+template <typename T, typename BaseT>
+struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
+ static LogicalResult
+ verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+ BaseT baseValue, size_t argIdx) {
+ return TypeSwitch<BaseT, LogicalResult>(baseValue)
+ .Case([&](T) { return success(); })
+ .Default([&](BaseT) {
+ return errorFn("expected argument " + Twine(argIdx) +
+ " to be of type: " + llvm::getTypeName<T>());
+ });
}
- llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
- return constraintFunctions;
+ static T processAsArg(BaseT baseValue) {
+ return baseValue.template cast<T>();
}
- /// Return the set of the registered rewrite functions.
- const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
- return rewriteFunctions;
- }
- llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
- return rewriteFunctions;
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ T value) {
+ results.push_back(value);
}
+};
- /// Clear out the patterns and functions within this module.
- void clear() {
- pdlModule = nullptr;
- constraintFunctions.clear();
- rewriteFunctions.clear();
+//===----------------------------------------------------------------------===//
+// Attribute
+
+template <>
+struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
+template <typename T>
+struct ProcessPDLValue<T,
+ std::enable_if_t<std::is_base_of<Attribute, T>::value>>
+ : public ProcessDerivedPDLValue<T, Attribute> {};
+
+/// Handling for various Attribute value types.
+template <>
+struct ProcessPDLValue<StringRef>
+ : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
+ static StringRef processAsArg(StringAttr value) { return value.getValue(); }
+ static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+ StringRef value) {
+ results.push_back(rewriter.getStringAttr(value));
}
+};
+template <>
+struct ProcessPDLValue<std::string>
+ : public ProcessPDLValueBasedOn<std::string, StringAttr> {
+ template <typename T>
+ static std::string processAsArg(T value) {
+ static_assert(always_false<T>,
+ "`std::string` arguments require a string copy, use "
+ "`StringRef` for string-like arguments instead");
+ }
+ static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+ StringRef value) {
+ results.push_back(rewriter.getStringAttr(value));
+ }
+};
-private:
- /// The module containing the `pdl.pattern` operations.
- OwningOpRef<ModuleOp> pdlModule;
-
- /// The external functions referenced from within the PDL module.
- llvm::StringMap<PDLConstraintFunction> constraintFunctions;
- llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
+//===----------------------------------------------------------------------===//
+// Operation
+
+template <>
+struct ProcessPDLValue<Operation *>
+ : public ProcessBuiltinPDLValue<Operation *> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
+ : public ProcessDerivedPDLValue<T, Operation *> {
+ static T processAsArg(Operation *value) { return cast<T>(value); }
};
//===----------------------------------------------------------------------===//
-// RewriterBase
+// Type
+
+template <>
+struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
+ : public ProcessDerivedPDLValue<T, Type> {};
+
//===----------------------------------------------------------------------===//
+// TypeRange
+
+template <>
+struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
+template <>
+struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ValueTypeRange<OperandRange> types) {
+ results.push_back(types);
+ }
+};
+template <>
+struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ValueTypeRange<ResultRange> types) {
+ results.push_back(types);
+ }
+};
-/// This class coordinates the application of a rewrite on a set of IR,
-/// providing a way for clients to track mutations and create new operations.
-/// This class serves as a common API for IR mutation between pattern rewrites
-/// and non-pattern rewrites, and facilitates the development of shared
-/// IR transformation utilities.
-class RewriterBase : public OpBuilder, public OpBuilder::Listener {
-public:
- /// Move the blocks that belong to "region" before the given position in
- /// another region "parent". The two regions must be
diff erent. The caller
- /// is responsible for creating or updating the operation transferring flow
- /// of control to the region and passing it the correct block arguments.
- virtual void inlineRegionBefore(Region ®ion, Region &parent,
- Region::iterator before);
- void inlineRegionBefore(Region ®ion, Block *before);
+//===----------------------------------------------------------------------===//
+// Value
- /// Clone the blocks that belong to "region" before the given position in
- /// another region "parent". The two regions must be
diff erent. The caller is
- /// responsible for creating or updating the operation transferring flow of
- /// control to the region and passing it the correct block arguments.
- virtual void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before,
- BlockAndValueMapping &mapping);
- void cloneRegionBefore(Region ®ion, Region &parent,
- Region::iterator before);
- void cloneRegionBefore(Region ®ion, Block *before);
+template <>
+struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when the provided `functor` returns true for a specific use.
- /// The number of values in `newValues` is required to match the number of
- /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
- /// the uses of `op` were replaced. Note that in some rewriters, the given
- /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
- /// As such, the function should not capture by reference and instead use
- /// value capture as necessary.
- virtual void
- replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
- llvm::unique_function<bool(OpOperand &) const> functor);
- void replaceOpWithIf(Operation *op, ValueRange newValues,
- llvm::unique_function<bool(OpOperand &) const> functor) {
- replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
- std::move(functor));
+//===----------------------------------------------------------------------===//
+// ValueRange
+
+template <>
+struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
+};
+template <>
+struct ProcessPDLValue<OperandRange> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ OperandRange values) {
+ results.push_back(values);
}
+};
+template <>
+struct ProcessPDLValue<ResultRange> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ ResultRange values) {
+ results.push_back(values);
+ }
+};
- /// This method replaces the uses of the results of `op` with the values in
- /// `newValues` when a use is nested within the given `block`. The number of
- /// values in `newValues` is required to match the number of results of `op`.
- /// If all uses of this operation are replaced, the operation is erased.
- void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
- bool *allUsesReplaced = nullptr);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Argument Handling
+//===----------------------------------------------------------------------===//
- /// This method replaces the results of the operation with the specified list
- /// of values. The number of provided values must match the number of results
- /// of the operation.
- virtual void replaceOp(Operation *op, ValueRange newValues);
+/// Validate the given PDLValues match the constraints defined by the argument
+/// types of the given function. In the case of failure, a match failure
+/// diagnostic is emitted.
+/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
+/// does not currently preserve Constraint application ordering.
+template <typename PDLFnT, std::size_t... I>
+LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ using FnTraitsT = llvm::function_traits<PDLFnT>;
+
+ auto errorFn = [&](const Twine &msg) {
+ return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
+ };
+ LogicalResult result = success();
+ (void)std::initializer_list<int>{
+ (result =
+ succeeded(result)
+ ? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+ verifyAsArg(errorFn, values[I], I)
+ : failure(),
+ 0)...};
+ return result;
+}
- /// Replaces the result op with a new op that is created without verification.
- /// The result values of the two ops must be the same types.
- template <typename OpTy, typename... Args>
- OpTy replaceOpWithNewOp(Operation *op, Args &&... args) {
- auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
- replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
- return newOp;
- }
+/// Assert that the given PDLValues match the constraints defined by the
+/// arguments of the given function. In the case of failure, a fatal error
+/// is emitted.
+template <typename PDLFnT, std::size_t... I>
+void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ using FnTraitsT = llvm::function_traits<PDLFnT>;
+
+ // We only want to do verification in debug builds, same as with `assert`.
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ auto errorFn = [&](const Twine &msg) -> LogicalResult {
+ llvm::report_fatal_error(msg);
+ };
+ (void)std::initializer_list<int>{
+ (assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<
+ I + 1>>::verifyAsArg(errorFn, values[I], I))),
+ 0)...};
+#endif
+}
- /// This method erases an operation that is known to have no uses.
- virtual void eraseOp(Operation *op);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Results Handling
+//===----------------------------------------------------------------------===//
- /// This method erases all operations in a block.
- virtual void eraseBlock(Block *block);
+/// Store a single result within the result list.
+template <typename T>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+ T &&value) {
+ ProcessPDLValue<T>::processAsResult(rewriter, results,
+ std::forward<T>(value));
+}
- /// Merge the operations of block 'source' into the end of block 'dest'.
- /// 'source's predecessors must either be empty or only contain 'dest`.
- /// 'argValues' is used to replace the block arguments of 'source' after
- /// merging.
- virtual void mergeBlocks(Block *source, Block *dest,
- ValueRange argValues = llvm::None);
+/// Store a std::pair<> as individual results within the result list.
+template <typename T1, typename T2>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+ std::pair<T1, T2> &&pair) {
+ processResults(rewriter, results, std::move(pair.first));
+ processResults(rewriter, results, std::move(pair.second));
+}
- // Merge the operations of block 'source' before the operation 'op'. Source
- // block should not have existing predecessors or successors.
- void mergeBlockBefore(Block *source, Operation *op,
- ValueRange argValues = llvm::None);
+/// Store a std::tuple<> as individual results within the result list.
+template <typename... Ts>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+ std::tuple<Ts...> &&tuple) {
+ auto applyFn = [&](auto &&...args) {
+ // TODO: Use proper fold expressions when we have C++17. For now we use a
+ // bogus std::initializer_list to work around C++14 limitations.
+ (void)std::initializer_list<int>{
+ (processResults(rewriter, results, std::move(args)), 0)...};
+ };
+ llvm::apply_tuple(applyFn, std::move(tuple));
+}
- /// Split the operations starting at "before" (inclusive) out of the given
- /// block into a new block, and return it.
- virtual Block *splitBlock(Block *block, Block::iterator before);
+//===----------------------------------------------------------------------===//
+// PDL Constraint Builder
+//===----------------------------------------------------------------------===//
- /// This method is used to notify the rewriter that an in-place operation
- /// modification is about to happen. A call to this function *must* be
- /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
- /// This is a minor efficiency win (it avoids creating a new operation and
- /// removing the old one) but also often allows simpler code in the client.
- virtual void startRootUpdate(Operation *op) {}
+/// Process the arguments of a native constraint and invoke it.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+typename FnTraitsT::result_t
+processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
+ ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ return fn(
+ rewriter,
+ (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+ values[I]))...);
+}
- /// This method is used to signal the end of a root update on the given
- /// operation. This can only be called on operations that were provided to a
- /// call to `startRootUpdate`.
- virtual void finalizeRootUpdate(Operation *op) {}
+/// Build a constraint function from the given function `ConstraintFnT`. This
+/// allows for enabling the user to define simpler, more direct constraint
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the constraint function is already in the correct form, we just forward
+/// it directly.
+template <typename ConstraintFnT>
+std::enable_if_t<
+ std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+ PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+ return std::forward<ConstraintFnT>(constraintFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename ConstraintFnT>
+std::enable_if_t<
+ !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+ PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+ return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
+ PatternRewriter &rewriter,
+ ArrayRef<PDLValue> values) -> LogicalResult {
+ auto argIndices = std::make_index_sequence<
+ llvm::function_traits<ConstraintFnT>::num_args - 1>();
+ if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
+ return failure();
+ return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
+ argIndices);
+ };
+}
- /// This method cancels a pending root update. This can only be called on
- /// operations that were provided to a call to `startRootUpdate`.
- virtual void cancelRootUpdate(Operation *op) {}
+//===----------------------------------------------------------------------===//
+// PDL Rewrite Builder
+//===----------------------------------------------------------------------===//
- /// This method is a utility wrapper around a root update of an operation. It
- /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
- /// callable.
- template <typename CallableT>
- void updateRootInPlace(Operation *root, CallableT &&callable) {
- startRootUpdate(root);
- callable();
- finalizeRootUpdate(root);
- }
+/// Process the arguments of a native rewrite and invoke it.
+/// This overload handles the case of no return values.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+ PDLResultList &, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ fn(rewriter,
+ (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+ values[I]))...);
+}
+/// This overload handles the case of return values, which need to be packaged
+/// into the result list.
+template <typename PDLFnT, std::size_t... I,
+ typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+ PDLResultList &results, ArrayRef<PDLValue> values,
+ std::index_sequence<I...>) {
+ processResults(
+ rewriter, results,
+ fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+ processAsArg(values[I]))...));
+}
- /// Used to notify the rewriter that the IR failed to be rewritten because of
- /// a match failure, and provide a callback to populate a diagnostic with the
- /// reason why the failure occurred. This method allows for derived rewriters
- /// to optionally hook into the reason why a rewrite failed, and display it to
- /// users.
- template <typename CallbackT>
- std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
- notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
- return notifyMatchFailure(op,
- function_ref<void(Diagnostic &)>(reasonCallback));
-#else
- return failure();
-#endif
- }
- LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
- return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
- }
- LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
- return notifyMatchFailure(op, Twine(msg));
- }
+/// Build a rewrite function from the given function `RewriteFnT`. This
+/// allows for enabling the user to define simpler, more direct rewrite
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the rewrite function is already in the correct form, we just forward
+/// it directly.
+template <typename RewriteFnT>
+std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+ PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+ return std::forward<RewriteFnT>(rewriteFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename RewriteFnT>
+std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+ PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+ return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
+ PatternRewriter &rewriter, PDLResultList &results,
+ ArrayRef<PDLValue> values) {
+ auto argIndices =
+ std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
+ 1>();
+ assertArgs<RewriteFnT>(rewriter, values, argIndices);
+ processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
+ argIndices);
+ };
+}
-protected:
- /// Initialize the builder with this rewriter as the listener.
- explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
- explicit RewriterBase(const OpBuilder &otherBuilder)
- : OpBuilder(otherBuilder) {
- setListener(this);
- }
- ~RewriterBase() override;
+} // namespace pdl_function_builder
+} // namespace detail
- /// These are the callback methods that subclasses can choose to implement if
- /// they would like to be notified about certain types of mutations.
+/// This class contains all of the necessary data for a set of PDL patterns, or
+/// pattern rewrites specified in the form of the PDL dialect. This PDL module
+/// contained by this pattern may contain any number of `pdl.pattern`
+/// operations.
+class PDLPatternModule {
+public:
+ PDLPatternModule() = default;
- /// Notify the rewriter that the specified operation is about to be replaced
- /// with another set of operations. This is called before the uses of the
- /// operation have been changed.
- virtual void notifyRootReplaced(Operation *op) {}
+ /// Construct a PDL pattern with the given module.
+ PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
+ : pdlModule(std::move(pdlModule)) {}
- /// This is called on an operation that a rewrite is removing, right before
- /// the operation is deleted. At this point, the operation has zero uses.
- virtual void notifyOperationRemoved(Operation *op) {}
+ /// Merge the state in `other` into this pattern module.
+ void mergeIn(PDLPatternModule &&other);
- /// Notify the rewriter that the pattern failed to match the given operation,
- /// and provide a callback to populate a diagnostic with the reason why the
- /// failure occurred. This method allows for derived rewriters to optionally
- /// hook into the reason why a rewrite failed, and display it to users.
- virtual LogicalResult
- notifyMatchFailure(Operation *op,
- function_ref<void(Diagnostic &)> reasonCallback) {
- return failure();
- }
+ /// Return the internal PDL module of this pattern.
+ ModuleOp getModule() { return pdlModule.get(); }
-private:
- void operator=(const RewriterBase &) = delete;
- RewriterBase(const RewriterBase &) = delete;
+ //===--------------------------------------------------------------------===//
+ // Function Registry
- /// 'op' and 'newOp' are known to have the same number of results, replace the
- /// uses of op with uses of newOp.
- void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
-};
+ /// Register a constraint function with PDL. A constraint function may be
+ /// specified in one of two ways:
+ ///
+ /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+ ///
+ /// In this overload the arguments of the constraint function are passed via
+ /// the low-level PDLValue form.
+ ///
+ /// * `LogicalResult (PatternRewriter &, ValueTs... values)`
+ ///
+ /// In this form the arguments of the constraint function are passed via the
+ /// expected high level C++ type. In this form, the framework will
+ /// automatically unwrap PDLValues and convert them to the expected ValueTs.
+ /// For example, if the constraint function accepts a `Operation *`, the
+ /// framework will automatically cast the input PDLValue. In the case of a
+ /// `StringRef`, the framework will automatically unwrap the argument as a
+ /// StringAttr and pass the underlying string value. To see the full list of
+ /// supported types, or to see how to add handling for custom types, view
+ /// the definition of `ProcessPDLValue` above.
+ void registerConstraintFunction(StringRef name,
+ PDLConstraintFunction constraintFn);
+ template <typename ConstraintFnT>
+ void registerConstraintFunction(StringRef name,
+ ConstraintFnT &&constraintFn) {
+ registerConstraintFunction(name,
+ detail::pdl_function_builder::buildConstraintFn(
+ std::forward<ConstraintFnT>(constraintFn)));
+ }
-//===----------------------------------------------------------------------===//
-// IRRewriter
-//===----------------------------------------------------------------------===//
+ /// Register a rewrite function with PDL. A rewrite function may be specified
+ /// in one of two ways:
+ ///
+ /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
+ ///
+ /// In this overload the arguments of the constraint function are passed via
+ /// the low-level PDLValue form, and the results are manually appended to
+ /// the given result list.
+ ///
+ /// * `ResultT (PatternRewriter &, ValueTs... values)`
+ ///
+ /// In this form the arguments and result of the rewrite function are passed
+ /// via the expected high level C++ type. In this form, the framework will
+ /// automatically unwrap the PDLValues arguments and convert them to the
+ /// expected ValueTs. It will also automatically handle the processing and
+ /// packaging of the result value to the result list. For example, if the
+ /// rewrite function takes a `Operation *`, the framework will automatically
+ /// cast the input PDLValue. In the case of a `StringRef`, the framework
+ /// will automatically unwrap the argument as a StringAttr and pass the
+ /// underlying string value. In the reverse case, if the rewrite returns a
+ /// StringRef or std::string, it will automatically package this as a
+ /// StringAttr and append it to the result list. To see the full list of
+ /// supported types, or to see how to add handling for custom types, view
+ /// the definition of `ProcessPDLValue` above.
+ void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+ template <typename RewriteFnT>
+ void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
+ registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
+ std::forward<RewriteFnT>(rewriteFn)));
+ }
-/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
-/// providing a way to keep track of the mutations made to the IR. This class
-/// should only be used in situations where another `RewriterBase` instance,
-/// such as a `PatternRewriter`, is not available.
-class IRRewriter : public RewriterBase {
-public:
- explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
- explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
-};
+ /// Return the set of the registered constraint functions.
+ const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+ return constraintFunctions;
+ }
+ llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
+ return constraintFunctions;
+ }
+ /// Return the set of the registered rewrite functions.
+ const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
+ return rewriteFunctions;
+ }
+ llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
+ return rewriteFunctions;
+ }
-//===----------------------------------------------------------------------===//
-// PatternRewriter
-//===----------------------------------------------------------------------===//
+ /// Clear out the patterns and functions within this module.
+ void clear() {
+ pdlModule = nullptr;
+ constraintFunctions.clear();
+ rewriteFunctions.clear();
+ }
-/// A special type of `RewriterBase` that coordinates the application of a
-/// rewrite pattern on the current IR being matched, providing a way to keep
-/// track of any mutations made. This class should be used to perform all
-/// necessary IR mutations within a rewrite pattern, as the pattern driver may
-/// be tracking various state that would be invalidated when a mutation takes
-/// place.
-class PatternRewriter : public RewriterBase {
-public:
- using RewriterBase::RewriterBase;
+private:
+ /// The module containing the `pdl.pattern` operations.
+ OwningOpRef<ModuleOp> pdlModule;
+
+ /// The external functions referenced from within the PDL module.
+ llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+ llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5b409d695baed..d0be98a307d70 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -629,7 +629,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
- notifyMatchFailure(Operation *op,
+ notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
using PatternRewriter::notifyMatchFailure;
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 367f51ad601a8..c2dc41a81c6f8 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1340,7 +1340,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
});
// Invoke the constraint and jump to the proper destination.
- selectJump(succeeded(constraintFn(args, rewriter)));
+ selectJump(succeeded(constraintFn(rewriter, args)));
}
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1357,7 +1357,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
- rewriteFn(args, rewriter, results);
+ rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index c6937aa736a11..c3b5c957007e8 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -184,9 +184,9 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
.Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
};
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
- << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
- "::mlir::PatternRewriter &rewriter"
- << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
+ << "PDLFn(::mlir::PatternRewriter &rewriter, "
+ << (isConstraint ? "" : "::mlir::PDLResultList &results, ")
+ << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n";
const char *argumentInitStr = R"(
{0} {1} = {{};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bdeb0fa222b31..575b3cbd3c335 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1673,8 +1673,8 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
}
LogicalResult ConversionPatternRewriter::notifyMatchFailure(
- Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
- return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
+ Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
+ return impl->notifyMatchFailure(loc, reasonCallback);
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7fd46c711db01..81b57c420a726 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -76,7 +76,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
- notifyMatchFailure(Operation *op,
+ notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
/// The low-level pattern applicator.
@@ -348,9 +348,9 @@ void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
}
LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
- Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+ Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
- Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+ Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
});
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index d06c500241b0b..e1a8c6081d4e5 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -181,8 +181,9 @@ module @patterns {
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation) {
+ %attr = pdl_interp.apply_rewrite "str_creator" : !pdl.attribute
%type = pdl_interp.apply_rewrite "type_creator" : !pdl.type
- %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type)
+ %newOp = pdl_interp.create_operation "test.success" {"attr" = %attr} -> (%type : !pdl.type)
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -190,7 +191,7 @@ module @patterns {
}
// CHECK-LABEL: test.apply_rewrite_4
-// CHECK: "test.success"() : () -> f32
+// CHECK: "test.success"() {attr = "test.str"} : () -> f32
module @ir attributes { test.apply_rewrite_4 } {
"test.op"() : () -> ()
}
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 13465ba2865e0..daa1c371f27c9 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -14,53 +14,42 @@
using namespace mlir;
/// Custom constraint invoked from PDL.
-static LogicalResult customSingleEntityConstraint(PDLValue value,
- PatternRewriter &rewriter) {
- Operation *rootOp = value.cast<Operation *>();
+static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
+ Operation *rootOp) {
return success(rootOp->getName().getStringRef() == "test.op");
}
-static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
- PatternRewriter &rewriter) {
- return customSingleEntityConstraint(values[1], rewriter);
+static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
+ Operation *root,
+ Operation *rootCopy) {
+ return customSingleEntityConstraint(rewriter, rootCopy);
}
-static LogicalResult
-customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
- PatternRewriter &rewriter) {
- if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
- return failure();
- ValueRange operandValues = values[0].cast<ValueRange>();
- TypeRange typeValues = values[1].cast<TypeRange>();
+static LogicalResult customMultiEntityVariadicConstraint(
+ PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
if (operandValues.size() != 2 || typeValues.size() != 2)
return failure();
return success();
}
// Custom creator invoked from PDL.
-static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
- PDLResultList &results) {
- results.push_back(rewriter.create(
- OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
+static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
+ return rewriter.create(OperationState(op->getLoc(), "test.success"));
}
-
-static void customVariadicResultCreate(ArrayRef<PDLValue> args,
- PatternRewriter &rewriter,
- PDLResultList &results) {
- Operation *root = args[0].cast<Operation *>();
- results.push_back(root->getOperands());
- results.push_back(root->getOperands().getTypes());
+static auto customVariadicResultCreate(PatternRewriter &rewriter,
+ Operation *root) {
+ return std::make_pair(root->getOperands(), root->getOperands().getTypes());
+}
+static Type customCreateType(PatternRewriter &rewriter) {
+ return rewriter.getF32Type();
}
-static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
- PDLResultList &results) {
- results.push_back(rewriter.getF32Type());
+static std::string customCreateStrAttr(PatternRewriter &rewriter) {
+ return "test.str";
}
/// Custom rewriter invoked from PDL.
-static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
- PDLResultList &results) {
- Operation *root = args[0].cast<Operation *>();
- OperationState successOpState(root->getLoc(), "test.success");
- successOpState.addOperands(args[1].cast<Value>());
- rewriter.create(successOpState);
+static void customRewriter(PatternRewriter &rewriter, Operation *root,
+ Value input) {
+ rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
+ input);
rewriter.eraseOp(root);
}
@@ -117,6 +106,7 @@ struct TestPDLByteCodePass
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
+ pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
patternList.add(std::move(pdlPattern));
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 9f0ea1386322d..802958a3872f2 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -43,8 +43,8 @@ Pattern => erase op<test.op3>;
// Check the generation of native constraints and rewrites.
-// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
-// CHECK-SAME: ::mlir::PatternRewriter &rewriter) {
+// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
+// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) {
// CHECK: ::mlir::Attribute attr = {};
// CHECK: if (values[0])
// CHECK: attr = values[0].cast<::mlir::Attribute>();
@@ -69,8 +69,8 @@ Pattern => erase op<test.op3>;
// CHECK-NOT: TestUnusedCst
-// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
-// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
+// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results,
+// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) {
// CHECK: ::mlir::Attribute attr = {};
// CHECK: ::mlir::Operation * op = {};
// CHECK: ::mlir::Type type = {};
More information about the Mlir-commits
mailing list