[Mlir-commits] [mlir] b7a4649 - [mlir] ConversionTarget legality callbacks refactoring

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jul 24 05:00:33 PDT 2021


Author: Butygin
Date: 2021-07-24T14:59:36+03:00
New Revision: b7a464989955e6374b39b518e317b59b510d4dc5

URL: https://github.com/llvm/llvm-project/commit/b7a464989955e6374b39b518e317b59b510d4dc5
DIFF: https://github.com/llvm/llvm-project/commit/b7a464989955e6374b39b518e317b59b510d4dc5.diff

LOG: [mlir] ConversionTarget legality callbacks refactoring

* Get rid of Optional<std::function> as std::function already have a null state
* Add private setLegalityCallback function to set legality callback for unknown ops
* Get rid of unknownOpsDynamicallyLegal flag, use unknownLegalityFn state insted. This causes behavior change when user first calls markUnknownOpDynamicallyLegal with callback and then without but I am not sure is the original behavior was really a 'feature', or just oversignt in the original implementation.

Differential Revision: https://reviews.llvm.org/D105496

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b0de3a170e667..32945c2794e59 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -621,8 +621,7 @@ class ConversionTarget {
   /// dynamically legal on the target.
   using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
 
-  ConversionTarget(MLIRContext &ctx)
-      : unknownOpsDynamicallyLegal(false), ctx(ctx) {}
+  ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
   virtual ~ConversionTarget() = default;
 
   //===--------------------------------------------------------------------===//
@@ -739,18 +738,11 @@ class ConversionTarget {
     setDialectAction(dialectNames, LegalizationAction::Dynamic);
   }
   template <typename... Args>
-  void addDynamicallyLegalDialect(
-      Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
+  void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback = {}) {
     SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
     setDialectAction(dialectNames, LegalizationAction::Dynamic);
     if (callback)
-      setLegalityCallback(dialectNames, *callback);
-  }
-  template <typename... Args>
-  void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
-    SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
-    setDialectAction(dialectNames, LegalizationAction::Dynamic);
-    setLegalityCallback(dialectNames, callback);
+      setLegalityCallback(dialectNames, callback);
   }
 
   /// Register unknown operations as dynamically legal. For operations(and
@@ -758,10 +750,11 @@ class ConversionTarget {
   /// dynamically legal and invoke the given callback if valid or
   /// 'isDynamicallyLegal'.
   void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
-    unknownOpsDynamicallyLegal = true;
-    unknownLegalityFn = fn;
+    setLegalityCallback(fn);
+  }
+  void markUnknownOpDynamicallyLegal() {
+    setLegalityCallback([](Operation *) { return true; });
   }
-  void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
 
   /// Register the operations of the given dialects as illegal, i.e.
   /// operations of this dialect are not supported by the target.
@@ -805,6 +798,9 @@ class ConversionTarget {
   void setLegalityCallback(ArrayRef<StringRef> dialects,
                            const DynamicLegalityCallbackFn &callback);
 
+  /// Set the dynamic legality callback for the unknown ops.
+  void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
+
   /// Set the recursive legality callback for the given operation and mark the
   /// operation as recursively legal.
   void markOpRecursivelyLegal(OperationName name,
@@ -819,7 +815,7 @@ class ConversionTarget {
     bool isRecursivelyLegal;
 
     /// The legality callback if this operation is dynamically legal.
-    Optional<DynamicLegalityCallbackFn> legalityFn;
+    DynamicLegalityCallbackFn legalityFn;
   };
 
   /// Get the legalization information for the given operation.
@@ -841,11 +837,7 @@ class ConversionTarget {
   llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
 
   /// An optional legality callback for unknown operations.
-  Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
-
-  /// Flag indicating if unknown operations should be treated as dynamically
-  /// legal.
-  bool unknownOpsDynamicallyLegal;
+  DynamicLegalityCallbackFn unknownLegalityFn;
 
   /// The current context this target applies to.
   MLIRContext &ctx;

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8e1e2cbcb7ee7..5fbd9dd9db60a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2672,7 +2672,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
 /// Register a legality action for the given operation.
 void ConversionTarget::setOpAction(OperationName op,
                                    LegalizationAction action) {
-  legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
+  legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
 }
 
 /// Register a legality action for the given dialects.
@@ -2703,8 +2703,7 @@ auto ConversionTarget::isLegal(Operation *op) const
     // Handle dynamic legality either with the provided legality function, or
     // the default hook on the derived instance.
     if (info->action == LegalizationAction::Dynamic)
-      return info->legalityFn ? (*info->legalityFn)(op)
-                              : isDynamicallyLegal(op);
+      return info->legalityFn ? info->legalityFn(op) : isDynamicallyLegal(op);
 
     // Otherwise, the operation is only legal if it was marked 'Legal'.
     return info->action == LegalizationAction::Legal;
@@ -2758,6 +2757,13 @@ void ConversionTarget::setLegalityCallback(
     dialectLegalityFns[dialect] = callback;
 }
 
+/// Set the dynamic legality callback for the unknown ops.
+void ConversionTarget::setLegalityCallback(
+    const DynamicLegalityCallbackFn &callback) {
+  assert(callback && "expected valid legality callback");
+  unknownLegalityFn = callback;
+}
+
 /// Get the legalization information for the given operation.
 auto ConversionTarget::getOpInfo(OperationName op) const
     -> Optional<LegalizationInfo> {
@@ -2768,7 +2774,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
   // Check for info for the parent dialect.
   auto dialectIt = legalDialects.find(op.getDialectNamespace());
   if (dialectIt != legalDialects.end()) {
-    Optional<DynamicLegalityCallbackFn> callback;
+    DynamicLegalityCallbackFn callback;
     auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
     if (dialectFn != dialectLegalityFns.end())
       callback = dialectFn->second;
@@ -2776,7 +2782,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
                             callback};
   }
   // Otherwise, check if we mark unknown operations as dynamic.
-  if (unknownOpsDynamicallyLegal)
+  if (unknownLegalityFn)
     return LegalizationInfo{LegalizationAction::Dynamic,
                             /*isRecursivelyLegal=*/false, unknownLegalityFn};
   return llvm::None;


        


More information about the Mlir-commits mailing list