[Mlir-commits] [mlir] 7cc7998 - [mlir] Allow to use constant lambda as callbacks for `TypeConverter`

Mehdi Amini llvmlistbot at llvm.org
Tue Feb 2 10:26:55 PST 2021


Author: Vladislav Vinogradov
Date: 2021-02-02T18:26:45Z
New Revision: 7cc79984979014c86b6b4672ed0df93a74b2f218

URL: https://github.com/llvm/llvm-project/commit/7cc79984979014c86b6b4672ed0df93a74b2f218
DIFF: https://github.com/llvm/llvm-project/commit/7cc79984979014c86b6b4672ed0df93a74b2f218.diff

LOG: [mlir] Allow to use constant lambda as callbacks for `TypeConverter`

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95787

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ae2e2d73cf58..a67c993b6ab5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -103,8 +103,8 @@ class TypeConverter {
   ///       conversion function to perform the conversion.
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
+  template <typename FnT, typename T = typename llvm::function_traits<
+                              std::decay_t<FnT>>::template arg_t<0>>
   void addConversion(FnT &&callback) {
     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
   }
@@ -124,8 +124,8 @@ class TypeConverter {
   ///
   /// This method registers a materialization that will be called when
   /// converting an illegal block argument type, to a legal type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  template <typename FnT, typename T = typename llvm::function_traits<
+                              std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
     argumentMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
@@ -133,16 +133,16 @@ class TypeConverter {
   /// This method registers a materialization that will be called when
   /// converting a legal type to an illegal source type. This is used when
   /// conversions to an illegal type must persist beyond the main conversion.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  template <typename FnT, typename T = typename llvm::function_traits<
+                              std::decay_t<FnT>>::template arg_t<1>>
   void addSourceMaterialization(FnT &&callback) {
     sourceMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
   /// This method registers a materialization that will be called when
   /// converting type from an illegal, or source, type to a legal type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  template <typename FnT, typename T = typename llvm::function_traits<
+                              std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 7da02ed0c5c3..1b2fab124c86 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -492,8 +492,17 @@ struct TestTypeConverter : public TypeConverter {
   TestTypeConverter() {
     addConversion(convertType);
     addArgumentMaterialization(materializeCast);
-    addArgumentMaterialization(materializeOneToOneCast);
     addSourceMaterialization(materializeCast);
+
+    /// Materialize the cast for one-to-one conversion from i64 to f64.
+    const auto materializeOneToOneCast =
+        [](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
+           Location loc) -> Optional<Value> {
+      if (resultType.getWidth() == 42 && inputs.size() == 1)
+        return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+      return llvm::None;
+    };
+    addArgumentMaterialization(materializeOneToOneCast);
   }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -532,16 +541,6 @@ struct TestTypeConverter : public TypeConverter {
       return inputs[0];
     return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
   }
-
-  /// Materialize the cast for one-to-one conversion from i64 to f64.
-  static Optional<Value> materializeOneToOneCast(OpBuilder &builder,
-                                                 IntegerType resultType,
-                                                 ValueRange inputs,
-                                                 Location loc) {
-    if (resultType.getWidth() == 42 && inputs.size() == 1)
-      return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
-    return llvm::None;
-  }
 };
 
 struct TestLegalizePatternDriver


        


More information about the Mlir-commits mailing list