[Mlir-commits] [mlir] 8c66344 - [mlir:PDL] Add support for DialectConversion with pattern configurations
River Riddle
llvmlistbot at llvm.org
Tue Nov 8 01:58:25 PST 2022
Author: River Riddle
Date: 2022-11-08T01:57:57-08:00
New Revision: 8c66344ee9f67f76b3cb6b3345a46345a2d3975a
URL: https://github.com/llvm/llvm-project/commit/8c66344ee9f67f76b3cb6b3345a46345a2d3975a
DIFF: https://github.com/llvm/llvm-project/commit/8c66344ee9f67f76b3cb6b3345a46345a2d3975a.diff
LOG: [mlir:PDL] Add support for DialectConversion with pattern configurations
Up until now PDL(L) has not supported dialect conversion because we had no
way of remapping values or integrating with type conversions. This commit
rectifies that by adding a new "pattern configuration" concept to PDL. This
essentially allows for attaching external configurations to patterns, which
can hook into pattern events (for now just the scope of a rewrite, but we
could also pass configs to native rewrites as well). This allows for injecting
the type converter into the conversion pattern rewriter.
Differential Revision: https://reviews.llvm.org/D133142
Added:
mlir/include/mlir/Transforms/DialectConversion.pdll
mlir/test/Transforms/test-dialect-conversion-pdll.mlir
mlir/test/lib/Transforms/TestDialectConversion.cpp
mlir/test/lib/Transforms/TestDialectConversion.pdll
mlir/test/lib/Transforms/lit.local.cfg
Modified:
mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
mlir/include/mlir/IR/PatternMatch.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
mlir/lib/Rewrite/PatternApplicator.cpp
mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
index 8e8517ec47fbd..54033ff1639c7 100644
--- a/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
+++ b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h
@@ -13,12 +13,14 @@
#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
-#include <memory>
+#include "mlir/Support/LLVM.h"
namespace mlir {
class ModuleOp;
+class Operation;
template <typename OpT>
class OperationPass;
+class PDLPatternConfigSet;
#define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP
#include "mlir/Conversion/Passes.h.inc"
@@ -26,6 +28,12 @@ class OperationPass;
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();
+/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
+/// `configMap` holds a map of the configurations for each pattern being
+/// compiled.
+std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(
+ DenseMap<Operation *, PDLPatternConfigSet *> &configMap);
+
} // namespace mlir
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 151fab8bdebbf..e257b67ad9d8e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -600,10 +600,16 @@ class IRRewriter : public RewriterBase {
class PatternRewriter : public RewriterBase {
public:
using RewriterBase::RewriterBase;
+
+ /// A hook used to indicate if the pattern rewriter can recover from failure
+ /// during the rewrite stage of a pattern. For example, if the pattern
+ /// rewriter supports rollback, it may progress smoothly even if IR was
+ /// changed during the rewrite.
+ virtual bool canRecoverFromRewriteFailure() const { return false; }
};
//===----------------------------------------------------------------------===//
-// PDLPatternModule
+// PDL Patterns
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
@@ -796,6 +802,108 @@ class PDLResultList {
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};
+//===----------------------------------------------------------------------===//
+// PDLPatternConfig
+
+/// An individual configuration for a pattern, which can be accessed by native
+/// functions via the PDLPatternConfigSet. This allows for injecting additional
+/// configuration into PDL patterns that is specific to certain compilation
+/// flows.
+class PDLPatternConfig {
+public:
+ virtual ~PDLPatternConfig() = default;
+
+ /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+ /// pattern. These can be used to setup any specific state necessary for the
+ /// rewrite.
+ virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
+ virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
+
+ /// Return the TypeID that represents this configuration.
+ TypeID getTypeID() const { return id; }
+
+protected:
+ PDLPatternConfig(TypeID id) : id(id) {}
+
+private:
+ TypeID id;
+};
+
+/// This class provides a base class for users implementing a type of pattern
+/// configuration.
+template <typename T>
+class PDLPatternConfigBase : public PDLPatternConfig {
+public:
+ /// Support LLVM style casting.
+ static bool classof(const PDLPatternConfig *config) {
+ return config->getTypeID() == getConfigID();
+ }
+
+ /// Return the type id used for this configuration.
+ static TypeID getConfigID() { return TypeID::get<T>(); }
+
+protected:
+ PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
+};
+
+/// This class contains a set of configurations for a specific pattern.
+/// Configurations are uniqued by TypeID, meaning that only one configuration of
+/// each type is allowed.
+class PDLPatternConfigSet {
+public:
+ PDLPatternConfigSet() = default;
+
+ /// Construct a set with the given configurations.
+ template <typename... ConfigsT>
+ PDLPatternConfigSet(ConfigsT &&...configs) {
+ (addConfig(std::forward<ConfigsT>(configs)), ...);
+ }
+
+ /// Get the configuration defined by the given type. Asserts that the
+ /// configuration of the provided type exists.
+ template <typename T>
+ const T &get() const {
+ const T *config = tryGet<T>();
+ assert(config && "configuration not found");
+ return *config;
+ }
+
+ /// Get the configuration defined by the given type, returns nullptr if the
+ /// configuration does not exist.
+ template <typename T>
+ const T *tryGet() const {
+ for (const auto &configIt : configs)
+ if (const T *config = dyn_cast<T>(configIt.get()))
+ return config;
+ return nullptr;
+ }
+
+ /// Notify the configurations within this set at the beginning or end of a
+ /// rewrite of a matched pattern.
+ void notifyRewriteBegin(PatternRewriter &rewriter) {
+ for (const auto &config : configs)
+ config->notifyRewriteBegin(rewriter);
+ }
+ void notifyRewriteEnd(PatternRewriter &rewriter) {
+ for (const auto &config : configs)
+ config->notifyRewriteEnd(rewriter);
+ }
+
+protected:
+ /// Add a configuration to the set.
+ template <typename T>
+ void addConfig(T &&config) {
+ assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
+ configs.emplace_back(
+ std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
+ }
+
+ /// The set of configurations for this pattern. This uses a vector instead of
+ /// a map with the expectation that the number of configurations per set is
+ /// small (<= 1).
+ SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
+};
+
//===----------------------------------------------------------------------===//
// PDLPatternModule
@@ -807,9 +915,11 @@ using PDLConstraintFunction =
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
-/// invoked when the corresponding match was successful.
-using PDLRewriteFunction =
- std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+/// invoked when the corresponding match was successful. Returns failure if an
+/// invariant of the rewrite was broken (certain rewriters may recover from
+/// partial pattern application).
+using PDLRewriteFunction = std::function<LogicalResult(
+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
namespace detail {
namespace pdl_function_builder {
@@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
results.push_back(types);
}
};
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Type, N>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ SmallVector<Type, N> values) {
+ results.push_back(TypeRange(values));
+ }
+};
//===----------------------------------------------------------------------===//
// Value
@@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
results.push_back(values);
}
};
+template <unsigned N>
+struct ProcessPDLValue<SmallVector<Value, N>> {
+ static void processAsResult(PatternRewriter &, PDLResultList &results,
+ SmallVector<Value, N> values) {
+ results.push_back(ValueRange(values));
+ }
+};
//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
@@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
/// Store a single result within the result list.
template <typename T>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
- T &&value) {
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results, T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value));
+ return success();
}
/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
- std::pair<T1, T2> &&pair) {
- processResults(rewriter, results, std::move(pair.first));
- processResults(rewriter, results, std::move(pair.second));
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ std::pair<T1, T2> &&pair) {
+ if (failed(processResults(rewriter, results, std::move(pair.first))) ||
+ failed(processResults(rewriter, results, std::move(pair.second))))
+ return failure();
+ return success();
}
/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
-static void processResults(PatternRewriter &rewriter, PDLResultList &results,
- std::tuple<Ts...> &&tuple) {
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) {
- (processResults(rewriter, results, std::move(args)), ...);
+ return (succeeded(processResults(rewriter, results, std::move(args))) &&
+ ...);
};
- std::apply(applyFn, std::move(tuple));
+ return success(std::apply(applyFn, std::move(tuple)));
+}
+
+/// Handle LogicalResult propagation.
+inline LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ LogicalResult &&result) {
+ return result;
+}
+template <typename T>
+static LogicalResult processResults(PatternRewriter &rewriter,
+ PDLResultList &results,
+ FailureOr<T> &&result) {
+ if (failed(result))
+ return failure();
+ return processResults(rewriter, results, std::move(*result));
}
//===----------------------------------------------------------------------===//
@@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
+std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
+ LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
+ return success();
}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
-std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
+std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
+ LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
- processResults(
+ return processResults(
rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...));
@@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices);
- processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
- argIndices);
+ return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
+ argIndices);
};
}
} // namespace pdl_function_builder
} // namespace detail
+//===----------------------------------------------------------------------===//
+// PDLPatternModule
+
/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern`
@@ -1256,9 +1407,17 @@ class PDLPatternModule {
public:
PDLPatternModule() = default;
- /// Construct a PDL pattern with the given module.
- PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
- : pdlModule(std::move(pdlModule)) {}
+ /// Construct a PDL pattern with the given module and configurations.
+ PDLPatternModule(OwningOpRef<ModuleOp> module)
+ : pdlModule(std::move(module)) {}
+ template <typename... ConfigsT>
+ PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
+ : PDLPatternModule(std::move(module)) {
+ auto configSet = std::make_unique<PDLPatternConfigSet>(
+ std::forward<ConfigsT>(patternConfigs)...);
+ attachConfigToPatterns(*pdlModule, *configSet);
+ configs.emplace_back(std::move(configSet));
+ }
/// Merge the state in `other` into this pattern module.
void mergeIn(PDLPatternModule &&other);
@@ -1344,6 +1503,14 @@ class PDLPatternModule {
return rewriteFunctions;
}
+ /// Return the set of the registered pattern configs.
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
+ return std::move(configs);
+ }
+ DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
+ return std::move(configMap);
+ }
+
/// Clear out the patterns and functions within this module.
void clear() {
pdlModule = nullptr;
@@ -1352,9 +1519,17 @@ class PDLPatternModule {
}
private:
+ /// Attach the given pattern config set to the patterns defined within the
+ /// given module.
+ void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
+
/// The module containing the `pdl.pattern` operations.
OwningOpRef<ModuleOp> pdlModule;
+ /// The set of configuration sets referenced by patterns within `pdlModule`.
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
+ DenseMap<Operation *, PDLPatternConfigSet *> configMap;
+
/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6045b2237976e..59809492d836d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -574,6 +574,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
+ /// Indicate that the conversion rewriter can recover from rewrite failure.
+ /// Recovery is supported via rollback, allowing for continued processing of
+ /// patterns even if a failure is encountered during the rewrite step.
+ bool canRecoverFromRewriteFailure() const override { return true; }
+
/// PatternRewriter hook for replacing the results of an operation when the
/// given functor returns true.
void replaceOpWithIf(
@@ -891,6 +896,35 @@ class ConversionTarget {
MLIRContext &ctx;
};
+//===----------------------------------------------------------------------===//
+// PDL Configuration
+//===----------------------------------------------------------------------===//
+
+/// A PDL configuration that is used to supported dialect conversion
+/// functionality.
+class PDLConversionConfig final
+ : public PDLPatternConfigBase<PDLConversionConfig> {
+public:
+ PDLConversionConfig(TypeConverter *converter) : converter(converter) {}
+ ~PDLConversionConfig() final = default;
+
+ /// Return the type converter used by this configuration, which may be nullptr
+ /// if no type conversions are expected.
+ TypeConverter *getTypeConverter() const { return converter; }
+
+ /// Hooks that are invoked at the beginning and end of a rewrite of a matched
+ /// pattern.
+ void notifyRewriteBegin(PatternRewriter &rewriter) final;
+ void notifyRewriteEnd(PatternRewriter &rewriter) final;
+
+private:
+ /// An optional type converter to use for the pattern.
+ TypeConverter *converter;
+};
+
+/// Register the dialect conversion PDL functions with the given pattern set.
+void registerConversionPDLFunctions(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/DialectConversion.pdll b/mlir/include/mlir/Transforms/DialectConversion.pdll
new file mode 100644
index 0000000000000..9c6ce7a2d2328
--- /dev/null
+++ b/mlir/include/mlir/Transforms/DialectConversion.pdll
@@ -0,0 +1,30 @@
+//===- DialectConversion.pdll - DialectConversion PDLL Support -*- PDLL -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines various utilities for interacting with dialect conversion
+// within PDLL.
+//
+//===----------------------------------------------------------------------===//
+
+/// This rewrite returns the converted value of `value`, whose type is defined
+/// by the type converted specified in the `PDLConversionConfig` of the current
+/// pattern.
+Rewrite convertValue(value: Value) -> Value;
+
+/// This rewrite returns the converted values of `values`, whose type is defined
+/// by the type converted specified in the `PDLConversionConfig` of the current
+/// pattern.
+Rewrite convertValues(values: ValueRange) -> ValueRange;
+
+/// This rewrite returns the converted type of `type` as defined by the type
+/// converted specified in the `PDLConversionConfig` of the current pattern.
+Rewrite convertType(type: Type) -> Type;
+
+/// This rewrite returns the converted types of `types` as defined by the type
+/// converted specified in the `PDLConversionConfig` of the current pattern.
+Rewrite convertTypes(types: TypeRange) -> TypeRange;
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 301fa68e59d03..987e7a36ea890 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -37,7 +37,8 @@ namespace {
/// given module containing PDL pattern operations.
struct PatternLowering {
public:
- PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
+ PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
+ DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
/// Generate code for matching and rewriting based on the pattern operations
/// within the module.
@@ -140,13 +141,19 @@ struct PatternLowering {
/// The set of operation values whose whose location will be used for newly
/// generated operations.
SetVector<Value> locOps;
+
+ /// A mapping between pattern operations and the corresponding configuration
+ /// set.
+ DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
};
} // namespace
-PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
- ModuleOp rewriterModule)
+PatternLowering::PatternLowering(
+ pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
+ DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
- rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
+ rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
+ configMap(configMap) {}
void PatternLowering::lower(ModuleOp module) {
PredicateUniquer predicateUniquer;
@@ -589,10 +596,14 @@ void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock);
- builder.create<pdl_interp::RecordMatchOp>(
+ auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back());
+
+ // Set the config of the lowered match to the parent pattern.
+ if (configMap)
+ configMap->try_emplace(matchOp, configMap->lookup(pattern));
}
SymbolRefAttr PatternLowering::generateRewriter(
@@ -922,7 +933,14 @@ void PatternLowering::generateOperationResultTypeRewriter(
namespace {
struct PDLToPDLInterpPass
: public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
+ PDLToPDLInterpPass() = default;
+ PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
+ PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
+ : configMap(&configMap) {}
void runOnOperation() final;
+
+ /// A map containing the configuration for each pattern.
+ DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
};
} // namespace
@@ -946,15 +964,24 @@ void PDLToPDLInterpPass::runOnOperation() {
module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
// Generate the code for the patterns within the module.
- PatternLowering generator(matcherFunc, rewriterModule);
+ PatternLowering generator(matcherFunc, rewriterModule, configMap);
generator.lower(module);
// After generation, delete all of the pattern operations.
for (pdl::PatternOp pattern :
- llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
+ llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
+ // Drop the now dead config mappings.
+ if (configMap)
+ configMap->erase(pattern);
+
pattern.erase();
+ }
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
return std::make_unique<PDLToPDLInterpPass>();
}
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
+ DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
+ return std::make_unique<PDLToPDLInterpPass>(configMap);
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 494d90f304bdd..d2de65e7694ba 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -158,11 +158,15 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
if (!other.pdlModule)
return;
- // Steal the functions of the other module.
+ // Steal the functions and config of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
+ for (auto &it : other.configs)
+ configs.emplace_back(std::move(it));
+ for (auto &it : other.configMap)
+ configMap.insert(it);
// Steal the other state if we have no patterns.
if (!pdlModule) {
@@ -176,6 +180,18 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
other.pdlModule->getBody()->getOperations());
}
+void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
+ PDLPatternConfigSet &configSet) {
+ // Attach the configuration to the symbols within the module. We only add
+ // to symbols to avoid hardcoding any specific operation names here (given
+ // that we don't depend on any PDL dialect). We can't use
+ // cast<SymbolOpInterface> here because patterns may be optional symbols.
+ module->walk([&](Operation *op) {
+ if (op->hasTrait<SymbolOpInterface::Trait>())
+ configMap[op] = &configSet;
+ });
+}
+
//===----------------------------------------------------------------------===//
// Function Registry
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 388d6dca8f6e0..9cc51da9fcf33 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -34,21 +34,23 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+ PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr) {
+ PatternBenefit benefit = matchOp.getBenefit();
+ MLIRContext *ctx = matchOp.getContext();
+
+ // Collect the set of generated operations.
SmallVector<StringRef, 8> generatedOps;
if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
generatedOps =
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
- PatternBenefit benefit = matchOp.getBenefit();
- MLIRContext *ctx = matchOp.getContext();
-
// Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.getRootKind())
- return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
+ return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
generatedOps);
- return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
- generatedOps);
+ return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
+ benefit, ctx, generatedOps);
}
//===----------------------------------------------------------------------===//
@@ -194,14 +196,15 @@ class Generator {
ByteCodeField &maxValueRangeMemoryIndex,
ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
- llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+ llvm::StringMap<PDLRewriteFunction> &rewriteFns,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex),
maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
- maxLoopLevel(maxLoopLevel) {
+ maxLoopLevel(maxLoopLevel), configMap(configMap) {
for (const auto &it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (const auto &it : llvm::enumerate(rewriteFns))
@@ -328,6 +331,9 @@ class Generator {
ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex;
ByteCodeField &maxLoopLevel;
+
+ /// A map of pattern configurations.
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
};
/// This class provides utilities for writing a bytecode stream.
@@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create(
- op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
+ op, configMap.lookup(op),
+ rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.getMatchedOps());
writer.appendPDLValueList(op.getInputs());
@@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
// PDLByteCode
//===----------------------------------------------------------------------===//
-PDLByteCode::PDLByteCode(ModuleOp module,
- llvm::StringMap<PDLConstraintFunction> constraintFns,
- llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+PDLByteCode::PDLByteCode(
+ ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
+ llvm::StringMap<PDLConstraintFunction> constraintFns,
+ llvm::StringMap<PDLRewriteFunction> rewriteFns)
+ : configs(std::move(configs)) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
- maxLoopLevel, constraintFns, rewriteFns);
+ maxLoopLevel, constraintFns, rewriteFns, configMap);
generator.generate(module);
// Initialize the external functions.
@@ -1076,14 +1086,15 @@ class ByteCodeExecutor {
/// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching
/// context.
- void execute(PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
- Optional<Location> mainRewriteLoc = {});
+ LogicalResult
+ execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+ Optional<Location> mainRewriteLoc = {});
private:
/// Internal implementation of executing each of the bytecode commands.
void executeApplyConstraint(PatternRewriter &rewriter);
- void executeApplyRewrite(PatternRewriter &rewriter);
+ LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual();
void executeAreRangesEqual();
void executeBranch();
@@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
selectJump(succeeded(constraintFn(rewriter, args)));
}
-void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
+LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args;
@@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
- rewriteFn(rewriter, results, args);
+ LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
@@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
+
+ // Process the result of the rewrite.
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
}
void ByteCodeExecutor::executeAreEqual() {
@@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() {
});
}
-void ByteCodeExecutor::execute(
- PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches,
- Optional<Location> mainRewriteLoc) {
+LogicalResult
+ByteCodeExecutor::execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+ Optional<Location> mainRewriteLoc) {
while (true) {
// Print the location of the operation being executed.
LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
@@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute(
executeApplyConstraint(rewriter);
break;
case ApplyRewrite:
- executeApplyRewrite(rewriter);
+ if (failed(executeApplyRewrite(rewriter)))
+ return failure();
break;
case AreEqual:
executeAreEqual();
@@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute(
case Finalize:
executeFinalize();
LLVM_DEBUG(llvm::dbgs() << "\n");
- return;
+ return success();
case ForEach:
executeForEach();
break;
@@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute(
}
}
-/// Run the pattern matcher on the given root operation, collecting the matched
-/// patterns in `matches`.
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const {
@@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, &matches);
+ LogicalResult executeResult = executor.execute(rewriter, &matches);
+ assert(succeeded(executeResult) && "unexpected matcher execution failure");
// Order the found matches by benefit.
std::stable_sort(matches.begin(), matches.end(),
@@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
});
}
-/// Run the rewriter of the given pattern on the root operation `op`.
-void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
- PDLByteCodeMutableState &state) const {
+LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
+ const MatchResult &match,
+ PDLByteCodeMutableState &state) const {
+ auto *configSet = match.pattern->getConfigSet();
+ if (configSet)
+ configSet->notifyRewriteBegin(rewriter);
+
// The arguments of the rewrite function are stored at the start of the
// memory buffer.
llvm::copy(match.values, state.memory.begin());
@@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, /*matches=*/nullptr, match.location);
+ LogicalResult result =
+ executor.execute(rewriter, /*matches=*/nullptr, match.location);
+
+ if (configSet)
+ configSet->notifyRewriteEnd(rewriter);
+
+ // If the rewrite failed, check if the pattern rewriter can recover. If it
+ // can, we can signal to the pattern applicator to keep trying patterns. If it
+ // doesn't, we need to bail. Bailing here should be fine, given that we have
+ // no means to propagate such a failure to the user, and it also indicates a
+ // bug in the user code (i.e. failable rewrites should not be used with
+ // pattern rewriters that don't support it).
+ if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
+ LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+ llvm::report_fatal_error(
+ "Native PDL Rewrite failed, but the pattern "
+ "rewriter doesn't support recovery. Failable pattern rewrites should "
+ "not be used with pattern rewriters that do not support them.");
+ }
+ return result;
}
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index e423ff2a02352..4d43fe636bd1f 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -38,19 +38,27 @@ using OwningOpRange = llvm::OwningArrayRef<Operation *>;
class PDLByteCodePattern : public Pattern {
public:
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
+ PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr);
/// Return the bytecode address of the rewriter for this pattern.
ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
+ /// Return the configuration set for this pattern, or null if there is none.
+ PDLPatternConfigSet *getConfigSet() const { return configSet; }
+
private:
template <typename... Args>
- PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
- : Pattern(std::forward<Args>(patternArgs)...),
- rewriterAddr(rewriterAddr) {}
+ PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
+ Args &&...patternArgs)
+ : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
+ configSet(configSet) {}
/// The address of the rewriter for this pattern.
ByteCodeAddr rewriterAddr;
+
+ /// The optional config set for this pattern.
+ PDLPatternConfigSet *configSet;
};
//===----------------------------------------------------------------------===//
@@ -148,6 +156,8 @@ class PDLByteCode {
/// Create a ByteCode instance from the given module containing operations in
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);
@@ -165,9 +175,9 @@ class PDLByteCode {
PDLByteCodeMutableState &state) const;
/// Run the rewriter of the given pattern that was previously matched in
- /// `match`.
- void rewrite(PatternRewriter &rewriter, const MatchResult &match,
- PDLByteCodeMutableState &state) const;
+ /// `match`. Returns if a failure was encountered during the rewrite.
+ LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
+ PDLByteCodeMutableState &state) const;
private:
/// Execute the given byte code starting at the provided instruction `inst`.
@@ -177,6 +187,9 @@ class PDLByteCode {
PDLByteCodeMutableState &state,
SmallVectorImpl<MatchResult> *matches) const;
+ /// The set of pattern configs referenced within the bytecode.
+ SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
+
/// A vector containing pointers to uniqued data. The storage is intentionally
/// opaque such that we can store a wide range of data types. The types of
/// data stored here include:
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
index 765782519ffd4..7b83d104befbc 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
@@ -16,7 +16,9 @@
using namespace mlir;
-static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
+static LogicalResult
+convertPDLToPDLInterp(ModuleOp pdlModule,
+ DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
// Skip the conversion if the module doesn't contain pdl.
if (pdlModule.getOps<pdl::PatternOp>().empty())
return success();
@@ -37,7 +39,7 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
// mode.
pdlPipeline.enableVerifier(false);
#endif
- pdlPipeline.addPass(createPDLToPDLInterpPass());
+ pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
if (failed(pdlPipeline.run(pdlModule)))
return failure();
@@ -123,13 +125,16 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
ModuleOp pdlModule = pdlPatterns.getModule();
if (!pdlModule)
return;
- if (failed(convertPDLToPDLInterp(pdlModule)))
+ DenseMap<Operation *, PDLPatternConfigSet *> configMap =
+ pdlPatterns.takeConfigMap();
+ if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
llvm::report_fatal_error(
"failed to lower PDL pattern module to the PDL Interpreter");
// Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
- pdlModule, pdlPatterns.takeConstraintFunctions(),
+ pdlModule, pdlPatterns.takeConfigs(), configMap,
+ pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeRewriteFunctions());
}
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 686b8e2330b0f..499a8506bc606 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -191,20 +191,21 @@ LogicalResult PatternApplicator::matchAndRewrite(
Operation *dumpRootOp = getDumpRootOp(op);
#endif
if (pdlMatch) {
- bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
- result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
+ result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
- const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
+ LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+ << bestPattern->getDebugName() << "\"\n");
- LLVM_DEBUG(llvm::dbgs()
- << "Trying to match \"" << pattern->getDebugName() << "\"\n");
+ const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);
- LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
- << succeeded(result) << "\n");
- if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
- result = failure();
+ LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
+ << "\" result " << succeeded(result) << "\n");
}
+
+ // Process the result of the pattern application.
+ if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+ result = failure();
if (succeeded(result)) {
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
break;
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index 4632533b5d43f..cf6770a2816d1 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -93,10 +93,12 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
os << "} // end namespace\n\n";
// Emit function to add the generated matchers to the pattern list.
- os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
- "::mlir::RewritePatternSet &patterns) {\n";
+ os << "template <typename... ConfigsT>\n"
+ "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+ "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames)
- os << " patterns.add<" << name << ">(patterns.getContext());\n";
+ os << " patterns.add<" << name
+ << ">(patterns.getContext(), configs...);\n";
os << "}\n";
}
@@ -104,14 +106,15 @@ void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
StringSet<> &nativeFunctions) {
const char *patternClassStartStr = R"(
struct {0} : ::mlir::PDLPatternModule {{
- {0}(::mlir::MLIRContext *context)
+ template <typename... ConfigsT>
+ {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
: ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
)";
os << llvm::formatv(patternClassStartStr, patternName);
os << "R\"mlir(";
pattern->print(os, OpPrintingFlags().enableDebugInfo());
- os << "\n )mlir\", context)) {\n";
+ os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
// Register any native functions used within the pattern.
StringSet<> registeredNativeFunctions;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 61bc4ffbe6f28..616e43721cae7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3272,6 +3272,76 @@ auto ConversionTarget::getOpInfo(OperationName op) const
return llvm::None;
}
+//===----------------------------------------------------------------------===//
+// PDL Configuration
+//===----------------------------------------------------------------------===//
+
+void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
+ auto &rewriterImpl =
+ static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+ rewriterImpl.currentTypeConverter = getTypeConverter();
+}
+
+void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
+ auto &rewriterImpl =
+ static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+ rewriterImpl.currentTypeConverter = nullptr;
+}
+
+/// Remap the given value using the rewriter and the type converter in the
+/// provided config.
+static FailureOr<SmallVector<Value>>
+pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
+ SmallVector<Value> mappedValues;
+ if (failed(rewriter.getRemappedValues(values, mappedValues)))
+ return failure();
+ return std::move(mappedValues);
+}
+
+void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
+ patterns.getPDLPatterns().registerRewriteFunction(
+ "convertValue",
+ [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
+ auto results = pdllConvertValues(
+ static_cast<ConversionPatternRewriter &>(rewriter), value);
+ if (failed(results))
+ return failure();
+ return results->front();
+ });
+ patterns.getPDLPatterns().registerRewriteFunction(
+ "convertValues", [](PatternRewriter &rewriter, ValueRange values) {
+ return pdllConvertValues(
+ static_cast<ConversionPatternRewriter &>(rewriter), values);
+ });
+ patterns.getPDLPatterns().registerRewriteFunction(
+ "convertType",
+ [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
+ auto &rewriterImpl =
+ static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+ if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
+ if (Type newType = converter->convertType(type))
+ return newType;
+ return failure();
+ }
+ return type;
+ });
+ patterns.getPDLPatterns().registerRewriteFunction(
+ "convertTypes",
+ [](PatternRewriter &rewriter,
+ TypeRange types) -> FailureOr<SmallVector<Type>> {
+ auto &rewriterImpl =
+ static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
+ TypeConverter *converter = rewriterImpl.currentTypeConverter;
+ if (!converter)
+ return SmallVector<Type>(types);
+
+ SmallVector<Type> remappedTypes;
+ if (failed(converter->convertTypes(types, remappedTypes)))
+ return failure();
+ return std::move(remappedTypes);
+ });
+}
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/test-dialect-conversion-pdll.mlir b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir
new file mode 100644
index 0000000000000..97c8dfc2d83df
--- /dev/null
+++ b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-dialect-conversion-pdll | FileCheck %s
+
+// CHECK-LABEL: @TestSingleConversion
+func.func @TestSingleConversion() {
+ // CHECK: %[[CAST:.*]] = "test.cast"() : () -> f64
+ // CHECK-NEXT: "test.return"(%[[CAST]]) : (f64) -> ()
+ %result = "test.cast"() : () -> (i64)
+ "test.return"(%result) : (i64) -> ()
+}
+
+// CHECK-LABEL: @TestLingeringConversion
+func.func @TestLingeringConversion() -> i64 {
+ // CHECK: %[[ORIG_CAST:.*]] = "test.cast"() : () -> f64
+ // CHECK: %[[MATERIALIZE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ORIG_CAST]] : f64 to i64
+ // CHECK-NEXT: return %[[MATERIALIZE_CAST]] : i64
+ %result = "test.cast"() : () -> (i64)
+ return %result : i64
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 8672e87bc527c..0379dcd7a1968 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,8 +1,18 @@
+add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
+ TestDialectConversion.pdll
+ TestDialectConversionPDLLPatterns.h.inc
+
+ EXTRA_INCLUDES
+ ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
+ ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
+ )
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
+ TestDialectConversion.cpp
TestInlining.cpp
TestIntRangeInference.cpp
TestTopologicalSort.cpp
@@ -12,8 +22,12 @@ add_mlir_library(MLIRTestTransforms
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+ DEPENDS
+ MLIRTestDialectConversionPDLLPatternsIncGen
+
LINK_LIBS PUBLIC
MLIRAnalysis
+ MLIRFuncDialect
MLIRInferIntRangeInterface
MLIRTestDialect
MLIRTransforms
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
new file mode 100644
index 0000000000000..996b7b9e28861
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -0,0 +1,96 @@
+//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace test;
+
+//===----------------------------------------------------------------------===//
+// Test PDLL Support
+//===----------------------------------------------------------------------===//
+
+#include "TestDialectConversionPDLLPatterns.h.inc"
+
+namespace {
+struct PDLLTypeConverter : public TypeConverter {
+ PDLLTypeConverter() {
+ addConversion(convertType);
+ addArgumentMaterialization(materializeCast);
+ addSourceMaterialization(materializeCast);
+ }
+
+ static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
+ // Convert I64 to F64.
+ if (t.isSignlessInteger(64)) {
+ results.push_back(FloatType::getF64(t.getContext()));
+ return success();
+ }
+
+ // Otherwise, convert the type directly.
+ results.push_back(t);
+ return success();
+ }
+ /// Hook for materializing a conversion.
+ static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
+ ValueRange inputs, Location loc) {
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+ .getResult(0);
+ }
+};
+
+struct TestDialectConversionPDLLPass
+ : public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
+
+ StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
+ StringRef getDescription() const final {
+ return "Test DialectConversion PDLL functionality";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
+ }
+ LogicalResult initialize(MLIRContext *ctx) override {
+ // Build the pattern set within the `initialize` to avoid recompiling PDL
+ // patterns during each `runOnOperation` invocation.
+ RewritePatternSet patternList(ctx);
+ registerConversionPDLFunctions(patternList);
+ populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
+ patterns = std::move(patternList);
+ return success();
+ }
+
+ void runOnOperation() final {
+ mlir::ConversionTarget target(getContext());
+ target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
+ target.addDynamicallyLegalDialect<TestDialect>(
+ [this](Operation *op) { return converter.isLegal(op); });
+
+ if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
+ signalPassFailure();
+ }
+
+ FrozenRewritePatternSet patterns;
+ PDLLTypeConverter converter;
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestDialectConversionPasses() {
+ PassRegistration<TestDialectConversionPDLLPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll
new file mode 100644
index 0000000000000..c29e852feeff3
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestDialectConversion.pdll
@@ -0,0 +1,19 @@
+//===- TestPDLL.pdll - Test PDLL functionality ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestOps.td"
+#include "mlir/Transforms/DialectConversion.pdll"
+
+/// Change the result type of a producer.
+// FIXME: We shouldn't need to specify arguments for the result cast.
+Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
+ with op<test.cast>(args) -> (convertTypes(results));
+
+/// Pass through test.return conversion.
+Pattern => replace op<test.return>(args: ValueRange)
+ with op<test.return>(convertValues(args));
diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg
new file mode 100644
index 0000000000000..8cfe5cd834f06
--- /dev/null
+++ b/mlir/test/lib/Transforms/lit.local.cfg
@@ -0,0 +1 @@
+config.suffixes.remove('.pdll')
diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 4dae177f10993..f97530700b1d4 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -5,18 +5,19 @@
// check that we handle overlap.
// CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule {
+// CHECK: template <typename... ConfigsT>
// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
// CHECK: R"mlir(
// CHECK: pdl.pattern
// CHECK: operation "test.op"
-// CHECK: )mlir", context))
+// CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
// CHECK: struct NamedPattern : ::mlir::PDLPatternModule {
// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
// CHECK: R"mlir(
// CHECK: pdl.pattern
// CHECK: operation "test.op2"
-// CHECK: )mlir", context))
+// CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
// CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule {
@@ -25,13 +26,13 @@
// CHECK: R"mlir(
// CHECK: pdl.pattern
// CHECK: operation "test.op3"
-// CHECK: )mlir", context))
+// CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
-// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) {
-// CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext());
-// CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext());
-// CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext());
-// CHECK-NEXT: patterns.add<GeneratedPDLLPattern2>(patterns.getContext());
+// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
+// CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...);
+// CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...);
+// CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...);
+// CHECK-NEXT: patterns.add<GeneratedPDLLPattern2>(patterns.getContext(), configs...);
// CHECK-NEXT: }
Pattern => erase op<test.op>;
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 9eb0a47558dda..1e53d8519b223 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -76,6 +76,7 @@ void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
+void registerTestDialectConversionPasses();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandMathPass();
@@ -170,6 +171,7 @@ void registerTestPasses() {
mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass();
+ mlir::test::registerTestDialectConversionPasses();
#if MLIR_CUDA_CONVERSIONS_ENABLED
mlir::test::registerTestGpuSerializeToCubinPass();
#endif
More information about the Mlir-commits
mailing list