[Mlir-commits] [mlir] 88bc24a - [mlir] Allow setting operation legality with an OperationName
River Riddle
llvmlistbot at llvm.org
Wed Apr 27 08:55:08 PDT 2022
Author: Mathieu Fehr
Date: 2022-04-27T08:54:51-07:00
New Revision: 88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc
URL: https://github.com/llvm/llvm-project/commit/88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc
DIFF: https://github.com/llvm/llvm-project/commit/88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc.diff
LOG: [mlir] Allow setting operation legality with an OperationName
This is necessary to handle conversions of operations defined at runtime in extensible dialects.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D124353
Added:
mlir/test/Transforms/test-rewrite-dynamic-op.mlir
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d0be98a307d70..feac434da68f5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -689,9 +689,12 @@ class ConversionTarget {
}
/// Register the given operations as legal.
+ void addLegalOp(OperationName op) {
+ setOpAction(op, LegalizationAction::Legal);
+ }
template <typename OpT>
void addLegalOp() {
- setOpAction<OpT>(LegalizationAction::Legal);
+ addLegalOp(OperationName(OpT::getOperationName(), &ctx));
}
template <typename OpT, typename OpT2, typename... OpTs>
void addLegalOp() {
@@ -701,11 +704,15 @@ class ConversionTarget {
/// Register the given operation as dynamically legal and set the dynamic
/// legalization callback to the one provided.
+ void addDynamicallyLegalOp(OperationName op,
+ const DynamicLegalityCallbackFn &callback) {
+ setOpAction(op, LegalizationAction::Dynamic);
+ setLegalityCallback(op, callback);
+ }
template <typename OpT>
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
- OperationName opName(OpT::getOperationName(), &ctx);
- setOpAction(opName, LegalizationAction::Dynamic);
- setLegalityCallback(opName, callback);
+ addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
+ callback);
}
template <typename OpT, typename OpT2, typename... OpTs>
void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
@@ -722,9 +729,12 @@ class ConversionTarget {
/// Register the given operation as illegal, i.e. this operation is known to
/// not be supported by this target.
+ void addIllegalOp(OperationName op) {
+ setOpAction(op, LegalizationAction::Illegal);
+ }
template <typename OpT>
void addIllegalOp() {
- setOpAction<OpT>(LegalizationAction::Illegal);
+ addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
}
template <typename OpT, typename OpT2, typename... OpTs>
void addIllegalOp() {
@@ -737,6 +747,8 @@ class ConversionTarget {
/// addition to the operation itself, all of the operations nested within are
/// also considered legal. An optional dynamic legality callback may be
/// provided to mark subsets of legal instances as recursively legal.
+ void markOpRecursivelyLegal(OperationName name,
+ const DynamicLegalityCallbackFn &callback);
template <typename OpT>
void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
OperationName opName(OpT::getOperationName(), &ctx);
@@ -840,11 +852,6 @@ class ConversionTarget {
/// Set the dynamic legality callback for the unknown ops.
void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
- /// Set the recursive legality callback for the given operation and mark the
- /// operation as recursively legal.
- void markOpRecursivelyLegal(OperationName name,
- const DynamicLegalityCallbackFn &callback);
-
/// The set of information that configures the legalization of an operation.
struct LegalizationInfo {
/// The legality action this operation was given.
diff --git a/mlir/test/Transforms/test-rewrite-dynamic-op.mlir b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir
new file mode 100644
index 0000000000000..5a6269704062a
--- /dev/null
+++ b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -test-rewrite-dynamic-op | FileCheck %s
+
+// Test that `test.one_operand_two_results` is replaced with
+// `test.generic_dynamic_op`.
+
+// CHECK-LABEL: func @rewrite_dynamic_op
+func @rewrite_dynamic_op(%arg0: i32) {
+ // CHECK-NEXT: %{{.*}}:2 = "test.dynamic_generic"(%arg0) : (i32) -> (i32, i32)
+ %0:2 = "test.dynamic_one_operand_two_results"(%arg0) : (i32) -> (i32, i32)
+ // CHECK-NEXT: return
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ab722f69e55c7..09c7a12d96dfd 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -955,6 +955,60 @@ struct TestUnknownRootOpDriver
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Test patterns that uses operations and types defined at runtime
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This pattern matches dynamic operations 'test.one_operand_two_results' and
+/// replace them with dynamic operations 'test.generic_dynamic_op'.
+struct RewriteDynamicOp : public RewritePattern {
+ RewriteDynamicOp(MLIRContext *context)
+ : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
+ context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getName().getStringRef() ==
+ "test.dynamic_one_operand_two_results" &&
+ "rewrite pattern should only match operations with the right name");
+
+ OperationState state(op->getLoc(), "test.dynamic_generic",
+ op->getOperands(), op->getResultTypes(),
+ op->getAttrs());
+ auto *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+struct TestRewriteDynamicOpDriver
+ : public PassWrapper<TestRewriteDynamicOpDriver,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<TestDialect>();
+ }
+ StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
+ StringRef getDescription() const final {
+ return "Test rewritting on dynamic operations";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<RewriteDynamicOp>(&getContext());
+
+ ConversionTarget target(getContext());
+ target.addIllegalOp(
+ OperationName("test.dynamic_one_operand_two_results", &getContext()));
+ target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // end anonymous namespace
+
//===----------------------------------------------------------------------===//
// Test type conversions
//===----------------------------------------------------------------------===//
@@ -1418,6 +1472,8 @@ void registerPatternsTestPass() {
PassRegistration<TestTypeConversionDriver>();
PassRegistration<TestTargetMaterializationWithNoUses>();
+ PassRegistration<TestRewriteDynamicOpDriver>();
+
PassRegistration<TestMergeBlocksPatternDriver>();
PassRegistration<TestSelectiveReplacementPatternDriver>();
}
More information about the Mlir-commits
mailing list