[Mlir-commits] [mlir] c6828e0 - [mlir] Make ConversionTarget dynamic legality callbacks composable
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 12 03:06:28 PDT 2021
Author: Caitlyn Cano
Date: 2021-10-12T13:05:54+03:00
New Revision: c6828e0cea73f89a79db008da0e902c790cdee88
URL: https://github.com/llvm/llvm-project/commit/c6828e0cea73f89a79db008da0e902c790cdee88
DIFF: https://github.com/llvm/llvm-project/commit/c6828e0cea73f89a79db008da0e902c790cdee88.diff
LOG: [mlir] Make ConversionTarget dynamic legality callbacks composable
* Change callback signature `bool(Operation *)` -> `Optional<bool>(Operation *)`
* addDynamicallyLegalOp add callback to the chain
* If callback returned empty `Optional` next callback in chain will be called
Differential Revision: https://reviews.llvm.org/D110487
Added:
mlir/unittests/Transforms/CMakeLists.txt
mlir/unittests/Transforms/DialectConversion.cpp
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/unittests/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9fe9690375c01..86d79e192a0dd 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -661,7 +661,7 @@ class ConversionTarget {
/// The signature of the callback used to determine if an operation is
/// dynamically legal on the target.
- using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
+ using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
@@ -827,10 +827,10 @@ class ConversionTarget {
/// The set of information that configures the legalization of an operation.
struct LegalizationInfo {
/// The legality action this operation was given.
- LegalizationAction action;
+ LegalizationAction action = LegalizationAction::Illegal;
/// If some legal instances of this operation may also be recursively legal.
- bool isRecursivelyLegal;
+ bool isRecursivelyLegal = false;
/// The legality callback if this operation is dynamically legal.
DynamicLegalityCallbackFn legalityFn;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3ee743b2c4758..7e937825845e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
- legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
+ legalOperations[op].action = action;
}
/// Register a legality action for the given dialects.
@@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
// Returns true if this operation instance is known to be legal.
auto isOpLegal = [&] {
// Handle dynamic legality either with the provided legality function.
- if (info->action == LegalizationAction::Dynamic)
- return info->legalityFn(op);
+ if (info->action == LegalizationAction::Dynamic) {
+ Optional<bool> result = info->legalityFn(op);
+ if (result)
+ return *result;
+ }
// Otherwise, the operation is only legal if it was marked 'Legal'.
return info->action == LegalizationAction::Legal;
@@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
LegalOpDetails legalityDetails;
if (info->isRecursivelyLegal) {
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
- if (legalityFnIt != opRecursiveLegalityFns.end())
- legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
- else
+ if (legalityFnIt != opRecursiveLegalityFns.end()) {
+ legalityDetails.isRecursivelyLegal =
+ legalityFnIt->second(op).getValueOr(true);
+ } else {
legalityDetails.isRecursivelyLegal = true;
+ }
}
return legalityDetails;
}
+static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
+ ConversionTarget::DynamicLegalityCallbackFn oldCallback,
+ ConversionTarget::DynamicLegalityCallbackFn newCallback) {
+ if (!oldCallback)
+ return newCallback;
+
+ auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
+ Operation *op) -> Optional<bool> {
+ if (Optional<bool> result = newCl(op))
+ return *result;
+
+ return oldCl(op);
+ };
+ return chain;
+}
+
/// Set the dynamic legality callback for the given operation.
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) {
@@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
assert(infoIt != legalOperations.end() &&
infoIt->second.action == LegalizationAction::Dynamic &&
"expected operation to already be marked as dynamically legal");
- infoIt->second.legalityFn = callback;
+ infoIt->second.legalityFn =
+ composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
}
/// Set the recursive legality callback for the given operation and mark the
@@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
"expected operation to already be marked as legal");
infoIt->second.isRecursivelyLegal = true;
if (callback)
- opRecursiveLegalityFns[name] = callback;
+ opRecursiveLegalityFns[name] = composeLegalityCallbacks(
+ std::move(opRecursiveLegalityFns[name]), callback);
else
opRecursiveLegalityFns.erase(name);
}
@@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
for (StringRef dialect : dialects)
- dialectLegalityFns[dialect] = callback;
+ dialectLegalityFns[dialect] = composeLegalityCallbacks(
+ std::move(dialectLegalityFns[dialect]), callback);
}
/// Set the dynamic legality callback for the unknown ops.
void ConversionTarget::setLegalityCallback(
const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
- unknownLegalityFn = callback;
+ unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
}
/// Get the legalization information for the given operation.
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 45558b6d3dcee..c54313f84d23f 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -12,3 +12,4 @@ add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Rewrite)
add_subdirectory(TableGen)
+add_subdirectory(Transforms)
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..9636f93835eb6
--- /dev/null
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_unittest(MLIRTransformsTests
+ DialectConversion.cpp
+)
+target_link_libraries(MLIRTransformsTests
+ PRIVATE
+ MLIRTransforms)
diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp
new file mode 100644
index 0000000000000..d3e7ff6909e4c
--- /dev/null
+++ b/mlir/unittests/Transforms/DialectConversion.cpp
@@ -0,0 +1,90 @@
+//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/DialectConversion.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static Operation *createOp(MLIRContext *context) {
+ context->allowUnregisteredDialects();
+ return Operation::create(UnknownLoc::get(context),
+ OperationName("foo.bar", context), llvm::None,
+ llvm::None, llvm::None, llvm::None, 0);
+}
+
+namespace {
+struct DummyOp {
+ static StringRef getOperationName() { return "foo.bar"; }
+};
+
+TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
+ MLIRContext context;
+ ConversionTarget target(context);
+
+ int index = 0;
+ int callbackCalled1 = 0;
+ target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
+ callbackCalled1 = ++index;
+ return true;
+ });
+
+ int callbackCalled2 = 0;
+ target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
+ callbackCalled2 = ++index;
+ return llvm::None;
+ });
+
+ auto *op = createOp(&context);
+ EXPECT_TRUE(target.isLegal(op));
+ EXPECT_EQ(2, callbackCalled1);
+ EXPECT_EQ(1, callbackCalled2);
+ op->destroy();
+}
+
+TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
+ MLIRContext context;
+ ConversionTarget target(context);
+
+ int index = 0;
+ int callbackCalled = 0;
+ target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
+ callbackCalled = ++index;
+ return llvm::None;
+ });
+
+ auto *op = createOp(&context);
+ EXPECT_FALSE(target.isLegal(op));
+ EXPECT_EQ(1, callbackCalled);
+ op->destroy();
+}
+
+TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
+ MLIRContext context;
+ ConversionTarget target(context);
+
+ int index = 0;
+ int callbackCalled1 = 0;
+ target.markUnknownOpDynamicallyLegal([&](Operation *) {
+ callbackCalled1 = ++index;
+ return true;
+ });
+
+ int callbackCalled2 = 0;
+ target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
+ callbackCalled2 = ++index;
+ return llvm::None;
+ });
+
+ auto *op = createOp(&context);
+ EXPECT_TRUE(target.isLegal(op));
+ EXPECT_EQ(2, callbackCalled1);
+ EXPECT_EQ(1, callbackCalled2);
+ op->destroy();
+}
+} // namespace
More information about the Mlir-commits
mailing list