[Mlir-commits] [mlir] 8d67d18 - [mlir][DialectConversion] Refactor how block argument types get converted

River Riddle llvmlistbot at llvm.org
Thu Jun 18 16:01:25 PDT 2020


Author: River Riddle
Date: 2020-06-18T15:59:22-07:00
New Revision: 8d67d187ba1bdb201f83ce25725e9be59b0141a7

URL: https://github.com/llvm/llvm-project/commit/8d67d187ba1bdb201f83ce25725e9be59b0141a7
DIFF: https://github.com/llvm/llvm-project/commit/8d67d187ba1bdb201f83ce25725e9be59b0141a7.diff

LOG: [mlir][DialectConversion] Refactor how block argument types get converted

This revision removes the TypeConverter parameter passed to the apply* methods, and instead moves the responsibility of region type conversion to patterns. The types of a region can be converted using the 'convertRegionTypes' method, which acts similarly to the existing 'applySignatureConversion'. This method ensures that all blocks within, and including those moved into, a region will have the block argument types converted using the provided converter.

This has the benefit of making more of the legalization logic controlled by patterns, instead of being handled explicitly by the driver. It also opens up the possibility to support multiple type conversions at some point in the future.

This revision also adds a new utility class `FailureOr<T>` that provides a LogicalResult friendly facility for returning a failure or a valid result value.

Differential Revision: https://reviews.llvm.org/D81681

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/docs/Tutorials/Toy/Ch-6.md
    mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
    mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Support/LogicalResult.h
    mlir/include/mlir/Transforms/BufferPlacement.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestBufferPlacement.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 7995099636e9..c7174147b72e 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -262,14 +262,21 @@ patterns used in dialect conversion.
 
 ### Region Signature Conversion
 
-From the perspective of type conversion, the entry block to a region is often
-special. The types of the entry block arguments are often tied semantically to
-details on the operation, e.g. FuncOp, AffineForOp, etc. Given this, the
-conversion of the types for this block must be done explicitly via a conversion
-pattern. To convert the signature of a region entry block, a custom hook on the
-ConversionPatternRewriter must be invoked `applySignatureConversion`. A
-signature conversion, `TypeConverter::SignatureConversion`, can be built
-programmatically:
+From the perspective of type conversion, the types of block arguments are a bit
+special. Throughout the conversion process, blocks may move between regions of
+
diff erent operations. Given this, the conversion of the types for blocks must be
+done explicitly via a conversion pattern. To convert the types of block
+arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
+be invoked; `convertRegionTypes`. This hook uses a provided type converter to
+apply type conversions to all blocks within the region, and all blocks that move
+into that region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to details on the operation, e.g. FuncOp, AffineForOp,
+etc. To convert the signature of just the region entry block, and not any other
+blocks within the region, the `applySignatureConversion` hook may be used
+instead. A signature conversion, `TypeConverter::SignatureConversion`, can be
+built programmatically:
 
 ```c++
 class SignatureConversion {
@@ -293,5 +300,6 @@ public:
 };
 ```
 
-The `TypeConverter` provides several default utilities for signature conversion:
-`convertSignatureArg`/`convertBlockSignature`.
+The `TypeConverter` provides several default utilities for signature conversion
+and legality checking:
+`convertSignatureArgs`/`convertBlockSignature`/`isLegal(Region *|Type)`.

diff  --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 734eafb0a99b..06f5cd60dca5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -106,8 +106,7 @@ that only legal operations will remain after the conversion.
 
 ```c++
   mlir::ModuleOp module = getOperation();
-  if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
-                                             &typeConverter)))
+  if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
     signalPassFailure();
 ```
 

diff  --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 43b4c10e62e6..af4130c6a5ca 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -203,7 +203,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
   auto module = getOperation();
-  if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+  if (failed(applyFullConversion(module, target, patterns)))
     signalPassFailure();
 }
 

diff  --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 43b4c10e62e6..af4130c6a5ca 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -203,7 +203,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
   auto module = getOperation();
-  if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+  if (failed(applyFullConversion(module, target, patterns)))
     signalPassFailure();
 }
 

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0f0228a3dad3..f1c7c39a3e73 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -99,7 +99,7 @@ class Pattern {
   /// pattern.
   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
 
-  /// This contructor is used when a pattern may match against multiple
+  /// This constructor is used when a pattern may match against multiple
   /// 
diff erent types of operations. The `benefit` is the expected benefit of
   /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that
   /// the "match any" behavior is what the user actually desired,
@@ -163,28 +163,27 @@ class RewritePattern : public Pattern {
   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
 
 protected:
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching.
+  /// Construct a rewrite pattern with a certain benefit that matches the
+  /// operation with the given root name.
   RewritePattern(StringRef rootName, PatternBenefit benefit,
                  MLIRContext *context)
       : Pattern(rootName, benefit, context) {}
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching. `MatchAnyOpTypeTag`
-  /// is just a tag to ensure that the "match any" behavior is what the user
-  /// actually desired, `MatchAnyOpTypeTag()` should always be supplied here.
+  /// Construct a rewrite pattern with a certain benefit that matches any
+  /// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the
+  /// "match any" behavior is what the user actually desired,
+  /// `MatchAnyOpTypeTag()` should always be supplied here.
   RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
       : Pattern(benefit, tag) {}
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching. They can also specify
-  /// the names of operations that may be generated during a successful rewrite.
+  /// Construct a rewrite pattern with a certain benefit that matches the
+  /// operation with the given root name. `generatedNames` contains the names of
+  /// operations that may be generated during a successful rewrite.
   RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
                  PatternBenefit benefit, MLIRContext *context);
-  /// Patterns must specify the root operation name they match against, and can
-  /// also specify the benefit of the pattern matching. They can also specify
-  /// the names of operations that may be generated during a successful rewrite.
-  /// `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
-  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
-  /// always be supplied here.
+  /// Construct a rewrite pattern that may match any operation type.
+  /// `generatedNames` contains the names of operations that may be generated
+  /// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure
+  /// that the "match any" behavior is what the user actually desired,
+  /// `MatchAnyOpTypeTag()` should always be supplied here.
   RewritePattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
                  MLIRContext *context, MatchAnyOpTypeTag tag);
 

diff  --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h
index 216e05698e8a..3e30e9e81567 100644
--- a/mlir/include/mlir/Support/LogicalResult.h
+++ b/mlir/include/mlir/Support/LogicalResult.h
@@ -10,11 +10,12 @@
 #define MLIR_SUPPORT_LOGICAL_RESULT_H
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/Optional.h"
 
 namespace mlir {
 
-// Values that can be used to signal success/failure. This should be used in
-// conjunction with the utility functions below.
+/// Values that can be used to signal success/failure. This should be used in
+/// conjunction with the utility functions below.
 struct LogicalResult {
   enum ResultEnum { Success, Failure } value;
   LogicalResult(ResultEnum v) : value(v) {}
@@ -46,6 +47,28 @@ inline bool failed(LogicalResult result) {
   return result.value == LogicalResult::Failure;
 }
 
+/// This class provides support for representing a failure result, or a valid
+/// value of type `T`. This allows for integrating with LogicalResult, while
+/// also providing a value on the success path.
+template <typename T> class LLVM_NODISCARD FailureOr : public Optional<T> {
+public:
+  /// Allow constructing from a LogicalResult. The result *must* be a failure.
+  /// Success results should use a proper instance of type `T`.
+  FailureOr(LogicalResult result) {
+    assert(failed(result) &&
+           "success should be constructed with an instance of 'T'");
+  }
+  FailureOr() : FailureOr(failure()) {}
+  FailureOr(T &&y) : Optional<T>(std::forward<T>(y)) {}
+
+  operator LogicalResult() const { return success(this->hasValue()); }
+
+private:
+  /// Hide the bool conversion as it easily creates confusion.
+  using Optional<T>::operator bool;
+  using Optional<T>::hasValue;
+};
+
 } // namespace mlir
 
 #endif // MLIR_SUPPORT_LOGICAL_RESULT_H

diff  --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 547db487e454..f8559a9dd939 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -141,12 +141,14 @@ class BufferAssignmentFuncOpConverter
       else
         newResultTypes.push_back(convertedType);
     }
+    if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+                                           &conversion)))
+      return failure();
 
     // Update the signature of the function.
     rewriter.updateRootInPlace(funcOp, [&] {
       funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
                                               newResultTypes));
-      rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
     });
     return success();
   }

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 2ce95b10d607..d862823930c5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -160,6 +160,9 @@ class TypeConverter {
   /// Return true if the given operation has legal operand and result types.
   bool isLegal(Operation *op);
 
+  /// Return true if the types of block arguments within the region are legal.
+  bool isLegal(Region *region);
+
   /// Return true if the inputs and outputs of the given function type are
   /// legal.
   bool isSignatureLegal(FunctionType ty);
@@ -268,16 +271,15 @@ class TypeConverter {
 // Conversion Patterns
 //===----------------------------------------------------------------------===//
 
-/// Base class for the conversion patterns that require type changes. Specific
-/// conversions must derive this class and implement least one `rewrite` method.
-/// NOTE: These conversion patterns can only be used with the 'apply*' methods
-/// below.
+/// Base class for the conversion patterns. This pattern class enables type
+/// conversions, and other uses specific to the conversion framework. As such,
+/// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
   /// Hook for derived classes to implement rewriting. `op` is the (first)
-  /// operation matched by the pattern, `operands` is a list of rewritten values
-  /// that are passed to this operation, `rewriter` can be used to emit the new
-  /// operations. This function should not fail. If some specific cases of
+  /// operation matched by the pattern, `operands` is a list of the rewritten
+  /// operand values that are passed to `op`, `rewriter` can be used to emit the
+  /// new operations. This function should not fail. If some specific cases of
   /// the operation are not supported, these cases should not be matched.
   virtual void rewrite(Operation *op, ArrayRef<Value> operands,
                        ConversionPatternRewriter &rewriter) const {
@@ -298,8 +300,32 @@ class ConversionPattern : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final;
 
+  /// Return the type converter held by this pattern, or nullptr if the pattern
+  /// does not require type conversion.
+  TypeConverter *getTypeConverter() const { return typeConverter; }
+
 protected:
+  /// See `RewritePattern::RewritePattern` for information on the other
+  /// available constructors.
   using RewritePattern::RewritePattern;
+  /// Construct a conversion pattern that matches an operation with the given
+  /// root name. This constructor allows for providing a type converter to use
+  /// within the pattern.
+  ConversionPattern(StringRef rootName, PatternBenefit benefit,
+                    TypeConverter &typeConverter, MLIRContext *ctx)
+      : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
+  /// Construct a conversion pattern that matches any operation type. This
+  /// constructor allows for providing a type converter to use within the
+  /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+  /// always be supplied here.
+  ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
+                    MatchAnyOpTypeTag tag)
+      : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
+
+protected:
+  /// An optional type converter for use by this pattern.
+  TypeConverter *typeConverter;
 
 private:
   using RewritePattern::rewrite;
@@ -312,6 +338,10 @@ template <typename SourceOp>
 struct OpConversionPattern : public ConversionPattern {
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
+  OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
+                          context) {}
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
@@ -367,7 +397,7 @@ struct ConversionPatternRewriterImpl;
 /// hooks.
 class ConversionPatternRewriter final : public PatternRewriter {
 public:
-  ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter);
+  ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
 
   /// Apply a signature conversion to the entry block of the given region. This
@@ -377,6 +407,15 @@ class ConversionPatternRewriter final : public PatternRewriter {
   applySignatureConversion(Region *region,
                            TypeConverter::SignatureConversion &conversion);
 
+  /// Convert the types of block arguments within the given region. This
+  /// replaces each block with a new block containing the updated signature. The
+  /// entry block may have a special conversion if `entryConversion` is
+  /// provided. On success, the new entry block to the region is returned for
+  /// convenience. Otherwise, failure is returned.
+  FailureOr<Block *> convertRegionTypes(
+      Region *region, TypeConverter &converter,
+      TypeConverter::SignatureConversion *entryConversion = nullptr);
+
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
 
@@ -721,36 +760,30 @@ class ConversionTarget {
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If `converter` is
-/// provided, the signatures of blocks and regions are also converted.
-/// If an `unconvertedOps` set is provided, all operations that are found not
-/// to be legalizable to the given `target` are placed within that set. (Note
-/// that if there is an op explicitly marked as illegal, the conversion
-/// terminates and the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal. If an
+/// `unconvertedOps` set is provided, all operations that are found not to be
+/// legalizable to the given `target` are placed within that set. (Note that if
+/// there is an op explicitly marked as illegal, the conversion terminates and
+/// the `unconvertedOps` set will not necessarily be complete.)
 LLVM_NODISCARD LogicalResult
 applyPartialConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
                        const OwningRewritePatternList &patterns,
-                       TypeConverter *converter = nullptr,
                        DenseSet<Operation *> *unconvertedOps = nullptr);
 LLVM_NODISCARD LogicalResult
 applyPartialConversion(Operation *op, ConversionTarget &target,
                        const OwningRewritePatternList &patterns,
-                       TypeConverter *converter = nullptr,
                        DenseSet<Operation *> *unconvertedOps = nullptr);
 
 /// Apply a complete conversion on the given operations, and all nested
 /// operations. This method returns failure if the conversion of any operation
 /// fails, or if there are unreachable blocks in any of the regions nested
-/// within 'ops'. If 'converter' is provided, the signatures of blocks and
-/// regions are also converted.
+/// within 'ops'.
 LLVM_NODISCARD LogicalResult
 applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                    const OwningRewritePatternList &patterns,
-                    TypeConverter *converter = nullptr);
+                    const OwningRewritePatternList &patterns);
 LLVM_NODISCARD LogicalResult
 applyFullConversion(Operation *op, ConversionTarget &target,
-                    const OwningRewritePatternList &patterns,
-                    TypeConverter *converter = nullptr);
+                    const OwningRewritePatternList &patterns);
 
 /// Apply an analysis conversion on the given operations, and all nested
 /// operations. This method analyzes which operations would be successfully
@@ -759,17 +792,15 @@ applyFullConversion(Operation *op, ConversionTarget &target,
 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
 /// operations on success and only pre-existing operations are added to the set.
 /// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops', or if a type conversion failed. If
-/// 'converter' is provided, the signatures of blocks and regions are also
-/// considered for conversion.
-LLVM_NODISCARD LogicalResult applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const OwningRewritePatternList &patterns,
-    DenseSet<Operation *> &convertedOps, TypeConverter *converter = nullptr);
-LLVM_NODISCARD LogicalResult applyAnalysisConversion(
-    Operation *op, ConversionTarget &target,
-    const OwningRewritePatternList &patterns,
-    DenseSet<Operation *> &convertedOps, TypeConverter *converter = nullptr);
+/// the regions nested within 'ops'.
+LLVM_NODISCARD LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+                        const OwningRewritePatternList &patterns,
+                        DenseSet<Operation *> &convertedOps);
+LLVM_NODISCARD LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+                        const OwningRewritePatternList &patterns,
+                        DenseSet<Operation *> &convertedOps);
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 0753e388f72e..b65118b72fdf 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -179,10 +179,7 @@ void ConvertAVX512ToLLVMPass::runOnOperation() {
   target.addLegalDialect<LLVM::LLVMDialect>();
   target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
   target.addIllegalDialect<avx512::AVX512Dialect>();
-  target.addDynamicallyLegalOp<FuncOp>(
-      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(applyPartialConversion(getOperation(), target, patterns,
-                                    &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns))) {
     signalPassFailure();
   }
 }

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 4a1fe1afffe6..f6aede4c0d70 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -145,8 +145,9 @@ struct GPUFuncOpLowering : ConvertToLLVMPattern {
     // Move the region to the new function, update the entry block signature.
     rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
                                 llvmFuncOp.end());
-    rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
-                                      signatureConversion);
+    if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), typeConverter,
+                                           &signatureConversion)))
+      return failure();
 
     rewriter.eraseOp(gpuFuncOp);
     return success();

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 0fe767a2b3a6..e4fabe4f441e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -133,7 +133,7 @@ class LowerGpuOpsToNVVMOpsPass
     target.addLegalDialect<NVVM::NVVMDialect>();
     // TODO(csigg): Remove once we support replacing non-root ops.
     target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
-    if (failed(applyPartialConversion(m, target, patterns, &converter)))
+    if (failed(applyPartialConversion(m, target, patterns)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 5707075767ed..2381d615f91b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -67,7 +67,7 @@ class LowerGpuOpsToROCDLOpsPass
     target.addLegalDialect<ROCDL::ROCDLDialect>();
     // TODO(whchung): Remove once we support replacing non-root ops.
     target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
-    if (failed(applyPartialConversion(m, target, patterns, &converter)))
+    if (failed(applyPartialConversion(m, target, patterns)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index 322d08a1b340..2b4829adcdeb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -164,8 +164,11 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
   TypeConverter::SignatureConversion signatureConverter(
       body->getNumArguments());
   signatureConverter.remapInput(0, newIndVar);
-  body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
-                                           signatureConverter);
+  FailureOr<Block *> newBody = rewriter.convertRegionTypes(
+      &forOp.getLoopBody(), typeConverter, &signatureConverter);
+  if (failed(newBody))
+    return failure();
+  body = *newBody;
 
   // Delete the loop terminator.
   rewriter.eraseOp(body->getTerminator());
@@ -356,9 +359,12 @@ lowerAsEntryFunction(gpu::GPUFuncOp funcOp, SPIRVTypeConverter &typeConverter,
       continue;
     newFuncOp.setAttr(namedAttr.first, namedAttr.second);
   }
+
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                         &signatureConverter)))
+    return nullptr;
   rewriter.eraseOp(funcOp);
 
   spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo);

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 3147eede8819..1f486b96e86c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -61,10 +61,8 @@ void GPUToSPIRVPass::runOnOperation() {
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
-  if (failed(applyFullConversion(kernelModules, *target, patterns,
-                                 &typeConverter))) {
+  if (failed(applyFullConversion(kernelModules, *target, patterns)))
     return signalPassFailure();
-  }
 }
 
 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index f60351098a18..b92ab13bd513 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -383,10 +383,8 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
 
   LLVMConversionTarget target(getContext());
-  target.addDynamicallyLegalOp<FuncOp>(
-      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
   target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-  if (failed(applyFullConversion(module, target, patterns, &converter)))
+  if (failed(applyFullConversion(module, target, patterns)))
     signalPassFailure();
 }
 

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index d81e269c778b..cc938c8c1594 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -36,8 +36,10 @@ void LinalgToSPIRVPass::runOnOperation() {
 
   // Allow builtin ops.
   target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
-  target->addDynamicallyLegalOp<FuncOp>(
-      [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
+  target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    return typeConverter.isSignatureLegal(op.getType()) &&
+           typeConverter.isLegal(&op.getBody());
+  });
 
   if (failed(applyFullConversion(module, *target, patterns)))
     return signalPassFailure();

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
index c8e2d731c581..8f300541a71d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp
@@ -44,8 +44,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addIllegalDialect<spirv::SPIRVDialect>();
   target.addLegalDialect<LLVM::LLVMDialect>();
-
-  if (failed(applyPartialConversion(module, target, patterns, &converter)))
+  if (failed(applyPartialConversion(module, target, patterns)))
     signalPassFailure();
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e7741145cfe5..d02f5e3de116 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -82,7 +82,8 @@ class ConvertShapeToStandardPass
     target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
     target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-      return typeConverter.isSignatureLegal(op.getType());
+      return typeConverter.isSignatureLegal(op.getType()) &&
+             typeConverter.isLegal(&op.getBody());
     });
 
     // Setup conversion patterns.
@@ -92,7 +93,7 @@ class ConvertShapeToStandardPass
 
     // Apply conversion.
     auto module = getOperation();
-    if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
+    if (failed(applyFullConversion(module, target, patterns)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index a316f2e56041..19c451fa3fe9 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -398,7 +398,7 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
                                            MLIRContext *context,
                                            LLVMTypeConverter &typeConverter_,
                                            PatternBenefit benefit)
-    : ConversionPattern(rootOpName, benefit, context),
+    : ConversionPattern(rootOpName, benefit, typeConverter_, context),
       typeConverter(typeConverter_) {}
 
 /*============================================================================*/
@@ -1038,8 +1038,9 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
         attributes);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
-    // Tell the rewriter to convert the region signature.
-    rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
+    if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                           &result)))
+      return nullptr;
 
     return newFuncOp;
   }
@@ -1059,6 +1060,9 @@ struct FuncOpConversion : public FuncOpConversionBase {
     auto funcOp = cast<FuncOp>(op);
 
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
+    if (!newFuncOp)
+      return failure();
+
     if (emitWrappers || funcOp.getAttrOfType<UnitAttr>(kEmitIfaceAttrName)) {
       if (newFuncOp.isExternal())
         wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp,
@@ -1095,6 +1099,8 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
     getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo);
 
     auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
+    if (!newFuncOp)
+      return failure();
     if (newFuncOp.getBody().empty()) {
       rewriter.eraseOp(op);
       return success();
@@ -3172,7 +3178,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
                                           emitCWrappers, useAlignedAlloc);
 
     LLVMConversionTarget target(getContext());
-    if (failed(applyPartialConversion(m, target, patterns, &typeConverter)))
+    if (failed(applyPartialConversion(m, target, patterns)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4d9734a87f0c..bd9ec9363383 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1195,10 +1195,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   populateStdToLLVMConversionPatterns(converter, patterns);
 
   LLVMConversionTarget target(getContext());
-  target.addDynamicallyLegalOp<FuncOp>(
-      [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(applyPartialConversion(getOperation(), target, patterns,
-                                    &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns))) {
     signalPassFailure();
   }
 }

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 665a32c2a99b..37e314f2b1dc 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -173,10 +173,8 @@ void LowerVectorToROCDLPass::runOnOperation() {
   LLVMConversionTarget target(getContext());
   target.addLegalDialect<ROCDL::ROCDLDialect>();
 
-  if (failed(applyPartialConversion(getOperation(), target, patterns,
-                                    &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns)))
     signalPassFailure();
-  }
 }
 
 std::unique_ptr<OperationPass<ModuleOp>>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 23570625e688..afd94cc06c6e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -136,19 +136,19 @@ struct ConvertLinalgOnTensorsToBuffers
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
       return converter.isSignatureLegal(funcOp.getType()) &&
              llvm::none_of(funcOp.getType().getResults(),
-                           [&](Type type) { return type.isa<MemRefType>(); });
+                           [&](Type type) { return type.isa<MemRefType>(); }) &&
+             converter.isLegal(&funcOp.getBody());
     });
 
     // Walk over all the functions to apply buffer assignment.
-    getOperation().walk([&](FuncOp function) {
+    getOperation().walk([&](FuncOp function) -> WalkResult {
       OwningRewritePatternList patterns;
       BufferAssignmentPlacer placer(function);
       populateConvertLinalgOnTensorsToBuffersPattern(&context, &placer,
                                                      &converter, &patterns);
 
       // Applying full conversion
-      return WalkResult(
-          applyFullConversion(function, target, patterns, &converter));
+      return applyFullConversion(function, target, patterns);
     });
   }
 };

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 7df2be9c61b7..6bb07b28d022 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -489,7 +489,9 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
 
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
-  rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
+  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+                                         &signatureConverter)))
+    return failure();
   rewriter.eraseOp(funcOp);
   return success();
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 139b6bc093d0..5bd425ae9107 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -201,12 +201,14 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     }
     signatureConverter.remapInput(argType.index(), replacement);
   }
+  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), typeConverter,
+                                         &signatureConverter)))
+    return failure();
 
   // Creates a new function with the update signature.
   rewriter.updateRootInPlace(funcOp, [&] {
     funcOp.setType(rewriter.getFunctionType(
         signatureConverter.getConvertedTypes(), llvm::None));
-    rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
   });
   return success();
 }
@@ -237,10 +239,8 @@ void LowerABIAttributesPass::runOnOperation() {
     return op->getDialect()->getNamespace() ==
            spirv::SPIRVDialect::getDialectNamespace();
   });
-  if (failed(
-          applyPartialConversion(module, target, patterns, &typeConverter))) {
+  if (failed(applyPartialConversion(module, target, patterns)))
     return signalPassFailure();
-  }
 
   // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point
   // attributes.

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index b06524719577..ecebe61d025f 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -98,7 +98,7 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
 }
 
 //===----------------------------------------------------------------------===//
-// Multi-Level Value Mapper
+// ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
 namespace {
@@ -140,9 +140,7 @@ namespace {
 /// types and extracting the block that contains the old illegal types to allow
 /// for undoing pending rewrites in the case of failure.
 struct ArgConverter {
-  ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
-      : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
-        rewriter(rewriter) {}
+  ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
 
   /// This structure contains the information pertaining to an argument that has
   /// been converted.
@@ -166,7 +164,8 @@ struct ArgConverter {
   /// This structure contains information pertaining to a block that has had its
   /// signature converted.
   struct ConvertedBlockInfo {
-    ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {}
+    ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
+        : origBlock(origBlock), converter(&converter) {}
 
     /// The original block that was requested to have its signature converted.
     Block *origBlock;
@@ -174,11 +173,26 @@ struct ArgConverter {
     /// The conversion information for each of the arguments. The information is
     /// None if the argument was dropped during conversion.
     SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
+
+    /// The type converter used to convert the arguments.
+    TypeConverter *converter;
   };
 
   /// Return if the signature of the given block has already been converted.
   bool hasBeenConverted(Block *block) const {
-    return conversionInfo.count(block);
+    return conversionInfo.count(block) || convertedBlocks.count(block);
+  }
+
+  /// Set the type converter to use for the given region.
+  void setConverter(Region *region, TypeConverter *typeConverter) {
+    assert(typeConverter && "expected valid type converter");
+    regionToConverter[region] = typeConverter;
+  }
+
+  /// Return the type converter to use for the given region, or null if there
+  /// isn't one.
+  TypeConverter *getConverter(Region *region) {
+    return regionToConverter.lookup(region);
   }
 
   //===--------------------------------------------------------------------===//
@@ -204,32 +218,39 @@ struct ArgConverter {
   //===--------------------------------------------------------------------===//
 
   /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. On failure, nullptr is
-  /// returned.
-  Block *convertSignature(Block *block, ConversionValueMapping &mapping);
+  /// block is returned containing the new arguments. Returns `block` if it did
+  /// not require conversion.
+  FailureOr<Block *> convertSignature(Block *block, TypeConverter &converter,
+                                      ConversionValueMapping &mapping);
 
   /// Apply the given signature conversion on the given block. The new block
-  /// containing the updated signature is returned.
+  /// containing the updated signature is returned. If no conversions were
+  /// necessary, e.g. if the block has no arguments, `block` is returned.
+  /// `converter` is used to generate any necessary cast operations that
+  /// translate between the origin argument types and those specified in the
+  /// signature conversion.
   Block *applySignatureConversion(
-      Block *block, TypeConverter::SignatureConversion &signatureConversion,
+      Block *block, TypeConverter &converter,
+      TypeConverter::SignatureConversion &signatureConversion,
       ConversionValueMapping &mapping);
 
   /// Insert a new conversion into the cache.
   void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
 
-  /// A collection of blocks that have had their arguments converted.
+  /// A collection of blocks that have had their arguments converted. This is a
+  /// map from the new replacement block, back to the original block.
   llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
 
+  /// The set of original blocks that were converted.
+  DenseSet<Block *> convertedBlocks;
+
   /// A mapping from valid regions, to those containing the original blocks of a
   /// conversion.
   DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
 
-  /// An instance of the unknown location that is used when materializing
-  /// conversions.
-  Location loc;
-
-  /// The type converter to use when changing types.
-  TypeConverter *typeConverter;
+  /// A mapping of regions to type converters that should be used when
+  /// converting the arguments of blocks within that region.
+  DenseMap<Region *, TypeConverter *> regionToConverter;
 
   /// The pattern rewriter to use when materializing conversions.
   PatternRewriter &rewriter;
@@ -240,6 +261,9 @@ struct ArgConverter {
 // Rewrite Application
 
 void ArgConverter::notifyOpRemoved(Operation *op) {
+  if (conversionInfo.empty())
+    return;
+
   for (Region &region : op->getRegions()) {
     for (Block &block : region) {
       // Drop any rewrites from within.
@@ -277,6 +301,7 @@ void ArgConverter::discardRewrites(Block *block) {
   origBlock->moveBefore(block);
   block->erase();
 
+  convertedBlocks.erase(origBlock);
   conversionInfo.erase(it);
 }
 
@@ -305,8 +330,8 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
         // persist in the IR after conversion.
         if (!origArg.use_empty()) {
           rewriter.setInsertionPointToStart(newBlock);
-          Value newArg = typeConverter->materializeConversion(
-              rewriter, loc, origArg.getType(), llvm::None);
+          Value newArg = blockInfo.converter->materializeConversion(
+              rewriter, origArg.getLoc(), origArg.getType(), llvm::None);
           assert(newArg &&
                  "Couldn't materialize a block argument after 1->0 conversion");
           origArg.replaceAllUsesWith(newArg);
@@ -333,15 +358,23 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
 //===----------------------------------------------------------------------===//
 // Conversion
 
-Block *ArgConverter::convertSignature(Block *block,
-                                      ConversionValueMapping &mapping) {
-  if (auto conversion = typeConverter->convertBlockSignature(block))
-    return applySignatureConversion(block, *conversion, mapping);
-  return nullptr;
+FailureOr<Block *>
+ArgConverter::convertSignature(Block *block, TypeConverter &converter,
+                               ConversionValueMapping &mapping) {
+  // Check if the block was already converted. If the block is detached,
+  // conservatively assume it is going to be deleted.
+  if (hasBeenConverted(block) || !block->getParent())
+    return block;
+
+  // Try to convert the signature for the block with the provided converter.
+  if (auto conversion = converter.convertBlockSignature(block))
+    return applySignatureConversion(block, converter, *conversion, mapping);
+  return failure();
 }
 
 Block *ArgConverter::applySignatureConversion(
-    Block *block, TypeConverter::SignatureConversion &signatureConversion,
+    Block *block, TypeConverter &converter,
+    TypeConverter::SignatureConversion &signatureConversion,
     ConversionValueMapping &mapping) {
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
@@ -359,7 +392,7 @@ Block *ArgConverter::applySignatureConversion(
 
   // Remap each of the original arguments as determined by the signature
   // conversion.
-  ConvertedBlockInfo info(block);
+  ConvertedBlockInfo info(block, converter);
   info.argInfo.resize(origArgCount);
 
   OpBuilder::InsertionGuard guard(rewriter);
@@ -384,10 +417,8 @@ Block *ArgConverter::applySignatureConversion(
     // to pack the new values. For 1->1 mappings, if there is no materialization
     // provided, use the argument directly instead.
     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg;
-    if (typeConverter)
-      newArg = typeConverter->materializeConversion(
-          rewriter, loc, origArg.getType(), replArgs);
+    Value newArg = converter.materializeConversion(rewriter, origArg.getLoc(),
+                                                   origArg.getType(), replArgs);
     if (!newArg) {
       assert(replArgs.size() == 1 &&
              "couldn't materialize the result of 1->N conversion");
@@ -414,6 +445,7 @@ void ArgConverter::insertConversion(Block *newBlock,
   // Move the original block to the mapped region and emplace the conversion.
   mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
                                    info.origBlock->getIterator());
+  convertedBlocks.insert(info.origBlock);
   conversionInfo.insert({newBlock, std::move(info)});
 }
 
@@ -548,9 +580,8 @@ struct ConversionPatternRewriterImpl {
     };
   };
 
-  ConversionPatternRewriterImpl(PatternRewriter &rewriter,
-                                TypeConverter *converter)
-      : argConverter(converter, rewriter) {}
+  ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+      : argConverter(rewriter) {}
 
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
@@ -575,13 +606,20 @@ struct ConversionPatternRewriterImpl {
   void applyRewrites();
 
   /// Convert the signature of the given block.
-  LogicalResult convertBlockSignature(Block *block);
+  FailureOr<Block *> convertBlockSignature(
+      Block *block, TypeConverter &converter,
+      TypeConverter::SignatureConversion *conversion = nullptr);
 
   /// Apply a signature conversion on the given region.
   Block *
   applySignatureConversion(Region *region,
                            TypeConverter::SignatureConversion &conversion);
 
+  /// Convert the types of block arguments within the given region.
+  FailureOr<Block *>
+  convertRegionTypes(Region *region, TypeConverter &converter,
+                     TypeConverter::SignatureConversion *entryConversion);
+
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ValueRange newValues);
 
@@ -654,6 +692,10 @@ struct ConversionPatternRewriterImpl {
   /// A logger used to emit diagnostics during the conversion process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
 #endif
+
+  /// A default type converter, used when block conversions do not have one
+  /// explicitly provided.
+  TypeConverter defaultTypeConverter;
 };
 } // end namespace detail
 } // end namespace mlir
@@ -791,7 +833,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
     // If this operation defines any regions, drop any pending argument
     // rewrites.
-    if (argConverter.typeConverter && repl.op->getNumRegions())
+    if (repl.op->getNumRegions())
       argConverter.notifyOpRemoved(repl.op);
   }
 
@@ -826,34 +868,45 @@ void ConversionPatternRewriterImpl::applyRewrites() {
   eraseDanglingBlocks();
 }
 
-LogicalResult
-ConversionPatternRewriterImpl::convertBlockSignature(Block *block) {
-  // Check to see if this block should not be converted:
-  // * There is no type converter.
-  // * The block has already been converted.
-  // * This is an entry block, these are converted explicitly via patterns.
-  if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
-      !block->getParent() || block->isEntryBlock())
-    return success();
-
-  // Otherwise, try to convert the block signature.
-  Block *newBlock = argConverter.convertSignature(block, mapping);
-  if (newBlock)
-    blockActions.push_back(BlockAction::getTypeConversion(newBlock));
-  return success(newBlock);
+FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
+    Block *block, TypeConverter &converter,
+    TypeConverter::SignatureConversion *conversion) {
+  FailureOr<Block *> result =
+      conversion ? argConverter.applySignatureConversion(block, converter,
+                                                         *conversion, mapping)
+                 : argConverter.convertSignature(block, converter, mapping);
+  if (Block *newBlock = result.getValue()) {
+    if (newBlock != block)
+      blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+  }
+  return result;
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion) {
   if (!region->empty()) {
-    Block *newEntry = argConverter.applySignatureConversion(
-        &region->front(), conversion, mapping);
-    blockActions.push_back(BlockAction::getTypeConversion(newEntry));
-    return newEntry;
+    return *convertBlockSignature(&region->front(), defaultTypeConverter,
+                                  &conversion);
   }
   return nullptr;
 }
 
+FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
+    Region *region, TypeConverter &converter,
+    TypeConverter::SignatureConversion *entryConversion) {
+  argConverter.setConverter(region, &converter);
+  if (region->empty())
+    return nullptr;
+
+  // Convert the arguments of each block within the region.
+  FailureOr<Block *> newEntry =
+      convertBlockSignature(&region->front(), converter, entryConversion);
+  for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
+    if (failed(convertBlockSignature(&block, converter)))
+      return failure();
+  return newEntry;
+}
+
 void ConversionPatternRewriterImpl::replaceOp(Operation *op,
                                               ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
@@ -938,10 +991,9 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
 
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx,
-                                                     TypeConverter *converter)
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {}
+      impl(new detail::ConversionPatternRewriterImpl(*this)) {}
 ConversionPatternRewriter::~ConversionPatternRewriter() {}
 
 /// PatternRewriter hook for replacing the results of an operation.
@@ -979,12 +1031,17 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
   block->getParent()->getBlocks().remove(block);
 }
 
-/// Apply a signature conversion to the entry block of the given region.
 Block *ConversionPatternRewriter::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion) {
   return impl->applySignatureConversion(region, conversion);
 }
 
+FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
+    Region *region, TypeConverter &converter,
+    TypeConverter::SignatureConversion *entryConversion) {
+  return impl->convertRegionTypes(region, converter, entryConversion);
+}
+
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
   LLVM_DEBUG({
@@ -1163,6 +1220,20 @@ class OperationLegalizer {
                                       ConversionPatternRewriter &rewriter,
                                       RewriterState &curState);
 
+  /// Legalizes the actions registered during the execution of a pattern.
+  LogicalResult legalizePatternBlockActions(Operation *op,
+                                            ConversionPatternRewriter &rewriter,
+                                            ConversionPatternRewriterImpl &impl,
+                                            RewriterState &state,
+                                            RewriterState &newState);
+  LogicalResult legalizePatternCreatedOperations(
+      ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+      RewriterState &state, RewriterState &newState);
+  LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
+                                           ConversionPatternRewriterImpl &impl,
+                                           RewriterState &state,
+                                           RewriterState &newState);
+
   /// Build an optimistic legalization graph given the provided patterns. This
   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
   /// patterns for operations that are not directly legal, but may be
@@ -1402,50 +1473,29 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
 LogicalResult OperationLegalizer::legalizePatternResult(
     Operation *op, const RewritePattern &pattern,
     ConversionPatternRewriter &rewriter, RewriterState &curState) {
-  auto &rewriterImpl = rewriter.getImpl();
+  auto &impl = rewriter.getImpl();
 
 #ifndef NDEBUG
-  assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+  assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 #endif
 
-  // If the pattern moved or created any blocks, try to legalize their types.
-  // This ensures that the types of the block arguments are legal for the region
-  // they were moved into.
-  for (unsigned i = curState.numBlockActions,
-                e = rewriterImpl.blockActions.size();
-       i != e; ++i) {
-    auto &action = rewriterImpl.blockActions[i];
-    if (action.kind ==
-            ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
-        action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
-      continue;
-
-    // Convert the block signature.
-    if (failed(rewriterImpl.convertBlockSignature(action.block))) {
-      LLVM_DEBUG(logFailure(rewriterImpl.logger,
-                            "failed to convert types of moved block"));
-      return failure();
-    }
-  }
-
   // Check all of the replacements to ensure that the pattern actually replaced
   // the root operation. We also mark any other replaced ops as 'dead' so that
   // we don't try to legalize them later.
   bool replacedRoot = false;
-  for (unsigned i = curState.numReplacements,
-                e = rewriterImpl.replacements.size();
+  for (unsigned i = curState.numReplacements, e = impl.replacements.size();
        i != e; ++i) {
-    Operation *replacedOp = rewriterImpl.replacements[i].op;
+    Operation *replacedOp = impl.replacements[i].op;
     if (replacedOp == op)
       replacedRoot = true;
     else
-      rewriterImpl.ignoredOps.insert(replacedOp);
+      impl.ignoredOps.insert(replacedOp);
   }
 
   // Check that the root was either updated or replace.
   auto updatedRootInPlace = [&] {
     return llvm::any_of(
-        llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+        llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
         [op](auto &state) { return state.getOperation() == op; });
   };
   (void)replacedRoot;
@@ -1453,32 +1503,99 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   assert((replacedRoot || updatedRootInPlace()) &&
          "expected pattern to replace the root operation");
 
-  // Recursively legalize each of the operations updated in place.
-  for (unsigned i = curState.numRootUpdates,
-                e = rewriterImpl.rootUpdates.size();
-       i != e; ++i) {
-    auto &state = rewriterImpl.rootUpdates[i];
-    if (failed(legalize(state.getOperation(), rewriter))) {
-      LLVM_DEBUG(logFailure(rewriterImpl.logger,
-                            "operation updated in-place '{0}' was illegal",
-                            op->getName()));
+  // Legalize each of the actions registered during application.
+  RewriterState newState = impl.getCurrentState();
+  if (failed(legalizePatternBlockActions(op, rewriter, impl, curState,
+                                         newState)) ||
+      failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
+      failed(legalizePatternCreatedOperations(rewriter, impl, curState,
+                                              newState))) {
+    return failure();
+  }
+
+  LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully"));
+  return success();
+}
+
+LogicalResult OperationLegalizer::legalizePatternBlockActions(
+    Operation *op, ConversionPatternRewriter &rewriter,
+    ConversionPatternRewriterImpl &impl, RewriterState &state,
+    RewriterState &newState) {
+  SmallPtrSet<Operation *, 16> operationsToIgnore;
+
+  // If the pattern moved or created any blocks, make sure the types of block
+  // arguments get legalized.
+  for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
+       ++i) {
+    auto &action = impl.blockActions[i];
+    if (action.kind ==
+            ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
+        action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
+      continue;
+    // Only check blocks outside of the current operation.
+    Operation *parentOp = action.block->getParentOp();
+    if (!parentOp || parentOp == op || action.block->getNumArguments() == 0)
+      continue;
+
+    // If the region of the block has a type converter, try to convert the block
+    // directly.
+    if (auto *converter =
+            impl.argConverter.getConverter(action.block->getParent())) {
+      if (failed(impl.convertBlockSignature(action.block, *converter))) {
+        LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
+                                           "block"));
+        return failure();
+      }
+      continue;
+    }
+
+    // Otherwise, check that this operation isn't one generated by this pattern.
+    // This is because we will attempt to legalize the parent operation, and
+    // blocks in regions created by this pattern will already be legalized later
+    // on. If we haven't built the set yet, build it now.
+    if (operationsToIgnore.empty()) {
+      auto createdOps = ArrayRef<Operation *>(impl.createdOps)
+                            .drop_front(state.numCreatedOps);
+      operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+    }
+
+    // If this operation should be considered for re-legalization, try it.
+    if (operationsToIgnore.insert(parentOp).second &&
+        failed(legalize(parentOp, rewriter))) {
+      LLVM_DEBUG(logFailure(
+          impl.logger, "operation '{0}'({1}) became illegal after block action",
+          parentOp->getName(), parentOp));
       return failure();
     }
   }
-
-  // Recursively legalize each of the new operations.
-  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
-       i != e; ++i) {
-    Operation *op = rewriterImpl.createdOps[i];
+  return success();
+}
+LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
+    ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+    RewriterState &state, RewriterState &newState) {
+  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
+    Operation *op = impl.createdOps[i];
     if (failed(legalize(op, rewriter))) {
-      LLVM_DEBUG(logFailure(rewriterImpl.logger,
+      LLVM_DEBUG(logFailure(impl.logger,
                             "generated operation '{0}'({1}) was illegal",
                             op->getName(), op));
       return failure();
     }
   }
-
-  LLVM_DEBUG(logSuccess(rewriterImpl.logger, "pattern applied successfully"));
+  return success();
+}
+LogicalResult OperationLegalizer::legalizePatternRootUpdates(
+    ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
+    RewriterState &state, RewriterState &newState) {
+  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
+    Operation *op = impl.rootUpdates[i].getOperation();
+    if (failed(legalize(op, rewriter))) {
+      LLVM_DEBUG(logFailure(impl.logger,
+                            "operation updated in-place '{0}' was illegal",
+                            op->getName()));
+      return failure();
+    }
+  }
   return success();
 }
 
@@ -1699,17 +1816,12 @@ struct OperationConverter {
       : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
 
   /// Converts the given operations to the conversion target.
-  LogicalResult convertOperations(ArrayRef<Operation *> ops,
-                                  TypeConverter *typeConverter);
+  LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
 private:
   /// Converts an operation with the given rewriter.
   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
 
-  /// Converts the type signatures of the blocks nested within 'op'.
-  LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter,
-                                       Operation *op);
-
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
@@ -1724,21 +1836,6 @@ struct OperationConverter {
 };
 } // end anonymous namespace
 
-LogicalResult
-OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter,
-                                           Operation *op) {
-  // Check to see if type signatures need to be converted.
-  if (!rewriter.getImpl().argConverter.typeConverter)
-    return success();
-
-  for (auto &region : op->getRegions()) {
-    for (auto &block : llvm::make_early_inc_range(region))
-      if (failed(rewriter.getImpl().convertBlockSignature(&block)))
-        return failure();
-  }
-  return success();
-}
-
 LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
                                           Operation *op) {
   // Legalize the given operation.
@@ -1759,24 +1856,16 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
       if (trackedOps)
         trackedOps->insert(op);
     }
-  } else {
+  } else if (mode == OpConversionMode::Analysis) {
     // Analysis conversions don't fail if any operations fail to legalize,
     // they are only interested in the operations that were successfully
     // legalized.
-    if (mode == OpConversionMode::Analysis)
-      trackedOps->insert(op);
-
-    // If legalization succeeded, convert the types any of the blocks within
-    // this operation.
-    if (failed(convertBlockSignatures(rewriter, op)))
-      return failure();
+    trackedOps->insert(op);
   }
   return success();
 }
 
-LogicalResult
-OperationConverter::convertOperations(ArrayRef<Operation *> ops,
-                                      TypeConverter *typeConverter) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
   ConversionTarget &target = opLegalizer.getTarget();
@@ -1792,7 +1881,7 @@ OperationConverter::convertOperations(ArrayRef<Operation *> ops,
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter);
+  ConversionPatternRewriter rewriter(ops.front()->getContext());
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
       return rewriter.getImpl().discardRewrites(), failure();
@@ -1913,6 +2002,13 @@ bool TypeConverter::isLegal(Operation *op) {
   return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
 }
 
+/// Return true if the types of block arguments within the region are legal.
+bool TypeConverter::isLegal(Region *region) {
+  return llvm::all_of(*region, [this](Block &block) {
+    return isLegal(block.getArgumentTypes());
+  });
+}
+
 /// Return true if the inputs and outputs of the given function type are
 /// legal.
 bool TypeConverter::isSignatureLegal(FunctionType ty) {
@@ -1969,7 +2065,7 @@ auto TypeConverter::convertBlockSignature(Block *block)
 namespace {
 struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
   FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
-      : OpConversionPattern(ctx), converter(converter) {}
+      : OpConversionPattern(converter, ctx) {}
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
@@ -1979,22 +2075,20 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
 
     // Convert the original function types.
     TypeConverter::SignatureConversion result(type.getNumInputs());
-    SmallVector<Type, 1> convertedResults;
-    if (failed(converter.convertSignatureArgs(type.getInputs(), result)) ||
-        failed(converter.convertTypes(type.getResults(), convertedResults)))
+    SmallVector<Type, 1> newResults;
+    if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
+        failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
+        failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
+                                           &result)))
       return failure();
 
     // Update the function signature in-place.
     rewriter.updateRootInPlace(funcOp, [&] {
-      funcOp.setType(FunctionType::get(result.getConvertedTypes(),
-                                       convertedResults, funcOp.getContext()));
-      rewriter.applySignatureConversion(&funcOp.getBody(), result);
+      funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
+                                       funcOp.getContext()));
     });
     return success();
   }
-
-  /// The type converter to use when rewriting the signature.
-  TypeConverter &converter;
 };
 } // end anonymous namespace
 
@@ -2128,27 +2222,26 @@ auto ConversionTarget::getOpInfo(OperationName op) const
 /// Apply a partial conversion on the given operations and all nested
 /// operations. This method converts as many operations to the target as
 /// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If `converter` is
-/// provided, the signatures of blocks and regions are also converted.
+/// returns failure if there ops explicitly marked as illegal.
 /// If an `unconvertedOps` set is provided, all operations that are found not
 /// to be legalizable to the given `target` are placed within that set. (Note
 /// that if there is an op explicitly marked as illegal, the conversion
 /// terminates and the `unconvertedOps` set will not necessarily be complete.)
-LogicalResult mlir::applyPartialConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const OwningRewritePatternList &patterns, TypeConverter *converter,
-    DenseSet<Operation *> *unconvertedOps) {
+LogicalResult
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+                             ConversionTarget &target,
+                             const OwningRewritePatternList &patterns,
+                             DenseSet<Operation *> *unconvertedOps) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
                                  unconvertedOps);
-  return opConverter.convertOperations(ops, converter);
+  return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
                              const OwningRewritePatternList &patterns,
-                             TypeConverter *converter,
                              DenseSet<Operation *> *unconvertedOps) {
   return applyPartialConversion(llvm::makeArrayRef(op), target, patterns,
-                                converter, unconvertedOps);
+                                unconvertedOps);
 }
 
 /// Apply a complete conversion on the given operations, and all nested
@@ -2156,17 +2249,14 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
 /// operation fails.
 LogicalResult
 mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
-                          const OwningRewritePatternList &patterns,
-                          TypeConverter *converter) {
+                          const OwningRewritePatternList &patterns) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
-  return opConverter.convertOperations(ops, converter);
+  return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyFullConversion(Operation *op, ConversionTarget &target,
-                          const OwningRewritePatternList &patterns,
-                          TypeConverter *converter) {
-  return applyFullConversion(llvm::makeArrayRef(op), target, patterns,
-                             converter);
+                          const OwningRewritePatternList &patterns) {
+  return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
 }
 
 /// Apply an analysis conversion on the given operations, and all nested
@@ -2175,19 +2265,19 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target,
 /// were found to be legalizable to the given 'target' are placed within the
 /// provided 'convertedOps' set; note that no actual rewrites are applied to the
 /// operations on success and only pre-existing operations are added to the set.
-LogicalResult mlir::applyAnalysisConversion(
-    ArrayRef<Operation *> ops, ConversionTarget &target,
-    const OwningRewritePatternList &patterns,
-    DenseSet<Operation *> &convertedOps, TypeConverter *converter) {
+LogicalResult
+mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+                              ConversionTarget &target,
+                              const OwningRewritePatternList &patterns,
+                              DenseSet<Operation *> &convertedOps) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
                                  &convertedOps);
-  return opConverter.convertOperations(ops, converter);
+  return opConverter.convertOperations(ops);
 }
 LogicalResult
 mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
                               const OwningRewritePatternList &patterns,
-                              DenseSet<Operation *> &convertedOps,
-                              TypeConverter *converter) {
+                              DenseSet<Operation *> &convertedOps) {
   return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns,
-                                 convertedOps, converter);
+                                 convertedOps);
 }

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 284c38bad7b7..5637fa853745 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -153,15 +153,11 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
 
 // CHECK-LABEL: @create_block
 func @create_block() {
-  // expected-remark at +1 {{op 'test.container' is not legalizable}}
-  "test.container"() ({
-    // Check that we created a block with arguments.
-    // CHECK-NOT: test.create_block
-    // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
-    // CHECK: test.finish
-    "test.create_block"() : () -> ()
-    "test.finish"() : () -> ()
-  }) : () -> ()
+  // Check that we created a block with arguments.
+  // CHECK-NOT: test.create_block
+  // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+  "test.create_block"() : () -> ()
+
   // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }
@@ -212,15 +208,12 @@ func @fail_to_convert_region() {
 
 // CHECK-LABEL: @create_illegal_block
 func @create_illegal_block() {
-  // expected-remark at +1 {{op 'test.container' is not legalizable}}
-  "test.container"() ({
-    // Check that we can undo block creation, i.e. that the block was removed.
-    // CHECK: test.create_illegal_block
-    // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
-    // expected-remark at +1 {{op 'test.create_illegal_block' is not legalizable}}
-    "test.create_illegal_block"() : () -> ()
-    "test.finish"() : () -> ()
-  }) : () -> ()
+  // Check that we can undo block creation, i.e. that the block was removed.
+  // CHECK: test.create_illegal_block
+  // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+  // expected-remark at +1 {{op 'test.create_illegal_block' is not legalizable}}
+  "test.create_illegal_block"() : () -> ()
+
   // expected-remark at +1 {{op 'std.return' is not legalizable}}
   return
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index cbab7d7494da..60f663f75adc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -304,8 +304,7 @@ struct TestUndoBlockErase : public ConversionPattern {
 /// This patterns erases a region operation that has had a type conversion.
 struct TestDropOpSignatureConversion : public ConversionPattern {
   TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) {
-  }
+      : ConversionPattern("test.drop_region_op", 1, converter, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
@@ -313,19 +312,17 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
     Block *entry = &region.front();
 
     // Convert the original entry arguments.
+    TypeConverter &converter = *getTypeConverter();
     TypeConverter::SignatureConversion result(entry->getNumArguments());
-    if (failed(
-            converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
+    if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(),
+                                              result)) ||
+        failed(rewriter.convertRegionTypes(&region, converter, &result)))
       return failure();
 
     // Convert the region signature and just drop the operation.
-    rewriter.applySignatureConversion(&region, result);
     rewriter.eraseOp(op);
     return success();
   }
-
-  /// The type converter to use when rewriting the signature.
-  TypeConverter &converter;
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
@@ -568,8 +565,10 @@ struct TestLegalizePatternDriver
       return llvm::none_of(op.getOperandTypes(),
                            [](Type type) { return type.isF32(); });
     });
-    target.addDynamicallyLegalOp<FuncOp>(
-        [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return converter.isSignatureLegal(op.getType()) &&
+             converter.isLegal(&op.getBody());
+    });
 
     // Expect the type_producer/type_consumer operations to only operate on f64.
     target.addDynamicallyLegalOp<TestTypeProducerOp>(
@@ -591,7 +590,7 @@ struct TestLegalizePatternDriver
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
-      (void)applyPartialConversion(getOperation(), target, patterns, &converter,
+      (void)applyPartialConversion(getOperation(), target, patterns,
                                    &unlegalizedOps);
       // Emit remarks for each legalizable operation.
       for (auto *op : unlegalizedOps)
@@ -606,7 +605,7 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
-      (void)applyFullConversion(getOperation(), target, patterns, &converter);
+      (void)applyFullConversion(getOperation(), target, patterns);
       return;
     }
 
@@ -616,7 +615,7 @@ struct TestLegalizePatternDriver
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target, patterns,
-                                       legalizedOps, &converter)))
+                                       legalizedOps)))
       return signalPassFailure();
 
     // Emit remarks for each legalizable operation.

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 0976f71c0fd3..2fbdfe989d05 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -1,4 +1,4 @@
-//===- TestBufferPlacement.cpp - Test for buffer placement 0----*- C++ -*-===//
+//===- TestBufferPlacement.cpp - Test for buffer placement ------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -140,7 +140,8 @@ struct TestBufferPlacementPreparationPass
 
     // Mark the function whose arguments are in tensor-type illegal.
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
-      return converter.isSignatureLegal(funcOp.getType());
+      return converter.isSignatureLegal(funcOp.getType()) &&
+             converter.isLegal(&funcOp.getBody());
     });
 
     // Walk over all the functions to apply buffer assignment.
@@ -151,7 +152,7 @@ struct TestBufferPlacementPreparationPass
           &context, &placer, &converter, &patterns);
 
       // Applying full conversion
-      return applyFullConversion(function, target, patterns, &converter);
+      return applyFullConversion(function, target, patterns);
     });
   };
 };


        


More information about the Mlir-commits mailing list