[Mlir-commits] [mlir] 88bc24a - [mlir] Allow setting operation legality with an OperationName

River Riddle llvmlistbot at llvm.org
Wed Apr 27 08:55:08 PDT 2022


Author: Mathieu Fehr
Date: 2022-04-27T08:54:51-07:00
New Revision: 88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc

URL: https://github.com/llvm/llvm-project/commit/88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc
DIFF: https://github.com/llvm/llvm-project/commit/88bc24a7e39ef3f85857c3d0857e0e5c93b50bbc.diff

LOG: [mlir] Allow setting operation legality with an OperationName

This is necessary to handle conversions of operations defined at runtime in extensible dialects.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D124353

Added: 
    mlir/test/Transforms/test-rewrite-dynamic-op.mlir

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index d0be98a307d70..feac434da68f5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -689,9 +689,12 @@ class ConversionTarget {
   }
 
   /// Register the given operations as legal.
+  void addLegalOp(OperationName op) {
+    setOpAction(op, LegalizationAction::Legal);
+  }
   template <typename OpT>
   void addLegalOp() {
-    setOpAction<OpT>(LegalizationAction::Legal);
+    addLegalOp(OperationName(OpT::getOperationName(), &ctx));
   }
   template <typename OpT, typename OpT2, typename... OpTs>
   void addLegalOp() {
@@ -701,11 +704,15 @@ class ConversionTarget {
 
   /// Register the given operation as dynamically legal and set the dynamic
   /// legalization callback to the one provided.
+  void addDynamicallyLegalOp(OperationName op,
+                             const DynamicLegalityCallbackFn &callback) {
+    setOpAction(op, LegalizationAction::Dynamic);
+    setLegalityCallback(op, callback);
+  }
   template <typename OpT>
   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
-    OperationName opName(OpT::getOperationName(), &ctx);
-    setOpAction(opName, LegalizationAction::Dynamic);
-    setLegalityCallback(opName, callback);
+    addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
+                          callback);
   }
   template <typename OpT, typename OpT2, typename... OpTs>
   void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
@@ -722,9 +729,12 @@ class ConversionTarget {
 
   /// Register the given operation as illegal, i.e. this operation is known to
   /// not be supported by this target.
+  void addIllegalOp(OperationName op) {
+    setOpAction(op, LegalizationAction::Illegal);
+  }
   template <typename OpT>
   void addIllegalOp() {
-    setOpAction<OpT>(LegalizationAction::Illegal);
+    addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
   }
   template <typename OpT, typename OpT2, typename... OpTs>
   void addIllegalOp() {
@@ -737,6 +747,8 @@ class ConversionTarget {
   /// addition to the operation itself, all of the operations nested within are
   /// also considered legal. An optional dynamic legality callback may be
   /// provided to mark subsets of legal instances as recursively legal.
+  void markOpRecursivelyLegal(OperationName name,
+                              const DynamicLegalityCallbackFn &callback);
   template <typename OpT>
   void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
     OperationName opName(OpT::getOperationName(), &ctx);
@@ -840,11 +852,6 @@ class ConversionTarget {
   /// Set the dynamic legality callback for the unknown ops.
   void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
 
-  /// Set the recursive legality callback for the given operation and mark the
-  /// operation as recursively legal.
-  void markOpRecursivelyLegal(OperationName name,
-                              const DynamicLegalityCallbackFn &callback);
-
   /// The set of information that configures the legalization of an operation.
   struct LegalizationInfo {
     /// The legality action this operation was given.

diff  --git a/mlir/test/Transforms/test-rewrite-dynamic-op.mlir b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir
new file mode 100644
index 0000000000000..5a6269704062a
--- /dev/null
+++ b/mlir/test/Transforms/test-rewrite-dynamic-op.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -test-rewrite-dynamic-op | FileCheck %s
+
+// Test that `test.one_operand_two_results` is replaced with
+// `test.generic_dynamic_op`.
+
+// CHECK-LABEL: func @rewrite_dynamic_op
+func @rewrite_dynamic_op(%arg0: i32) {
+  // CHECK-NEXT: %{{.*}}:2 = "test.dynamic_generic"(%arg0) : (i32) -> (i32, i32)
+  %0:2 = "test.dynamic_one_operand_two_results"(%arg0) : (i32) -> (i32, i32)
+  // CHECK-NEXT: return
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ab722f69e55c7..09c7a12d96dfd 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -955,6 +955,60 @@ struct TestUnknownRootOpDriver
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Test patterns that uses operations and types defined at runtime
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This pattern matches dynamic operations 'test.one_operand_two_results' and
+/// replace them with dynamic operations 'test.generic_dynamic_op'.
+struct RewriteDynamicOp : public RewritePattern {
+  RewriteDynamicOp(MLIRContext *context)
+      : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1,
+                       context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(op->getName().getStringRef() ==
+               "test.dynamic_one_operand_two_results" &&
+           "rewrite pattern should only match operations with the right name");
+
+    OperationState state(op->getLoc(), "test.dynamic_generic",
+                         op->getOperands(), op->getResultTypes(),
+                         op->getAttrs());
+    auto *newOp = rewriter.create(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+struct TestRewriteDynamicOpDriver
+    : public PassWrapper<TestRewriteDynamicOpDriver,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<TestDialect>();
+  }
+  StringRef getArgument() const final { return "test-rewrite-dynamic-op"; }
+  StringRef getDescription() const final {
+    return "Test rewritting on dynamic operations";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<RewriteDynamicOp>(&getContext());
+
+    ConversionTarget target(getContext());
+    target.addIllegalOp(
+        OperationName("test.dynamic_one_operand_two_results", &getContext()));
+    target.addLegalOp(OperationName("test.dynamic_generic", &getContext()));
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // end anonymous namespace
+
 //===----------------------------------------------------------------------===//
 // Test type conversions
 //===----------------------------------------------------------------------===//
@@ -1418,6 +1472,8 @@ void registerPatternsTestPass() {
   PassRegistration<TestTypeConversionDriver>();
   PassRegistration<TestTargetMaterializationWithNoUses>();
 
+  PassRegistration<TestRewriteDynamicOpDriver>();
+
   PassRegistration<TestMergeBlocksPatternDriver>();
   PassRegistration<TestSelectiveReplacementPatternDriver>();
 }


        


More information about the Mlir-commits mailing list