[Mlir-commits] [mlir] 0e36074 - [mlir][DialectConversion] Cache type conversions and add a few useful helpers

River Riddle llvmlistbot at llvm.org
Mon Jun 15 15:58:01 PDT 2020


Author: River Riddle
Date: 2020-06-15T15:57:43-07:00
New Revision: 0e360744f36c5fa8a74f3f9e1e539ec9d43e27ee

URL: https://github.com/llvm/llvm-project/commit/0e360744f36c5fa8a74f3f9e1e539ec9d43e27ee
DIFF: https://github.com/llvm/llvm-project/commit/0e360744f36c5fa8a74f3f9e1e539ec9d43e27ee.diff

LOG: [mlir][DialectConversion] Cache type conversions and add a few useful helpers

It is quite common for the same type to be converted many types throughout the conversion process, and there isn't any good reason why we aren't caching that result. Especially given that we currently use identity conversion to signify legality. This revision also adds a few additional helpers to TypeConverter.

Differential Revision: https://reviews.llvm.org/D81679

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestBufferPlacement.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index bdb61d1a409a..f9d6671ae6fb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -72,16 +72,16 @@ class TypeConverter {
     /// used if the new types are not intended to remap an existing input.
     void addInputs(ArrayRef<Type> types);
 
-    /// Remap an input of the original signature with a range of types in the
-    /// new signature.
-    void remapInput(unsigned origInputNo, unsigned newInputNo,
-                    unsigned newInputCount = 1);
-
     /// Remap an input of the original signature to another `replacement`
     /// value. This drops the original argument.
     void remapInput(unsigned origInputNo, Value replacement);
 
   private:
+    /// Remap an input of the original signature with a range of types in the
+    /// new signature.
+    void remapInput(unsigned origInputNo, unsigned newInputNo,
+                    unsigned newInputCount = 1);
+
     /// The remapping information for each of the original arguments.
     SmallVector<Optional<InputMapping>, 4> remappedInputs;
 
@@ -149,16 +149,29 @@ class TypeConverter {
   /// Return true if the given type is legal for this type converter, i.e. the
   /// type converts to itself.
   bool isLegal(Type type);
+  /// Return true if all of the given types are legal for this type converter.
+  template <typename RangeT>
+  std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
+                       !std::is_convertible<RangeT, Operation *>::value,
+                   bool>
+  isLegal(RangeT &&range) {
+    return llvm::all_of(range, [this](Type type) { return isLegal(type); });
+  }
+  /// Return true if the given operation has legal operand and result types.
+  bool isLegal(Operation *op);
 
   /// Return true if the inputs and outputs of the given function type are
   /// legal.
-  bool isSignatureLegal(FunctionType funcType);
+  bool isSignatureLegal(FunctionType ty);
 
   /// This method allows for converting a specific argument of a signature. It
   /// takes as inputs the original argument input number, type.
   /// On success, it populates 'result' with any new mappings.
   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
                                     SignatureConversion &result);
+  LogicalResult convertSignatureArgs(TypeRange types,
+                                     SignatureConversion &result,
+                                     unsigned origInputOffset = 0);
 
   /// This function converts the type signature of the given block, by invoking
   /// 'convertSignatureArg' for each argument. This function should return a
@@ -214,6 +227,8 @@ class TypeConverter {
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {
     conversions.emplace_back(std::move(callback));
+    cachedDirectConversions.clear();
+    cachedMultiConversions.clear();
   }
 
   /// Generate a wrapper for the given materialization callback. The callback
@@ -240,6 +255,13 @@ class TypeConverter {
 
   /// The list of registered materialization functions.
   SmallVector<MaterializationCallbackFn, 2> materializations;
+
+  /// A set of cached conversions to avoid recomputing in the common case.
+  /// Direct 1-1 conversions are the most common, so this cache stores the
+  /// successful 1-1 conversions as well as all failed conversions.
+  DenseMap<Type, Type> cachedDirectConversions;
+  /// This cache stores the successful 1->N conversions, where N != 1.
+  DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 1f983e802eab..23570625e688 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -120,10 +120,8 @@ struct ConvertLinalgOnTensorsToBuffers
     target.addLegalDialect<StandardOpsDialect>();
 
     // Mark all Linalg operations illegal as long as they work on tensors.
-    auto isIllegalType = [&](Type type) { return !converter.isLegal(type); };
     auto isLegalOperation = [&](Operation *op) {
-      return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
-             llvm::none_of(op->getResultTypes(), isIllegalType);
+      return converter.isLegal(op);
     };
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
@@ -131,7 +129,7 @@ struct ConvertLinalgOnTensorsToBuffers
 
     // Mark Standard Return operations illegal as long as one operand is tensor.
     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
-      return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
+      return converter.isLegal(returnOp.getOperandTypes());
     });
 
     // Mark the function operation illegal as long as an argument is tensor.

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ff1ce3739d81..fdac287d5d69 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1742,11 +1742,35 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 /// This hooks allows for converting a type.
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) {
+  auto existingIt = cachedDirectConversions.find(t);
+  if (existingIt != cachedDirectConversions.end()) {
+    if (existingIt->second)
+      results.push_back(existingIt->second);
+    return success(existingIt->second != nullptr);
+  }
+  auto multiIt = cachedMultiConversions.find(t);
+  if (multiIt != cachedMultiConversions.end()) {
+    results.append(multiIt->second.begin(), multiIt->second.end());
+    return success();
+  }
+
   // Walk the added converters in reverse order to apply the most recently
   // registered first.
-  for (ConversionCallbackFn &converter : llvm::reverse(conversions))
-    if (Optional<LogicalResult> result = converter(t, results))
-      return *result;
+  size_t currentCount = results.size();
+  for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+    if (Optional<LogicalResult> result = converter(t, results)) {
+      if (!succeeded(*result)) {
+        cachedDirectConversions.try_emplace(t, nullptr);
+        return failure();
+      }
+      auto newTypes = ArrayRef<Type>(results).drop_front(currentCount);
+      if (newTypes.size() == 1)
+        cachedDirectConversions.try_emplace(t, newTypes.front());
+      else
+        cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes));
+      return success();
+    }
+  }
   return failure();
 }
 
@@ -1775,18 +1799,16 @@ LogicalResult TypeConverter::convertTypes(ArrayRef<Type> types,
 
 /// Return true if the given type is legal for this type converter, i.e. the
 /// type converts to itself.
-bool TypeConverter::isLegal(Type type) {
-  SmallVector<Type, 1> results;
-  return succeeded(convertType(type, results)) && results.size() == 1 &&
-         results.front() == type;
+bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
+/// Return true if the given operation has legal operand and result types.
+bool TypeConverter::isLegal(Operation *op) {
+  return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
 }
 
 /// Return true if the inputs and outputs of the given function type are
 /// legal.
-bool TypeConverter::isSignatureLegal(FunctionType funcType) {
-  return llvm::all_of(
-      llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
-      [this](Type type) { return isLegal(type); });
+bool TypeConverter::isSignatureLegal(FunctionType ty) {
+  return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
 }
 
 /// This hook allows for converting a specific argument of a signature.
@@ -1805,6 +1827,14 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   result.addInputs(inputNo, convertedTypes);
   return success();
 }
+LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
+                                                  SignatureConversion &result,
+                                                  unsigned origInputOffset) {
+  for (unsigned i = 0, e = types.size(); i != e; ++i)
+    if (failed(convertSignatureArg(origInputOffset + i, types[i], result)))
+      return failure();
+  return success();
+}
 
 Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
                                            Location loc, Type resultType,
@@ -1815,6 +1845,17 @@ Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
   return nullptr;
 }
 
+/// This function converts the type signature of the given block, by invoking
+/// 'convertSignatureArg' for each argument. This function should return a valid
+/// conversion for the signature on success, None otherwise.
+auto TypeConverter::convertBlockSignature(Block *block)
+    -> Optional<SignatureConversion> {
+  SignatureConversion conversion(block->getNumArguments());
+  if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+    return llvm::None;
+  return conversion;
+}
+
 /// Create a default conversion pattern that rewrites the type signature of a
 /// FuncOp.
 namespace {
@@ -1828,15 +1869,11 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
                   ConversionPatternRewriter &rewriter) const override {
     FunctionType type = funcOp.getType();
 
-    // Convert the original function arguments.
+    // Convert the original function types.
     TypeConverter::SignatureConversion result(type.getNumInputs());
-    for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
-      if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
-        return failure();
-
-    // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
-    if (failed(converter.convertTypes(type.getResults(), convertedResults)))
+    if (failed(converter.convertSignatureArgs(type.getInputs(), result)) ||
+        failed(converter.convertTypes(type.getResults(), convertedResults)))
       return failure();
 
     // Update the function signature in-place.
@@ -1859,19 +1896,6 @@ void mlir::populateFuncOpTypeConversionPattern(
   patterns.insert<FuncOpSignatureConversion>(ctx, converter);
 }
 
-/// This function converts the type signature of the given block, by invoking
-/// 'convertSignatureArg' for each argument. This function should return a valid
-/// conversion for the signature on success, None otherwise.
-auto TypeConverter::convertBlockSignature(Block *block)
-    -> Optional<SignatureConversion> {
-  SignatureConversion conversion(block->getNumArguments());
-  for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i)
-    if (failed(convertSignatureArg(i, block->getArgument(i).getType(),
-                                   conversion)))
-      return llvm::None;
-  return conversion;
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index dab8e196111a..2f9577c5b4b4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -314,10 +314,9 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 
     // Convert the original entry arguments.
     TypeConverter::SignatureConversion result(entry->getNumArguments());
-    for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i)
-      if (failed(converter.convertSignatureArg(
-              i, entry->getArgument(i).getType(), result)))
-        return failure();
+    if (failed(
+            converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
+      return failure();
 
     // Convert the region signature and just drop the operation.
     rewriter.applySignatureConversion(&region, result);

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 3d0cc290e9fc..cbccb7dbb1a4 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -124,10 +124,8 @@ struct TestBufferPlacementPreparationPass
     target.addLegalDialect<StandardOpsDialect>();
 
     // Mark all Linalg operations illegal as long as they work on tensors.
-    auto isIllegalType = [&](Type type) { return !converter.isLegal(type); };
     auto isLegalOperation = [&](Operation *op) {
-      return llvm::none_of(op->getOperandTypes(), isIllegalType) &&
-             llvm::none_of(op->getResultTypes(), isIllegalType);
+      return converter.isLegal(op);
     };
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
@@ -135,14 +133,12 @@ struct TestBufferPlacementPreparationPass
 
     // Mark Standard Return operations illegal as long as one operand is tensor.
     target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
-      return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
+      return converter.isLegal(returnOp.getOperandTypes());
     });
 
     // Mark Standard Call Operation illegal as long as it operates on tensor.
-    target.addDynamicallyLegalOp<mlir::CallOp>([&](mlir::CallOp callOp) {
-      return llvm::none_of(callOp.getOperandTypes(), isIllegalType) &&
-             llvm::none_of(callOp.getResultTypes(), isIllegalType);
-    });
+    target.addDynamicallyLegalOp<mlir::CallOp>(
+        [&](mlir::CallOp callOp) { return converter.isLegal(callOp); });
 
     // Mark the function whose arguments are in tensor-type illegal.
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {


        


More information about the Mlir-commits mailing list