[Mlir-commits] [mlir] 9c5982e - [mlir] support recursive types in type conversion infra

Alex Zinenko llvmlistbot at llvm.org
Mon Nov 22 09:16:15 PST 2021


Author: Alex Zinenko
Date: 2021-11-22T18:16:02+01:00
New Revision: 9c5982ef8e95a0b5acdbd0d2599fbd87526abe2e

URL: https://github.com/llvm/llvm-project/commit/9c5982ef8e95a0b5acdbd0d2599fbd87526abe2e
DIFF: https://github.com/llvm/llvm-project/commit/9c5982ef8e95a0b5acdbd0d2599fbd87526abe2e.diff

LOG: [mlir] support recursive types in type conversion infra

MLIR supports recursive types but they could not be handled by the conversion
infrastructure directly as it would result in infinite recursion in
`convertType` for elemental types. Support this case by keeping the "call
stack" of nested type conversions in the TypeConverter class and by passing it
as an optional argument to the individual conversion callback. The callback can
then check if a specific type is present on the stack more than once to detect
and handle the recursive case.

This approach is preferred to the alternative approach of having a separate
callback dedicated to handling only the recursive case as the latter was
observed to introduce ~3% time overhead on a 50MB IR file even if it did not
contain recursive types.

This approach is also preferred to keeping a local stack in type converters
that need to handle recursive types as that would compose poorly in case of
out-of-tree or cross-project extensions.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D113579

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalize-type-conversion.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 4eca6d9e01734..394b15cda362b 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -307,6 +307,14 @@ class TypeConverter {
   ///       existing value are expected to be removed during conversion. If
   ///       `llvm::None` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
+  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
+  ///     - This form represents a 1-N type conversion supporting recursive
+  ///       types. The first two arguments and the return value are the same as
+  ///       for the regular 1-N form. The third argument is contains is the
+  ///       "call stack" of the recursive conversion: it contains the list of
+  ///       types currently being converted, with the current type being the
+  ///       last one. If it is present more than once in the list, the
+  ///       conversion concerns a recursive type.
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT,

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2953a5d43c9ae..e66dbbc664b4e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -101,6 +101,14 @@ class TypeConverter {
   ///       existing value are expected to be removed during conversion. If
   ///       `llvm::None` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
+  ///   * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
+  ///     - This form represents a 1-N type conversion supporting recursive
+  ///       types. The first two arguments and the return value are the same as
+  ///       for the regular 1-N form. The third argument is contains is the
+  ///       "call stack" of the recursive conversion: it contains the list of
+  ///       types currently being converted, with the current type being the
+  ///       last one. If it is present more than once in the list, the
+  ///       conversion concerns a recursive type.
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -221,8 +229,8 @@ class TypeConverter {
   /// The signature of the callback used to convert a type. If the new set of
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
-  using ConversionCallbackFn =
-      std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
+  using ConversionCallbackFn = std::function<Optional<LogicalResult>(
+      Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
 
   /// The signature of the callback used to materialize a conversion.
   using MaterializationCallbackFn =
@@ -240,28 +248,44 @@ class TypeConverter {
   template <typename T, typename FnT>
   std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
-    return wrapCallback<T>([callback = std::forward<FnT>(callback)](
-                               T type, SmallVectorImpl<Type> &results) {
-      if (Optional<Type> resultOpt = callback(type)) {
-        bool wasSuccess = static_cast<bool>(resultOpt.getValue());
-        if (wasSuccess)
-          results.push_back(resultOpt.getValue());
-        return Optional<LogicalResult>(success(wasSuccess));
-      }
-      return Optional<LogicalResult>();
-    });
-  }
-  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)`
+    return wrapCallback<T>(
+        [callback = std::forward<FnT>(callback)](
+            T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
+          if (Optional<Type> resultOpt = callback(type)) {
+            bool wasSuccess = static_cast<bool>(resultOpt.getValue());
+            if (wasSuccess)
+              results.push_back(resultOpt.getValue());
+            return Optional<LogicalResult>(success(wasSuccess));
+          }
+          return Optional<LogicalResult>();
+        });
+  }
+  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
+  /// &)`
   template <typename T, typename FnT>
-  std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
+  std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) {
+    return wrapCallback<T>(
+        [callback = std::forward<FnT>(callback)](
+            T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
+          return callback(type, results);
+        });
+  }
+  /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
+  /// &, ArrayRef<Type>)`.
+  template <typename T, typename FnT>
+  std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
+                                      ArrayRef<Type>>::value,
+                   ConversionCallbackFn>
   wrapCallback(FnT &&callback) {
     return [callback = std::forward<FnT>(callback)](
-               Type type,
-               SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
+               Type type, SmallVectorImpl<Type> &results,
+               ArrayRef<Type> callStack) -> Optional<LogicalResult> {
       T derivedType = type.dyn_cast<T>();
       if (!derivedType)
         return llvm::None;
-      return callback(derivedType, results);
+      return callback(derivedType, results, callStack);
     };
   }
 
@@ -300,6 +324,10 @@ class TypeConverter {
   DenseMap<Type, Type> cachedDirectConversions;
   /// This cache stores the successful 1->N conversions, where N != 1.
   DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
+
+  /// Stores the types that are being converted in the case when convertType
+  /// is being called recursively to convert nested types.
+  SmallVector<Type, 2> conversionCallStack;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cc83ab353b06b..1d793f91da5fd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/FunctionSupport.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
@@ -2931,8 +2932,12 @@ LogicalResult TypeConverter::convertType(Type t,
   // Walk the added converters in reverse order to apply the most recently
   // registered first.
   size_t currentCount = results.size();
+  conversionCallStack.push_back(t);
+  auto popConversionCallStack =
+      llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
   for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (Optional<LogicalResult> result = converter(t, results)) {
+    if (Optional<LogicalResult> result =
+            converter(t, results, conversionCallStack)) {
       if (!succeeded(*result)) {
         cachedDirectConversions.try_emplace(t, nullptr);
         return failure();

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index cfb09f8f272ac..a42ed6184fe63 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -112,3 +112,12 @@ func @test_signature_conversion_no_converter() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @recursive_type_conversion
+func @recursive_type_conversion() {
+  // CHECK:  !test.test_rec<outer_converted_type, smpla>
+  "test.type_producer"() : () -> !test.test_rec<something, test_rec<something>>
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8cc89c6ce6444..ed84b8cba3f68 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "TestTypes.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
@@ -924,10 +925,16 @@ struct TestTypeConversionProducer
   matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     Type resultType = op.getType();
+    Type convertedType = getTypeConverter()
+                             ? getTypeConverter()->convertType(resultType)
+                             : resultType;
     if (resultType.isa<FloatType>())
       resultType = rewriter.getF64Type();
     else if (resultType.isInteger(16))
       resultType = rewriter.getIntegerType(64);
+    else if (resultType.isa<test::TestRecursiveType>() &&
+             convertedType != resultType)
+      resultType = convertedType;
     else
       return failure();
 
@@ -1035,6 +1042,35 @@ struct TestTypeConversionDriver
       // Drop all integer types.
       return success();
     });
+    converter.addConversion(
+        // Convert a recursive self-referring type into a non-self-referring
+        // type named "outer_converted_type" that contains a SimpleAType.
+        [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
+            ArrayRef<Type> callStack) -> Optional<LogicalResult> {
+          // If the type is already converted, return it to indicate that it is
+          // legal.
+          if (type.getName() == "outer_converted_type") {
+            results.push_back(type);
+            return success();
+          }
+
+          // If the type is on the call stack more than once (it is there at
+          // least once because of the _current_ call, which is always the last
+          // element on the stack), we've hit the recursive case. Just return
+          // SimpleAType here to create a non-recursive type as a result.
+          if (llvm::is_contained(callStack.drop_back(), type)) {
+            results.push_back(test::SimpleAType::get(type.getContext()));
+            return success();
+          }
+
+          // Convert the body recursively.
+          auto result = test::TestRecursiveType::get(type.getContext(),
+                                                     "outer_converted_type");
+          if (failed(result.setBody(converter.convertType(type.getBody()))))
+            return failure();
+          results.push_back(result);
+          return success();
+        });
 
     /// Add the legal set of type materializations.
     converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@@ -1059,7 +1095,10 @@ struct TestTypeConversionDriver
     // Initialize the conversion target.
     mlir::ConversionTarget target(getContext());
     target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
-      return op.getType().isF64() || op.getType().isInteger(64);
+      auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
+      return op.getType().isF64() || op.getType().isInteger(64) ||
+             (recursiveType &&
+              recursiveType.getName() == "outer_converted_type");
     });
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
       return converter.isSignatureLegal(op.getType()) &&


        


More information about the Mlir-commits mailing list