[Mlir-commits] [mlir] c6828e0 - [mlir] Make ConversionTarget dynamic legality callbacks composable

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 12 03:06:28 PDT 2021


Author: Caitlyn Cano
Date: 2021-10-12T13:05:54+03:00
New Revision: c6828e0cea73f89a79db008da0e902c790cdee88

URL: https://github.com/llvm/llvm-project/commit/c6828e0cea73f89a79db008da0e902c790cdee88
DIFF: https://github.com/llvm/llvm-project/commit/c6828e0cea73f89a79db008da0e902c790cdee88.diff

LOG: [mlir] Make ConversionTarget dynamic legality callbacks composable

* Change callback signature `bool(Operation *)` -> `Optional<bool>(Operation *)`
* addDynamicallyLegalOp add callback to the chain
* If callback returned empty `Optional` next callback in chain will be called

Differential Revision: https://reviews.llvm.org/D110487

Added: 
    mlir/unittests/Transforms/CMakeLists.txt
    mlir/unittests/Transforms/DialectConversion.cpp

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/unittests/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9fe9690375c01..86d79e192a0dd 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -661,7 +661,7 @@ class ConversionTarget {
 
   /// The signature of the callback used to determine if an operation is
   /// dynamically legal on the target.
-  using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
+  using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
 
   ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
   virtual ~ConversionTarget() = default;
@@ -827,10 +827,10 @@ class ConversionTarget {
   /// The set of information that configures the legalization of an operation.
   struct LegalizationInfo {
     /// The legality action this operation was given.
-    LegalizationAction action;
+    LegalizationAction action = LegalizationAction::Illegal;
 
     /// If some legal instances of this operation may also be recursively legal.
-    bool isRecursivelyLegal;
+    bool isRecursivelyLegal = false;
 
     /// The legality callback if this operation is dynamically legal.
     DynamicLegalityCallbackFn legalityFn;

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3ee743b2c4758..7e937825845e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
 /// Register a legality action for the given operation.
 void ConversionTarget::setOpAction(OperationName op,
                                    LegalizationAction action) {
-  legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
+  legalOperations[op].action = action;
 }
 
 /// Register a legality action for the given dialects.
@@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
   // Returns true if this operation instance is known to be legal.
   auto isOpLegal = [&] {
     // Handle dynamic legality either with the provided legality function.
-    if (info->action == LegalizationAction::Dynamic)
-      return info->legalityFn(op);
+    if (info->action == LegalizationAction::Dynamic) {
+      Optional<bool> result = info->legalityFn(op);
+      if (result)
+        return *result;
+    }
 
     // Otherwise, the operation is only legal if it was marked 'Legal'.
     return info->action == LegalizationAction::Legal;
@@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
   LegalOpDetails legalityDetails;
   if (info->isRecursivelyLegal) {
     auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
-    if (legalityFnIt != opRecursiveLegalityFns.end())
-      legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
-    else
+    if (legalityFnIt != opRecursiveLegalityFns.end()) {
+      legalityDetails.isRecursivelyLegal =
+          legalityFnIt->second(op).getValueOr(true);
+    } else {
       legalityDetails.isRecursivelyLegal = true;
+    }
   }
   return legalityDetails;
 }
 
+static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
+    ConversionTarget::DynamicLegalityCallbackFn oldCallback,
+    ConversionTarget::DynamicLegalityCallbackFn newCallback) {
+  if (!oldCallback)
+    return newCallback;
+
+  auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
+                   Operation *op) -> Optional<bool> {
+    if (Optional<bool> result = newCl(op))
+      return *result;
+
+    return oldCl(op);
+  };
+  return chain;
+}
+
 /// Set the dynamic legality callback for the given operation.
 void ConversionTarget::setLegalityCallback(
     OperationName name, const DynamicLegalityCallbackFn &callback) {
@@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
   assert(infoIt != legalOperations.end() &&
          infoIt->second.action == LegalizationAction::Dynamic &&
          "expected operation to already be marked as dynamically legal");
-  infoIt->second.legalityFn = callback;
+  infoIt->second.legalityFn =
+      composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
 }
 
 /// Set the recursive legality callback for the given operation and mark the
@@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
          "expected operation to already be marked as legal");
   infoIt->second.isRecursivelyLegal = true;
   if (callback)
-    opRecursiveLegalityFns[name] = callback;
+    opRecursiveLegalityFns[name] = composeLegalityCallbacks(
+        std::move(opRecursiveLegalityFns[name]), callback);
   else
     opRecursiveLegalityFns.erase(name);
 }
@@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
     ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
   assert(callback && "expected valid legality callback");
   for (StringRef dialect : dialects)
-    dialectLegalityFns[dialect] = callback;
+    dialectLegalityFns[dialect] = composeLegalityCallbacks(
+        std::move(dialectLegalityFns[dialect]), callback);
 }
 
 /// Set the dynamic legality callback for the unknown ops.
 void ConversionTarget::setLegalityCallback(
     const DynamicLegalityCallbackFn &callback) {
   assert(callback && "expected valid legality callback");
-  unknownLegalityFn = callback;
+  unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
 }
 
 /// Get the legalization information for the given operation.

diff  --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 45558b6d3dcee..c54313f84d23f 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -12,3 +12,4 @@ add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Rewrite)
 add_subdirectory(TableGen)
+add_subdirectory(Transforms)

diff  --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..9636f93835eb6
--- /dev/null
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_unittest(MLIRTransformsTests
+  DialectConversion.cpp
+)
+target_link_libraries(MLIRTransformsTests
+  PRIVATE
+  MLIRTransforms)

diff  --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp
new file mode 100644
index 0000000000000..d3e7ff6909e4c
--- /dev/null
+++ b/mlir/unittests/Transforms/DialectConversion.cpp
@@ -0,0 +1,90 @@
+//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/DialectConversion.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static Operation *createOp(MLIRContext *context) {
+  context->allowUnregisteredDialects();
+  return Operation::create(UnknownLoc::get(context),
+                           OperationName("foo.bar", context), llvm::None,
+                           llvm::None, llvm::None, llvm::None, 0);
+}
+
+namespace {
+struct DummyOp {
+  static StringRef getOperationName() { return "foo.bar"; }
+};
+
+TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
+  MLIRContext context;
+  ConversionTarget target(context);
+
+  int index = 0;
+  int callbackCalled1 = 0;
+  target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
+    callbackCalled1 = ++index;
+    return true;
+  });
+
+  int callbackCalled2 = 0;
+  target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
+    callbackCalled2 = ++index;
+    return llvm::None;
+  });
+
+  auto *op = createOp(&context);
+  EXPECT_TRUE(target.isLegal(op));
+  EXPECT_EQ(2, callbackCalled1);
+  EXPECT_EQ(1, callbackCalled2);
+  op->destroy();
+}
+
+TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
+  MLIRContext context;
+  ConversionTarget target(context);
+
+  int index = 0;
+  int callbackCalled = 0;
+  target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
+    callbackCalled = ++index;
+    return llvm::None;
+  });
+
+  auto *op = createOp(&context);
+  EXPECT_FALSE(target.isLegal(op));
+  EXPECT_EQ(1, callbackCalled);
+  op->destroy();
+}
+
+TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
+  MLIRContext context;
+  ConversionTarget target(context);
+
+  int index = 0;
+  int callbackCalled1 = 0;
+  target.markUnknownOpDynamicallyLegal([&](Operation *) {
+    callbackCalled1 = ++index;
+    return true;
+  });
+
+  int callbackCalled2 = 0;
+  target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
+    callbackCalled2 = ++index;
+    return llvm::None;
+  });
+
+  auto *op = createOp(&context);
+  EXPECT_TRUE(target.isLegal(op));
+  EXPECT_EQ(2, callbackCalled1);
+  EXPECT_EQ(1, callbackCalled2);
+  op->destroy();
+}
+} // namespace


        


More information about the Mlir-commits mailing list