[Mlir-commits] [mlir] 5c5dafc - [mlir] support materialization for 1-1 type conversions

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 2 04:48:43 PDT 2020


Author: Alex Zinenko
Date: 2020-06-02T13:48:33+02:00
New Revision: 5c5dafc534ac80aad978f4092ff842457aab6d07

URL: https://github.com/llvm/llvm-project/commit/5c5dafc534ac80aad978f4092ff842457aab6d07
DIFF: https://github.com/llvm/llvm-project/commit/5c5dafc534ac80aad978f4092ff842457aab6d07.diff

LOG: [mlir] support materialization for 1-1 type conversions

Dialect conversion infrastructure supports 1->N type conversions by requiring
individual conversions to provide facilities to generate operations
retrofitting N values into 1 of the original type when N > 1. This
functionality can also be used to materialize explicit "cast"-like operations,
but it did not support 1->1 type conversions until now. Modify TypeConverter to
support materialization of cast operations for 1-1 conversions.

This also makes materialization specification more extensible following the
same pattern as type conversions. Instead of overloading a virtual function,
users or subclasses of TypeConversion can now register type-specific
materialization callbacks that will be called in order for the given type.

Differential Revision: https://reviews.llvm.org/D79729

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 1f9ec14b6a96..7995099636e9 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -217,16 +217,20 @@ class TypeConverter {
   template <typename ConversionFnT>
   void addConversion(ConversionFnT &&callback);
 
-  /// This hook allows for materializing a conversion from a set of types into
-  /// one result type by generating a cast operation of some kind. The generated
-  /// operation should produce one result, of 'resultType', with the provided
-  /// 'inputs' as operands. This hook must be overridden when a type conversion
+  /// Register a materialization function, which must be convertibe to the
+  /// following form
+  ///   `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
+  /// where `T` is any subclass of `Type`. This function is responsible for
+  /// creating an operation, using the PatternRewriter and Location provided,
+  /// that "casts" a range of values into a single value of the given type `T`.
+  /// It must return a Value of the converted type on success, an `llvm::None`
+  /// if it failed but other materialization can be attempted, and `nullptr` on
+  /// unrecoverable failure. It will only be called for (sub)types of `T`.
+  /// Materialization functions must be provided when a type conversion
   /// results in more than one type, or if a type conversion may persist after
   /// the conversion has finished.
-  virtual Operation *materializeConversion(PatternRewriter &rewriter,
-                                           Type resultType,
-                                           ArrayRef<Value> inputs,
-                                           Location loc);
+  template <typename FnT>
+  void addMaterialization(FnT &&callback);
 };
 ```
 

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index c241de6ff6fe..a5c14139668b 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -122,11 +122,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// pointers to memref descriptors for arguments.
   LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
 
-  /// Creates descriptor structs from individual values constituting them.
-  Operation *materializeConversion(PatternRewriter &rewriter, Type type,
-                                   ArrayRef<Value> values,
-                                   Location loc) override;
-
   /// Gets the LLVM representation of the index type. The returned type is an
   /// integer type with the size configured for this type converter.
   LLVM::LLVMType getIndexType();

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 31b5e04c9dbd..029c38f6a1c7 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -113,6 +113,25 @@ class TypeConverter {
     registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
   }
 
+  /// Register a materialization function, which must be convertible to the
+  /// following form:
+  ///   `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
+  /// where `T` is any subclass of `Type`. This function is responsible for
+  /// creating an operation, using the PatternRewriter and Location provided,
+  /// that "casts" a range of values into a single value of the given type `T`.
+  /// It must return a Value of the converted type on success, an `llvm::None`
+  /// if it failed but other materialization can be attempted, and `nullptr` on
+  /// unrecoverable failure. It will only be called for (sub)types of `T`.
+  /// Materialization functions must be provided when a type conversion
+  /// results in more than one type, or if a type conversion may persist after
+  /// the conversion has finished.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  void addMaterialization(FnT &&callback) {
+    registerMaterialization(
+        wrapMaterialization<T>(std::forward<FnT>(callback)));
+  }
+
   /// Convert the given type. This function should return failure if no valid
   /// conversion exists, success otherwise. If the new set of types is empty,
   /// the type is removed and any usages of the existing value are expected to
@@ -148,18 +167,10 @@ class TypeConverter {
   /// valid conversion for the signature on success, None otherwise.
   Optional<SignatureConversion> convertBlockSignature(Block *block);
 
-  /// This hook allows for materializing a conversion from a set of types into
-  /// one result type by generating a cast operation of some kind. The generated
-  /// operation should produce one result, of 'resultType', with the provided
-  /// 'inputs' as operands. This hook must be overridden when a type conversion
-  /// results in more than one type, or if a type conversion may persist after
-  /// the conversion has finished.
-  virtual Operation *materializeConversion(PatternRewriter &rewriter,
-                                           Type resultType,
-                                           ArrayRef<Value> inputs,
-                                           Location loc) {
-    llvm_unreachable("expected 'materializeConversion' to be overridden");
-  }
+  /// Materialize a conversion from a set of types into one result type by
+  /// generating a cast operation of some kind.
+  Value materializeConversion(PatternRewriter &rewriter, Location loc,
+                              Type resultType, ValueRange inputs);
 
 private:
   /// The signature of the callback used to convert a type. If the new set of
@@ -168,6 +179,9 @@ class TypeConverter {
   using ConversionCallbackFn =
       std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
 
+  using MaterializationCallbackFn = std::function<Optional<Value>(
+      PatternRewriter &, Type, ValueRange, Location)>;
+
   /// Generate a wrapper for the given callback. This allows for accepting
   /// 
diff erent callback forms, that all compose into a single version.
   /// With callback of form: `Optional<Type>(T)`
@@ -204,8 +218,30 @@ class TypeConverter {
     conversions.emplace_back(std::move(callback));
   }
 
+  /// Generate a wrapper for the given materialization callback. The callback
+  /// may take any subclass of `Type` and the wrapper will check for the target
+  /// type to be of the expected class before calling the callback.
+  template <typename T, typename FnT>
+  MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
+    return [callback = std::forward<FnT>(callback)](
+               PatternRewriter &rewriter, Type resultType, ValueRange inputs,
+               Location loc) -> Optional<Value> {
+      if (T derivedType = resultType.dyn_cast<T>())
+        return callback(rewriter, derivedType, inputs, loc);
+      return llvm::None;
+    };
+  }
+
+  /// Register a materialization.
+  void registerMaterialization(MaterializationCallbackFn &&callback) {
+    materializations.emplace_back(std::move(callback));
+  }
+
   /// The set of registered conversion functions.
   SmallVector<ConversionCallbackFn, 4> conversions;
+
+  /// The list of registered materialization functions.
+  SmallVector<MaterializationCallbackFn, 2> materializations;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 8cc2315ddd15..4294e0024e79 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -150,6 +150,24 @@ LLVMTypeConverter::LLVMTypeConverter(
 
   // LLVMType is legal, so add a pass-through conversion.
   addConversion([](LLVM::LLVMType type) { return type; });
+
+  // Materialization for memrefs creates descriptor structs from individual
+  // values constituting them, when descriptors are used, i.e. more than one
+  // value represents a memref.
+  addMaterialization([&](PatternRewriter &rewriter,
+                         UnrankedMemRefType resultType, ValueRange inputs,
+                         Location loc) -> Optional<Value> {
+    if (inputs.size() == 1)
+      return llvm::None;
+    return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
+                                          inputs);
+  });
+  addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
+                         ValueRange inputs, Location loc) -> Optional<Value> {
+    if (inputs.size() == 1)
+      return llvm::None;
+    return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
+  });
 }
 
 /// Returns the MLIR context.
@@ -297,22 +315,6 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
   return LLVM::LLVMType::getFunctionTy(resultType, inputs, false);
 }
 
-/// Creates descriptor structs from individual values constituting them.
-Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter,
-                                                    Type type,
-                                                    ArrayRef<Value> values,
-                                                    Location loc) {
-  if (auto unrankedMemRefType = type.dyn_cast<UnrankedMemRefType>())
-    return UnrankedMemRefDescriptor::pack(rewriter, loc, *this,
-                                          unrankedMemRefType, values)
-        .getDefiningOp();
-
-  auto memRefType = type.dyn_cast<MemRefType>();
-  assert(memRefType && "1->N conversion is only supported for memrefs");
-  return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values)
-      .getDefiningOp();
-}
-
 // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which
 // contains:
 //   1. the pointer to the data buffer, followed by

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 0e2bf03f9639..ff1ce3739d81 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -305,27 +305,18 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
         // persist in the IR after conversion.
         if (!origArg.use_empty()) {
           rewriter.setInsertionPointToStart(newBlock);
-          auto *newOp = typeConverter->materializeConversion(
-              rewriter, origArg.getType(), llvm::None, loc);
-          origArg.replaceAllUsesWith(newOp->getResult(0));
+          Value newArg = typeConverter->materializeConversion(
+              rewriter, loc, origArg.getType(), llvm::None);
+          assert(newArg &&
+                 "Couldn't materialize a block argument after 1->0 conversion");
+          origArg.replaceAllUsesWith(newArg);
         }
         continue;
       }
 
-      // If mapping is 1-1, replace the remaining uses and drop the cast
-      // operation.
-      // FIXME(riverriddle) This should check that the result type and operand
-      // type are the same, otherwise it should force a conversion to be
-      // materialized.
-      if (argInfo->newArgSize == 1) {
-        origArg.replaceAllUsesWith(
-            mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
-        continue;
-      }
-
-      // Otherwise this is a 1->N value mapping.
+      // Otherwise this is a 1->1+ value mapping.
       Value castValue = argInfo->castValue;
-      assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
+      assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
 
       // If the argument is still used, replace it with the generated cast.
       if (!origArg.use_empty())
@@ -333,7 +324,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
 
       // If all users of the cast were removed, we can drop it. Otherwise, keep
       // the operation alive and let the user handle any remaining usages.
-      if (castValue.use_empty())
+      if (castValue.use_empty() && castValue.getDefiningOp())
         castValue.getDefiningOp()->erase();
     }
   }
@@ -389,22 +380,22 @@ Block *ArgConverter::applySignatureConversion(
       continue;
     }
 
-    // If this is a 1->1 mapping, then map the argument directly.
-    if (inputMap->size == 1) {
-      mapping.map(origArg, newArgs[inputMap->inputNo]);
-      info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
-      continue;
-    }
-
-    // Otherwise, this is a 1->N mapping. Call into the provided type converter
-    // to pack the new values.
+    // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
+    // to pack the new values. For 1->1 mappings, if there is no materialization
+    // provided, use the argument directly instead.
     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Operation *cast = typeConverter->materializeConversion(
-        rewriter, origArg.getType(), replArgs, loc);
-    assert(cast->getNumResults() == 1);
-    mapping.map(origArg, cast->getResult(0));
+    Value newArg;
+    if (typeConverter)
+      newArg = typeConverter->materializeConversion(
+          rewriter, loc, origArg.getType(), replArgs);
+    if (!newArg) {
+      assert(replArgs.size() == 1 &&
+             "couldn't materialize the result of 1->N conversion");
+      newArg = replArgs.front();
+    }
+    mapping.map(origArg, newArg);
     info.argInfo[i] =
-        ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
+        ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
   }
 
   // Remove the original block from the region and return the new one.
@@ -1815,6 +1806,15 @@ LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   return success();
 }
 
+Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
+                                           Location loc, Type resultType,
+                                           ValueRange inputs) {
+  for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
+    if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
+      return result.getValue();
+  return nullptr;
+}
+
 /// Create a default conversion pattern that rewrites the type signature of a
 /// FuncOp.
 namespace {

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 98f350053c04..284c38bad7b7 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -48,6 +48,13 @@ func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   "work"(%arg0) : (f32) -> ()
 }
 
+// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43)
+func @remap_materialize_1_to_1(%arg0: i42) {
+  // CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42
+  // CHECK: "test.return"(%[[V]])
+  "test.return"(%arg0) : (i42) -> ()
+}
+
 // CHECK-LABEL: func @remap_input_to_self
 func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index df3068cc6487..dab8e196111a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -477,7 +477,11 @@ struct TestNestedOpCreationUndoRewrite
 namespace {
 struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
-  TestTypeConverter() { addConversion(convertType); }
+  TestTypeConverter() {
+    addConversion(convertType);
+    addMaterialization(materializeCast);
+    addMaterialization(materializeOneToOneCast);
+  }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
     // Drop I16 types.
@@ -490,6 +494,12 @@ struct TestTypeConverter : public TypeConverter {
       return success();
     }
 
+    // Convert I42 to I43.
+    if (t.isInteger(42)) {
+      results.push_back(IntegerType::get(43, t.getContext()));
+      return success();
+    }
+
     // Split F32 into F16,F16.
     if (t.isF32()) {
       results.assign(2, FloatType::getF16(t.getContext()));
@@ -501,12 +511,24 @@ struct TestTypeConverter : public TypeConverter {
     return success();
   }
 
-  /// Override the hook to materialize a conversion. This is necessary because
-  /// we generate 1->N type mappings.
-  Operation *materializeConversion(PatternRewriter &rewriter, Type resultType,
-                                   ArrayRef<Value> inputs,
-                                   Location loc) override {
-    return rewriter.create<TestCastOp>(loc, resultType, inputs);
+  /// Hook for materializing a conversion. This is necessary because we generate
+  /// 1->N type mappings.
+  static Optional<Value> materializeCast(PatternRewriter &rewriter,
+                                         Type resultType, ValueRange inputs,
+                                         Location loc) {
+    if (inputs.size() == 1)
+      return inputs[0];
+    return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+  }
+
+  /// Materialize the cast for one-to-one conversion from i64 to f64.
+  static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
+                                                 IntegerType resultType,
+                                                 ValueRange inputs,
+                                                 Location loc) {
+    if (resultType.getWidth() == 42 && inputs.size() == 1)
+      return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+    return llvm::None;
   }
 };
 


        


More information about the Mlir-commits mailing list