[Mlir-commits] [mlir] b7a4649 - [mlir] ConversionTarget legality callbacks refactoring
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 24 05:00:33 PDT 2021
Author: Butygin
Date: 2021-07-24T14:59:36+03:00
New Revision: b7a464989955e6374b39b518e317b59b510d4dc5
URL: https://github.com/llvm/llvm-project/commit/b7a464989955e6374b39b518e317b59b510d4dc5
DIFF: https://github.com/llvm/llvm-project/commit/b7a464989955e6374b39b518e317b59b510d4dc5.diff
LOG: [mlir] ConversionTarget legality callbacks refactoring
* Get rid of Optional<std::function> as std::function already have a null state
* Add private setLegalityCallback function to set legality callback for unknown ops
* Get rid of unknownOpsDynamicallyLegal flag, use unknownLegalityFn state insted. This causes behavior change when user first calls markUnknownOpDynamicallyLegal with callback and then without but I am not sure is the original behavior was really a 'feature', or just oversignt in the original implementation.
Differential Revision: https://reviews.llvm.org/D105496
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b0de3a170e667..32945c2794e59 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -621,8 +621,7 @@ class ConversionTarget {
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
- ConversionTarget(MLIRContext &ctx)
- : unknownOpsDynamicallyLegal(false), ctx(ctx) {}
+ ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
//===--------------------------------------------------------------------===//
@@ -739,18 +738,11 @@ class ConversionTarget {
setDialectAction(dialectNames, LegalizationAction::Dynamic);
}
template <typename... Args>
- void addDynamicallyLegalDialect(
- Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
+ void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback = {}) {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
if (callback)
- setLegalityCallback(dialectNames, *callback);
- }
- template <typename... Args>
- void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
- SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
- setDialectAction(dialectNames, LegalizationAction::Dynamic);
- setLegalityCallback(dialectNames, callback);
+ setLegalityCallback(dialectNames, callback);
}
/// Register unknown operations as dynamically legal. For operations(and
@@ -758,10 +750,11 @@ class ConversionTarget {
/// dynamically legal and invoke the given callback if valid or
/// 'isDynamicallyLegal'.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
- unknownOpsDynamicallyLegal = true;
- unknownLegalityFn = fn;
+ setLegalityCallback(fn);
+ }
+ void markUnknownOpDynamicallyLegal() {
+ setLegalityCallback([](Operation *) { return true; });
}
- void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
/// Register the operations of the given dialects as illegal, i.e.
/// operations of this dialect are not supported by the target.
@@ -805,6 +798,9 @@ class ConversionTarget {
void setLegalityCallback(ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback);
+ /// Set the dynamic legality callback for the unknown ops.
+ void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
+
/// Set the recursive legality callback for the given operation and mark the
/// operation as recursively legal.
void markOpRecursivelyLegal(OperationName name,
@@ -819,7 +815,7 @@ class ConversionTarget {
bool isRecursivelyLegal;
/// The legality callback if this operation is dynamically legal.
- Optional<DynamicLegalityCallbackFn> legalityFn;
+ DynamicLegalityCallbackFn legalityFn;
};
/// Get the legalization information for the given operation.
@@ -841,11 +837,7 @@ class ConversionTarget {
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
/// An optional legality callback for unknown operations.
- Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
-
- /// Flag indicating if unknown operations should be treated as dynamically
- /// legal.
- bool unknownOpsDynamicallyLegal;
+ DynamicLegalityCallbackFn unknownLegalityFn;
/// The current context this target applies to.
MLIRContext &ctx;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8e1e2cbcb7ee7..5fbd9dd9db60a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2672,7 +2672,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
- legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
+ legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
}
/// Register a legality action for the given dialects.
@@ -2703,8 +2703,7 @@ auto ConversionTarget::isLegal(Operation *op) const
// Handle dynamic legality either with the provided legality function, or
// the default hook on the derived instance.
if (info->action == LegalizationAction::Dynamic)
- return info->legalityFn ? (*info->legalityFn)(op)
- : isDynamicallyLegal(op);
+ return info->legalityFn ? info->legalityFn(op) : isDynamicallyLegal(op);
// Otherwise, the operation is only legal if it was marked 'Legal'.
return info->action == LegalizationAction::Legal;
@@ -2758,6 +2757,13 @@ void ConversionTarget::setLegalityCallback(
dialectLegalityFns[dialect] = callback;
}
+/// Set the dynamic legality callback for the unknown ops.
+void ConversionTarget::setLegalityCallback(
+ const DynamicLegalityCallbackFn &callback) {
+ assert(callback && "expected valid legality callback");
+ unknownLegalityFn = callback;
+}
+
/// Get the legalization information for the given operation.
auto ConversionTarget::getOpInfo(OperationName op) const
-> Optional<LegalizationInfo> {
@@ -2768,7 +2774,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
// Check for info for the parent dialect.
auto dialectIt = legalDialects.find(op.getDialectNamespace());
if (dialectIt != legalDialects.end()) {
- Optional<DynamicLegalityCallbackFn> callback;
+ DynamicLegalityCallbackFn callback;
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
if (dialectFn != dialectLegalityFns.end())
callback = dialectFn->second;
@@ -2776,7 +2782,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
callback};
}
// Otherwise, check if we mark unknown operations as dynamic.
- if (unknownOpsDynamicallyLegal)
+ if (unknownLegalityFn)
return LegalizationInfo{LegalizationAction::Dynamic,
/*isRecursivelyLegal=*/false, unknownLegalityFn};
return llvm::None;
More information about the Mlir-commits
mailing list