[Mlir-commits] [mlir] 4589dd9 - [mlir][DialectConversion] Enable deeper integration of type conversions

River Riddle llvmlistbot at llvm.org
Thu Jul 23 19:44:45 PDT 2020


Author: River Riddle
Date: 2020-07-23T19:40:31-07:00
New Revision: 4589dd924dfc43c846652b85825e291af0d7428a

URL: https://github.com/llvm/llvm-project/commit/4589dd924dfc43c846652b85825e291af0d7428a
DIFF: https://github.com/llvm/llvm-project/commit/4589dd924dfc43c846652b85825e291af0d7428a.diff

LOG: [mlir][DialectConversion] Enable deeper integration of type conversions

This revision adds support for much deeper type conversion integration into the conversion process, and enables auto-generating cast operations when necessary. Type conversions are now largely automatically managed by the conversion infra when using a ConversionPattern with a provided TypeConverter. This removes the need for patterns to do type cast wrapping themselves and moves the burden to the infra. This makes it much easier to perform partial lowerings when type conversions are involved, as any lingering type conversions will be automatically resolved/legalized by the conversion infra.

To support this new integration, a few changes have been made to the type materialization API on TypeConverter. Materialization has been split into three separate categories:
* Argument Materialization: This type of materialization is used when converting the type of block arguments when calling `convertRegionTypes`. This is useful for contextually inserting additional conversion operations when converting a block argument type, such as when converting the types of a function signature.
* Source Materialization: This type of materialization is used to convert a legal type of the converter into a non-legal type, generally a source type. This may be called when uses of a non-legal type persist after the conversion process has finished.
* Target Materialization: This type of materialization is used to convert a non-legal, or source, type into a legal, or target, type. This type of materialization is used when applying a pattern on an operation, but the types of the operands have not yet been converted.

Differential Revision: https://reviews.llvm.org/D82831

Added: 
    mlir/test/Transforms/test-legalize-type-conversion.mlir

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/IR/Value.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 26b7ce6ea6c3..8bffb9649d1f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -113,20 +113,40 @@ class TypeConverter {
 
   /// Register a materialization function, which must be convertible to the
   /// following form:
-  ///   `Optional<Value>(PatternRewriter &, T, ValueRange, Location)`,
+  ///   `Optional<Value>(OpBuilder &, T, ValueRange, Location)`,
   /// where `T` is any subclass of `Type`. This function is responsible for
-  /// creating an operation, using the PatternRewriter and Location provided,
-  /// that "casts" a range of values into a single value of the given type `T`.
-  /// It must return a Value of the converted type on success, an `llvm::None`
-  /// if it failed but other materialization can be attempted, and `nullptr` on
+  /// creating an operation, using the OpBuilder and Location provided, that
+  /// "casts" a range of values into a single value of the given type `T`. It
+  /// must return a Value of the converted type on success, an `llvm::None` if
+  /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. It will only be called for (sub)types of `T`.
   /// Materialization functions must be provided when a type conversion
   /// results in more than one type, or if a type conversion may persist after
   /// the conversion has finished.
+  ///
+  /// This method registers a materialization that will be called when
+  /// converting an illegal block argument type, to a legal type.
   template <typename FnT,
             typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
-  void addMaterialization(FnT &&callback) {
-    registerMaterialization(
+  void addArgumentMaterialization(FnT &&callback) {
+    argumentMaterializations.emplace_back(
+        wrapMaterialization<T>(std::forward<FnT>(callback)));
+  }
+  /// This method registers a materialization that will be called when
+  /// converting a legal type to an illegal source type. This is used when
+  /// conversions to an illegal type must persist beyond the main conversion.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  void addSourceMaterialization(FnT &&callback) {
+    sourceMaterializations.emplace_back(
+        wrapMaterialization<T>(std::forward<FnT>(callback)));
+  }
+  /// This method registers a materialization that will be called when
+  /// converting type from an illegal, or source, type to a legal type.
+  template <typename FnT,
+            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
+  void addTargetMaterialization(FnT &&callback) {
+    targetMaterializations.emplace_back(
         wrapMaterialization<T>(std::forward<FnT>(callback)));
   }
 
@@ -182,9 +202,24 @@ class TypeConverter {
   Optional<SignatureConversion> convertBlockSignature(Block *block);
 
   /// Materialize a conversion from a set of types into one result type by
-  /// generating a cast operation of some kind.
-  Value materializeConversion(PatternRewriter &rewriter, Location loc,
-                              Type resultType, ValueRange inputs);
+  /// generating a cast sequence of some kind. See the respective
+  /// `add*Materialization` for more information on the context for these
+  /// methods.
+  Value materializeArgumentConversion(OpBuilder &builder, Location loc,
+                                      Type resultType, ValueRange inputs) {
+    return materializeConversion(argumentMaterializations, builder, loc,
+                                 resultType, inputs);
+  }
+  Value materializeSourceConversion(OpBuilder &builder, Location loc,
+                                    Type resultType, ValueRange inputs) {
+    return materializeConversion(sourceMaterializations, builder, loc,
+                                 resultType, inputs);
+  }
+  Value materializeTargetConversion(OpBuilder &builder, Location loc,
+                                    Type resultType, ValueRange inputs) {
+    return materializeConversion(targetMaterializations, builder, loc,
+                                 resultType, inputs);
+  }
 
 private:
   /// The signature of the callback used to convert a type. If the new set of
@@ -193,8 +228,15 @@ class TypeConverter {
   using ConversionCallbackFn =
       std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
 
-  using MaterializationCallbackFn = std::function<Optional<Value>(
-      PatternRewriter &, Type, ValueRange, Location)>;
+  /// The signature of the callback used to materialize a conversion.
+  using MaterializationCallbackFn =
+      std::function<Optional<Value>(OpBuilder &, Type, ValueRange, Location)>;
+
+  /// Attempt to materialize a conversion using one of the provided
+  /// materialization functions.
+  Value materializeConversion(
+      MutableArrayRef<MaterializationCallbackFn> materializations,
+      OpBuilder &builder, Location loc, Type resultType, ValueRange inputs);
 
   /// Generate a wrapper for the given callback. This allows for accepting
   /// 
diff erent callback forms, that all compose into a single version.
@@ -240,24 +282,21 @@ class TypeConverter {
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) {
     return [callback = std::forward<FnT>(callback)](
-               PatternRewriter &rewriter, Type resultType, ValueRange inputs,
+               OpBuilder &builder, Type resultType, ValueRange inputs,
                Location loc) -> Optional<Value> {
       if (T derivedType = resultType.dyn_cast<T>())
-        return callback(rewriter, derivedType, inputs, loc);
+        return callback(builder, derivedType, inputs, loc);
       return llvm::None;
     };
   }
 
-  /// Register a materialization.
-  void registerMaterialization(MaterializationCallbackFn &&callback) {
-    materializations.emplace_back(std::move(callback));
-  }
-
   /// The set of registered conversion functions.
   SmallVector<ConversionCallbackFn, 4> conversions;
 
   /// The list of registered materialization functions.
-  SmallVector<MaterializationCallbackFn, 2> materializations;
+  SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
+  SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
+  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
 
   /// A set of cached conversions to avoid recomputing in the common case.
   /// Direct 1-1 conversions are the most common, so this cache stores the
@@ -325,7 +364,7 @@ class ConversionPattern : public RewritePattern {
 
 protected:
   /// An optional type converter for use by this pattern.
-  TypeConverter *typeConverter;
+  TypeConverter *typeConverter = nullptr;
 
 private:
   using RewritePattern::rewrite;

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 91a4867ad307..080264e666cf 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -150,19 +150,42 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // Materialization for memrefs creates descriptor structs from individual
   // values constituting them, when descriptors are used, i.e. more than one
   // value represents a memref.
-  addMaterialization([&](PatternRewriter &rewriter,
-                         UnrankedMemRefType resultType, ValueRange inputs,
-                         Location loc) -> Optional<Value> {
+  addArgumentMaterialization(
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc) -> Optional<Value> {
+        if (inputs.size() == 1)
+          return llvm::None;
+        return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
+                                              inputs);
+      });
+  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                 ValueRange inputs,
+                                 Location loc) -> Optional<Value> {
     if (inputs.size() == 1)
       return llvm::None;
-    return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType,
-                                          inputs);
+    return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
   });
-  addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType,
-                         ValueRange inputs, Location loc) -> Optional<Value> {
-    if (inputs.size() == 1)
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs,
+                               Location loc) -> Optional<Value> {
+    if (inputs.size() != 1)
+      return llvm::None;
+    // FIXME: These should check LLVM::DialectCastOp can actually be constructed
+    // from the input and result.
+    return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
+        .getResult();
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs,
+                               Location loc) -> Optional<Value> {
+    if (inputs.size() != 1)
       return llvm::None;
-    return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs);
+    // FIXME: These should check LLVM::DialectCastOp can actually be constructed
+    // from the input and result.
+    return builder.create<LLVM::DialectCastOp>(loc, resultType, inputs[0])
+        .getResult();
   });
 }
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index be1d27141390..aa376993ae71 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -222,6 +222,16 @@ void LowerABIAttributesPass::runOnOperation() {
   spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
 
   SPIRVTypeConverter typeConverter(targetEnv);
+
+  // Insert a bitcast in the case of a pointer type change.
+  typeConverter.addSourceMaterialization([](OpBuilder &builder,
+                                            spirv::PointerType type,
+                                            ValueRange inputs, Location loc) {
+    if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
+      return Value();
+    return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
+  });
+
   OwningRewritePatternList patterns;
   patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
 

diff  --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 6467a7f2295b..776b32a73d58 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -77,7 +77,11 @@ Operation *Value::getDefiningOp() const {
 Location Value::getLoc() const {
   if (auto *op = getDefiningOp())
     return op->getLoc();
-  return UnknownLoc::get(getContext());
+
+  // Use the location of the parent operation if this is a block argument.
+  // TODO: Should we just add locations to block arguments?
+  Operation *parentOp = cast<BlockArgument>().getOwner()->getParentOp();
+  return parentOp ? parentOp->getLoc() : UnknownLoc::get(getContext());
 }
 
 /// Return the Region in which this Value is defined.

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index b9ed64f573f2..9778958a4588 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/ScopedPrinter.h"
 
 using namespace mlir;
@@ -106,8 +107,15 @@ namespace {
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
   /// Lookup a mapped value within the map. If a mapping for the provided value
-  /// does not exist then return the provided value.
-  Value lookupOrDefault(Value from) const;
+  /// does not exist then return the provided value. If `desiredType` is
+  /// non-null, returns the most recently mapped value with that type. If an
+  /// operand of that type does not exist, defaults to normal behavior.
+  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
+
+  /// Lookup a mapped value within the map, or return null if a mapping does not
+  /// exist. If a mapping exists, this follows the same behavior of
+  /// `lookupOrDefault`.
+  Value lookupOrNull(Value from) const;
 
   /// Map a value to the one provided.
   void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
@@ -121,14 +129,36 @@ struct ConversionValueMapping {
 };
 } // end anonymous namespace
 
-/// Lookup a mapped value within the map. If a mapping for the provided value
-/// does not exist then return the provided value.
-Value ConversionValueMapping::lookupOrDefault(Value from) const {
-  // If this value had a valid mapping, unmap that value as well in the case
-  // that it was also replaced.
-  while (auto mappedValue = mapping.lookupOrNull(from))
+Value ConversionValueMapping::lookupOrDefault(Value from,
+                                              Type desiredType) const {
+  // If there was no desired type, simply find the leaf value.
+  if (!desiredType) {
+    // If this value had a valid mapping, unmap that value as well in the case
+    // that it was also replaced.
+    while (auto mappedValue = mapping.lookupOrNull(from))
+      from = mappedValue;
+    return from;
+  }
+
+  // Otherwise, try to find the deepest value that has the desired type.
+  Value desiredValue;
+  do {
+    if (from.getType() == desiredType)
+      desiredValue = from;
+
+    Value mappedValue = mapping.lookupOrNull(from);
+    if (!mappedValue)
+      break;
     from = mappedValue;
-  return from;
+  } while (true);
+
+  // If the desired value was found use it, otherwise default to the leaf value.
+  return desiredValue ? desiredValue : from;
+}
+
+Value ConversionValueMapping::lookupOrNull(Value from) const {
+  Value result = lookupOrDefault(from);
+  return result == from ? nullptr : result;
 }
 
 //===----------------------------------------------------------------------===//
@@ -209,10 +239,17 @@ struct ArgConverter {
   /// its original state.
   void discardRewrites(Block *block);
 
-  /// Fully replace uses of the old arguments with the new, materializing cast
-  /// operations as necessary.
+  /// Fully replace uses of the old arguments with the new.
   void applyRewrites(ConversionValueMapping &mapping);
 
+  /// Materialize any necessary conversions for converted arguments that have
+  /// live users, using the provided `findLiveUser` to search for a user that
+  /// survives the conversion process.
+  LogicalResult
+  materializeLiveConversions(ConversionValueMapping &mapping,
+                             OpBuilder &builder,
+                             function_ref<Operation *(Value)> findLiveUser);
+
   //===--------------------------------------------------------------------===//
   // Conversion
   //===--------------------------------------------------------------------===//
@@ -307,7 +344,6 @@ void ArgConverter::discardRewrites(Block *block) {
 
 void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
   for (auto &info : conversionInfo) {
-    Block *newBlock = info.first;
     ConvertedBlockInfo &blockInfo = info.second;
     Block *origBlock = blockInfo.origBlock;
 
@@ -318,24 +354,8 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
 
       // Handle the case of a 1->0 value mapping.
       if (!argInfo) {
-        // If a replacement value was given for this argument, use that to
-        // replace all uses.
-        auto argReplacementValue = mapping.lookupOrDefault(origArg);
-        if (argReplacementValue != origArg) {
-          origArg.replaceAllUsesWith(argReplacementValue);
-          continue;
-        }
-        // If there are any dangling uses then replace the argument with one
-        // generated by the type converter. This is necessary as the cast must
-        // persist in the IR after conversion.
-        if (!origArg.use_empty()) {
-          rewriter.setInsertionPointToStart(newBlock);
-          Value newArg = blockInfo.converter->materializeConversion(
-              rewriter, origArg.getLoc(), origArg.getType(), llvm::None);
-          assert(newArg &&
-                 "Couldn't materialize a block argument after 1->0 conversion");
+        if (Value newArg = mapping.lookupOrNull(origArg))
           origArg.replaceAllUsesWith(newArg);
-        }
         continue;
       }
 
@@ -355,6 +375,59 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
   }
 }
 
+LogicalResult ArgConverter::materializeLiveConversions(
+    ConversionValueMapping &mapping, OpBuilder &builder,
+    function_ref<Operation *(Value)> findLiveUser) {
+  for (auto &info : conversionInfo) {
+    Block *newBlock = info.first;
+    ConvertedBlockInfo &blockInfo = info.second;
+    Block *origBlock = blockInfo.origBlock;
+
+    // Process the remapping for each of the original arguments.
+    for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+      // FIXME: We should run the below checks even if the type conversion was
+      // 1->N, but a lot of existing lowering rely on the block argument being
+      // blindly replaced. Those usages should be updated, and this if should be
+      // removed.
+      if (blockInfo.argInfo[i])
+        continue;
+
+      // If the type of this argument changed and the argument is still live, we
+      // need to materialize a conversion.
+      BlockArgument origArg = origBlock->getArgument(i);
+      auto argReplacementValue = mapping.lookupOrDefault(origArg);
+      bool isDroppedArg = argReplacementValue == origArg;
+      if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
+        continue;
+      Operation *liveUser = findLiveUser(origArg);
+      if (!liveUser)
+        continue;
+
+      if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
+        rewriter.setInsertionPointAfter(result.getOwner());
+      else
+        rewriter.setInsertionPointToStart(newBlock);
+      Value newArg = blockInfo.converter->materializeSourceConversion(
+          rewriter, origArg.getLoc(), origArg.getType(),
+          isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
+      if (!newArg) {
+        InFlightDiagnostic diag =
+            emitError(origArg.getLoc())
+            << "failed to materialize conversion for block argument #" << i
+            << " that remained live after conversion, type was "
+            << origArg.getType();
+        if (!isDroppedArg)
+          diag << ", with target type " << argReplacementValue.getType();
+        diag.attachNote(liveUser->getLoc())
+            << "see existing live user here: " << *liveUser;
+        return failure();
+      }
+      mapping.map(origArg, newArg);
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion
 
@@ -417,8 +490,8 @@ Block *ArgConverter::applySignatureConversion(
     // to pack the new values. For 1->1 mappings, if there is no materialization
     // provided, use the argument directly instead.
     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(),
-                                                   origArg.getType(), replArgs);
+    Value newArg = converter.materializeArgumentConversion(
+        rewriter, origArg.getLoc(), origArg.getType(), replArgs);
     if (!newArg) {
       assert(replArgs.size() == 1 &&
              "couldn't materialize the result of 1->N conversion");
@@ -516,13 +589,15 @@ class OperationTransactionState {
   SmallVector<Block *, 2> successors;
 };
 
-/// This class represents one requested operation replacement via 'replaceOp'.
+/// This class represents one requested operation replacement via 'replaceOp' or
+/// 'eraseOp`.
 struct OpReplacement {
   OpReplacement() = default;
-  OpReplacement(ValueRange newValues)
-      : newValues(newValues.begin(), newValues.end()) {}
+  OpReplacement(TypeConverter *converter) : converter(converter) {}
 
-  SmallVector<Value, 2> newValues;
+  /// An optional type converter that can be used to materialize conversions
+  /// between the new and old values if necessary.
+  TypeConverter *converter = nullptr;
 };
 
 /// The kind of the block action performed during the rewrite.  Actions can be
@@ -611,9 +686,14 @@ struct ConversionPatternRewriterImpl {
   /// "numActionsToKeep" actions remains.
   void undoBlockActions(unsigned numActionsToKeep = 0);
 
-  /// Remap the given operands to those with potentially 
diff erent types.
-  void remapValues(Operation::operand_range operands,
-                   SmallVectorImpl<Value> &remapped);
+  /// Remap the given operands to those with potentially 
diff erent types. The
+  /// provided type converter is used to ensure that the remapped types are
+  /// legal. Returns success if the operands could be remapped, failure
+  /// otherwise.
+  LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
+                            TypeConverter *converter,
+                            Operation::operand_range operands,
+                            SmallVectorImpl<Value> &remapped);
 
   /// Returns true if the given operation is ignored, and does not need to be
   /// converted.
@@ -666,6 +746,11 @@ struct ConversionPatternRewriterImpl {
   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
                                    Location origRegionLoc);
 
+  /// Notifies that a pattern match failed for the given reason.
+  LogicalResult
+  notifyMatchFailure(Location loc,
+                     function_ref<void(Diagnostic &)> reasonCallback);
+
   //===--------------------------------------------------------------------===//
   // State
   //===--------------------------------------------------------------------===//
@@ -712,6 +797,10 @@ struct ConversionPatternRewriterImpl {
   /// explicitly provided.
   TypeConverter defaultTypeConverter;
 
+  /// The current conversion pattern that is being rewritten, or nullptr if
+  /// called from outside of a conversion pattern rewrite.
+  const ConversionPattern *currentConversionPattern = nullptr;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -759,11 +848,9 @@ void ConversionPatternRewriterImpl::discardRewrites() {
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Apply all of the rewrites replacements requested during conversion.
   for (auto &repl : replacements) {
-    for (unsigned i = 0, e = repl.second.newValues.size(); i != e; ++i) {
-      if (auto newValue = repl.second.newValues[i])
-        repl.first->getResult(i).replaceAllUsesWith(
-            mapping.lookupOrDefault(newValue));
-    }
+    for (OpResult result : repl.first->getResults())
+      if (Value newValue = mapping.lookupOrNull(result))
+        result.replaceAllUsesWith(newValue);
 
     // If this operation defines any regions, drop any pending argument
     // rewrites.
@@ -905,11 +992,61 @@ void ConversionPatternRewriterImpl::undoBlockActions(
   blockActions.resize(numActionsToKeep);
 }
 
-void ConversionPatternRewriterImpl::remapValues(
+LogicalResult ConversionPatternRewriterImpl::remapValues(
+    Location loc, PatternRewriter &rewriter, TypeConverter *converter,
     Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
   remapped.reserve(llvm::size(operands));
-  for (Value operand : operands)
-    remapped.push_back(mapping.lookupOrDefault(operand));
+
+  SmallVector<Type, 1> legalTypes;
+  for (auto it : llvm::enumerate(operands)) {
+    Value operand = it.value();
+    Type origType = operand.getType();
+
+    // If a converter was provided, get the desired legal types for this
+    // operand.
+    Type desiredType;
+    if (converter) {
+      // If there is no legal conversion, fail to match this pattern.
+      legalTypes.clear();
+      if (failed(converter->convertType(origType, legalTypes))) {
+        return notifyMatchFailure(loc, [=](Diagnostic &diag) {
+          diag << "unable to convert type for operand #" << it.index()
+               << ", type was " << origType;
+        });
+      }
+      // TODO: There currently isn't any mechanism to do 1->N type conversion
+      // via the PatternRewriter replacement API, so for now we just ignore it.
+      if (legalTypes.size() == 1)
+        desiredType = legalTypes.front();
+    } else {
+      // TODO: What we should do here is just set `desiredType` to `origType`
+      // and then handle the necessary type conversions after the conversion
+      // process has finished. Unfortunately a lot of patterns currently rely on
+      // receiving the new operands even if the types change, so we keep the
+      // original behavior here for now until all of the patterns relying on
+      // this get updated.
+    }
+    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
+
+    // Handle the case where the conversion was 1->1 and the new operand type
+    // isn't legal.
+    Type newOperandType = newOperand.getType();
+    if (converter && desiredType && newOperandType != desiredType) {
+      // Attempt to materialize a conversion for this new value.
+      newOperand = converter->materializeTargetConversion(
+          rewriter, loc, desiredType, newOperand);
+      if (!newOperand) {
+        return notifyMatchFailure(loc, [=](Diagnostic &diag) {
+          diag << "unable to materialize a conversion for "
+                  "operand #"
+               << it.index() << ", from " << newOperandType << " to "
+               << desiredType;
+        });
+      }
+    }
+    remapped.push_back(newOperand);
+  }
+  return success();
 }
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
@@ -987,16 +1124,22 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
   Value newValue, result;
   for (auto it : llvm::zip(newValues, op->getResults())) {
     std::tie(newValue, result) = it;
-    if (!newValue)
+    if (!newValue) {
       resultChanged = true;
-    else
-      mapping.map(result, newValue);
+      continue;
+    }
+    // Remap, and check for any result type changes.
+    mapping.map(result, newValue);
+    resultChanged |= (newValue.getType() != result.getType());
   }
   if (resultChanged)
     operationsWithChangedResults.push_back(replacements.size());
 
   // Record the requested operation replacement.
-  replacements.insert(std::make_pair(op, OpReplacement(newValues)));
+  TypeConverter *converter = nullptr;
+  if (currentConversionPattern)
+    converter = currentConversionPattern->getTypeConverter();
+  replacements.insert(std::make_pair(op, OpReplacement(converter)));
 
   // Mark this operation as recursively ignored so that we don't need to
   // convert any nested operations.
@@ -1041,6 +1184,16 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
   assert(succeeded(result) && "expected region to have no unreachable blocks");
 }
 
+LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
+    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
+  LLVM_DEBUG({
+    Diagnostic diag(loc, DiagnosticSeverity::Remark);
+    reasonCallback(diag);
+    logger.startLine() << "** Failure : " << diag.str() << "\n";
+  });
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
@@ -1200,12 +1353,7 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
 /// PatternRewriter hook for notifying match failure reasons.
 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
     Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
-  LLVM_DEBUG({
-    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
-    reasonCallback(diag);
-    impl->logger.startLine() << "** Failure : " << diag.str() << "\n";
-  });
-  return failure();
+  return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
 }
 
 /// Return a reference to the internal implementation.
@@ -1221,9 +1369,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
 LogicalResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
-  SmallVector<Value, 4> operands;
   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
-  dialectRewriter.getImpl().remapValues(op->getOperands(), operands);
+  auto &rewriterImpl = dialectRewriter.getImpl();
+
+  // Track the current conversion pattern in the rewriter.
+  assert(!rewriterImpl.currentConversionPattern &&
+         "already inside of a pattern rewrite");
+  llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
+      rewriterImpl.currentConversionPattern, this);
+
+  // Remap the operands of the operation.
+  SmallVector<Value, 4> operands;
+  if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
+                                      getTypeConverter(), op->getOperands(),
+                                      operands))) {
+    return failure();
+  }
   return matchAndRewrite(op, operands, dialectRewriter);
 }
 
@@ -1878,6 +2039,24 @@ struct OperationConverter {
   /// remaining artifacts and complete the conversion.
   LogicalResult finalize(ConversionPatternRewriter &rewriter);
 
+  /// Legalize the types of converted block arguments.
+  LogicalResult
+  legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
+                                 ConversionPatternRewriterImpl &rewriterImpl);
+
+  /// Legalize an operation result that was marked as "erased".
+  LogicalResult
+  legalizeErasedResult(Operation *op, OpResult result,
+                       ConversionPatternRewriterImpl &rewriterImpl);
+
+  /// Legalize an operation result that was replaced with a value of a 
diff erent
+  /// type.
+  LogicalResult
+  legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
+                            TypeConverter *replConverter,
+                            ConversionPatternRewriter &rewriter,
+                            ConversionPatternRewriterImpl &rewriterImpl);
+
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
@@ -1961,33 +2140,145 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 LogicalResult
 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  auto isOpDead = [&](Operation *op) { return rewriterImpl.isOpIgnored(op); };
 
-  // Process the operations with changed results.
-  for (unsigned replIdx : rewriterImpl.operationsWithChangedResults) {
+  // Legalize converted block arguments.
+  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+    return failure();
+
+  // Process requested operation replacements.
+  for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
+       i != e; ++i) {
+    unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
     auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
-    for (auto it : llvm::zip(repl.first->getResults(), repl.second.newValues)) {
-      Value result = std::get<0>(it), newValue = std::get<1>(it);
+    for (OpResult result : repl.first->getResults()) {
+      Value newValue = rewriterImpl.mapping.lookupOrNull(result);
 
       // If the operation result was replaced with null, all of the uses of this
       // value should be replaced.
-      if (newValue)
+      if (!newValue) {
+        if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
+          return failure();
+        continue;
+      }
+
+      // Otherwise, check to see if the type of the result changed.
+      if (result.getType() == newValue.getType())
         continue;
 
-      auto liveUserIt = llvm::find_if_not(result.getUsers(), isOpDead);
-      if (liveUserIt != result.user_end()) {
-        InFlightDiagnostic diag = repl.first->emitError()
-                                  << "failed to legalize operation '"
-                                  << repl.first->getName()
-                                  << "' marked as erased";
-        diag.attachNote(liveUserIt->getLoc())
-            << "found live user of result #"
-            << result.cast<OpResult>().getResultNumber() << ": " << *liveUserIt;
+      // Legalize this result.
+      rewriter.setInsertionPoint(repl.first);
+      if (failed(legalizeChangedResultType(repl.first, result, newValue,
+                                           repl.second.converter, rewriter,
+                                           rewriterImpl)))
         return failure();
-      }
+
+      // Update the end iterator for this loop in the case it was updated
+      // when legalizing generated conversion operations.
+      e = rewriterImpl.operationsWithChangedResults.size();
+    }
+  }
+  return success();
+}
+
+LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
+    ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl) {
+  // Functor used to check if all users of a value will be dead after
+  // conversion.
+  auto findLiveUser = [&](Value val) {
+    auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
+      return rewriterImpl.isOpIgnored(user);
+    });
+    return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
+  };
+
+  // Materialize any necessary conversions for converted block arguments that
+  // are still live.
+  size_t numCreatedOps = rewriterImpl.createdOps.size();
+  if (failed(rewriterImpl.argConverter.materializeLiveConversions(
+          rewriterImpl.mapping, rewriter, findLiveUser)))
+    return failure();
+
+  // Legalize any newly created operations during argument materialization.
+  for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
+    if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
+      return rewriterImpl.createdOps[i]->emitError()
+             << "failed to legalize conversion operation generated for block "
+                "argument that remained live after conversion";
+    }
+  }
+  return success();
+}
+
+LogicalResult OperationConverter::legalizeErasedResult(
+    Operation *op, OpResult result,
+    ConversionPatternRewriterImpl &rewriterImpl) {
+  // If the operation result was replaced with null, all of the uses of this
+  // value should be replaced.
+  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
+    return rewriterImpl.isOpIgnored(user);
+  });
+  if (liveUserIt != result.user_end()) {
+    InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
+                              << op->getName() << "' marked as erased";
+    diag.attachNote(liveUserIt->getLoc())
+        << "found live user of result #" << result.getResultNumber() << ": "
+        << *liveUserIt;
+    return failure();
+  }
+  return success();
+}
+
+LogicalResult OperationConverter::legalizeChangedResultType(
+    Operation *op, OpResult result, Value newValue,
+    TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &rewriterImpl) {
+  // Walk the users of this value to see if there are any live users that
+  // weren't replaced during conversion.
+  auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
+    return rewriterImpl.isOpIgnored(user);
+  });
+  if (liveUserIt == result.user_end())
+    return success();
+
+  // If the replacement has a type converter, attempt to materialize a
+  // conversion back to the original type.
+  if (!replConverter) {
+    // TODO: We should emit an error here, similarly to the case where the
+    // result is replaced with null. Unfortunately a lot of existing
+    // patterns rely on this behavior, so until those patterns are updated
+    // we keep the legacy behavior here of just forwarding the new value.
+    return success();
+  }
+
+  // Track the number of created operations so that new ones can be legalized.
+  size_t numCreatedOps = rewriterImpl.createdOps.size();
+
+  // Materialize a conversion for this live result value.
+  Type resultType = result.getType();
+  Value convertedValue = replConverter->materializeSourceConversion(
+      rewriter, op->getLoc(), resultType, newValue);
+  if (!convertedValue) {
+    InFlightDiagnostic diag = op->emitError()
+                              << "failed to materialize conversion for result #"
+                              << result.getResultNumber() << " of operation '"
+                              << op->getName()
+                              << "' that remained live after conversion";
+    diag.attachNote(liveUserIt->getLoc())
+        << "see existing live user here: " << *liveUserIt;
+    return failure();
+  }
+
+  // Legalize all of the newly created conversion operations.
+  for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
+    if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
+      return op->emitError("failed to legalize conversion operation generated ")
+             << "for result #" << result.getResultNumber() << " of operation '"
+             << op->getName() << "' that remained live after conversion";
     }
   }
 
+  rewriterImpl.mapping.map(result, convertedValue);
   return success();
 }
 
@@ -2136,11 +2427,11 @@ LogicalResult TypeConverter::convertSignatureArgs(TypeRange types,
   return success();
 }
 
-Value TypeConverter::materializeConversion(PatternRewriter &rewriter,
-                                           Location loc, Type resultType,
-                                           ValueRange inputs) {
+Value TypeConverter::materializeConversion(
+    MutableArrayRef<MaterializationCallbackFn> materializations,
+    OpBuilder &builder, Location loc, Type resultType, ValueRange inputs) {
   for (MaterializationCallbackFn &fn : llvm::reverse(materializations))
-    if (Optional<Value> result = fn(rewriter, resultType, inputs, loc))
+    if (Optional<Value> result = fn(builder, resultType, inputs, loc))
       return result.getValue();
   return nullptr;
 }

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index b8ebdfbf35f1..3b0a17be640b 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -75,15 +75,3 @@ func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
   %0 = rsqrt %arg0 : vector<4x3xf32>
   std.return
 }
-
-// -----
-
-// This should not crash. The first operation cannot be converted, so the
-// second should not match. This attempts to convert `return` to `llvm.return`
-// and complains about non-LLVM types.
-func @unknown_source() -> i32 {
-  %0 = "foo"() : () -> i32
-  %1 = addi %0, %0 : i32
-  // expected-error at +1 {{must be LLVM dialect type}}
-  return %1 : i32
-}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 4e4bf06e6f73..3d37f35b1c46 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -57,9 +57,12 @@ spv.module Logical GLSL450 {
     // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
     // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
     // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
-    // CHECK: [[ARG2:%.*]] = spv._address_of [[VAR2]]
-    // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]]
-    // CHECK: [[ARG0:%.*]] = spv._address_of [[VAR0]]
+    // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
+    // CHECK: [[ARG2:%.*]] = spv.Bitcast [[ADDRESSARG2]]
+    // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+    // CHECK: [[ARG1:%.*]] = spv.Bitcast [[ADDRESSARG1]]
+    // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+    // CHECK: [[ARG0:%.*]] = spv.Bitcast [[ADDRESSARG0]]
     %0 = spv._address_of @__builtin_var_WorkgroupId__ : !spv.ptr<vector<3xi32>, Input>
     %1 = spv.Load "Input" %0 : vector<3xi32>
     %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32>

diff  --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
new file mode 100644
index 000000000000..c56b3c8ca1e2
--- /dev/null
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -test-legalize-type-conversion -allow-unregistered-dialect -split-input-file -verify-diagnostics | FileCheck %s
+
+// expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+func @test_invalid_arg_materialization(%arg0: i16) {
+  // expected-note at below {{see existing live user here}}
+  "foo.return"(%arg0) : (i16) -> ()
+}
+
+// -----
+
+// expected-error at below {{failed to legalize conversion operation generated for block argument}}
+func @test_invalid_arg_illegal_materialization(%arg0: i32) {
+  "foo.return"(%arg0) : (i32) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @test_valid_arg_materialization
+func @test_valid_arg_materialization(%arg0: i64) {
+  // CHECK: %[[ARG:.*]] = "test.type_producer"
+  // CHECK: "foo.return"(%[[ARG]]) : (i64)
+
+  "foo.return"(%arg0) : (i64) -> ()
+}
+
+// -----
+
+func @test_invalid_result_materialization() {
+  // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+  %result = "test.type_producer"() : () -> f16
+
+  // expected-note at below {{see existing live user here}}
+  "foo.return"(%result) : (f16) -> ()
+}
+
+// -----
+
+func @test_invalid_result_materialization() {
+  // expected-error at below {{failed to materialize conversion for result #0 of operation 'test.type_producer' that remained live after conversion}}
+  %result = "test.type_producer"() : () -> f16
+
+  // expected-note at below {{see existing live user here}}
+  "foo.return"(%result) : (f16) -> ()
+}
+
+// -----
+
+func @test_invalid_result_legalization() {
+  // expected-error at below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}}
+  %result = "test.type_producer"() : () -> i16
+  "foo.return"(%result) : (i16) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func @test_valid_result_legalization
+func @test_valid_result_legalization() {
+  // CHECK: %[[RESULT:.*]] = "test.type_producer"() : () -> f64
+  // CHECK: %[[CAST:.*]] = "test.cast"(%[[RESULT]]) : (f64) -> f32
+  // CHECK: "foo.return"(%[[CAST]]) : (f32)
+
+  %result = "test.type_producer"() : () -> f32
+  "foo.return"(%result) : (f32) -> ()
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 255b1c152a36..5bc947fc8c91 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -485,8 +485,9 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addMaterialization(materializeCast);
-    addMaterialization(materializeOneToOneCast);
+    addArgumentMaterialization(materializeCast);
+    addArgumentMaterialization(materializeOneToOneCast);
+    addSourceMaterialization(materializeCast);
   }
 
   static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -519,21 +520,20 @@ struct TestTypeConverter : public TypeConverter {
 
   /// Hook for materializing a conversion. This is necessary because we generate
   /// 1->N type mappings.
-  static Optional<Value> materializeCast(PatternRewriter &rewriter,
-                                         Type resultType, ValueRange inputs,
-                                         Location loc) {
+  static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
+                                         ValueRange inputs, Location loc) {
     if (inputs.size() == 1)
       return inputs[0];
-    return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+    return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
   }
 
   /// Materialize the cast for one-to-one conversion from i64 to f64.
-  static Optional<Value> materializeOneToOneCast(PatternRewriter &rewriter,
+  static Optional<Value> materializeOneToOneCast(OpBuilder &builder,
                                                  IntegerType resultType,
                                                  ValueRange inputs,
                                                  Location loc) {
     if (resultType.getWidth() == 42 && inputs.size() == 1)
-      return rewriter.create<TestCastOp>(loc, resultType, inputs).getResult();
+      return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
     return llvm::None;
   }
 };
@@ -742,6 +742,102 @@ struct TestUnknownRootOpDriver
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// Test type conversions
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TestTypeConversionProducer
+    : public OpConversionPattern<TestTypeProducerOp> {
+  using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TestTypeProducerOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    Type resultType = op.getType();
+    if (resultType.isa<FloatType>())
+      resultType = rewriter.getF64Type();
+    else if (resultType.isInteger(16))
+      resultType = rewriter.getIntegerType(64);
+    else
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType);
+    return success();
+  }
+};
+
+struct TestTypeConversionDriver
+    : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
+  void runOnOperation() override {
+    // Initialize the type converter.
+    TypeConverter converter;
+
+    /// Add the legal set of type conversions.
+    converter.addConversion([](Type type) -> Type {
+      // Treat F64 as legal.
+      if (type.isF64())
+        return type;
+      // Allow converting BF16/F16/F32 to F64.
+      if (type.isBF16() || type.isF16() || type.isF32())
+        return FloatType::getF64(type.getContext());
+      // Otherwise, the type is illegal.
+      return nullptr;
+    });
+    converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) {
+      // Drop all integer types.
+      return success();
+    });
+
+    /// Add the legal set of type materializations.
+    converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
+                                          ValueRange inputs,
+                                          Location loc) -> Value {
+      // Allow casting from F64 back to F32.
+      if (!resultType.isF16() && inputs.size() == 1 &&
+          inputs[0].getType().isF64())
+        return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+      // Allow producing an i32 or i64 from nothing.
+      if ((resultType.isInteger(32) || resultType.isInteger(64)) &&
+          inputs.empty())
+        return builder.create<TestTypeProducerOp>(loc, resultType);
+      // Allow producing an i64 from an integer.
+      if (resultType.isa<IntegerType>() && inputs.size() == 1 &&
+          inputs[0].getType().isa<IntegerType>())
+        return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
+      // Otherwise, fail.
+      return nullptr;
+    });
+
+    // Initialize the conversion target.
+    mlir::ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
+      return op.getType().isF64() || op.getType().isInteger(64);
+    });
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return converter.isSignatureLegal(op.getType()) &&
+             converter.isLegal(&op.getBody());
+    });
+    target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
+      // Allow casts from F64 to F32.
+      return (*op.operand_type_begin()).isF64() && op.getType().isF32();
+    });
+
+    // Initialize the set of rewrite patterns.
+    OwningRewritePatternList patterns;
+    patterns.insert<TestTypeConversionProducer>(converter, &getContext());
+    mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
+                                              converter);
+
+    if (failed(applyPartialConversion(getOperation(), target, patterns)))
+      signalPassFailure();
+  }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// PassRegistration
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 void registerPatternsTestPass() {
   PassRegistration<TestReturnTypeDriver>("test-return-type",
@@ -766,5 +862,9 @@ void registerPatternsTestPass() {
   PassRegistration<TestUnknownRootOpDriver>(
       "test-legalize-unknown-root-patterns",
       "Test public remapped value mechanism in ConversionPatternRewriter");
+
+  PassRegistration<TestTypeConversionDriver>(
+      "test-legalize-type-conversion",
+      "Test various type conversion functionalities in DialectConversion");
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list