[Mlir-commits] [mlir] 52050f3 - [mlir][Transforms] Dialect Conversion: Simplify block conversion API (#94866)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 10 12:49:56 PDT 2024


Author: Matthias Springer
Date: 2024-06-10T21:49:52+02:00
New Revision: 52050f3ff388773b9345d421d968a7d1ee880531

URL: https://github.com/llvm/llvm-project/commit/52050f3ff388773b9345d421d968a7d1ee880531
DIFF: https://github.com/llvm/llvm-project/commit/52050f3ff388773b9345d421d968a7d1ee880531.diff

LOG: [mlir][Transforms] Dialect Conversion: Simplify block conversion API (#94866)

This commit simplifies and improves documentation for the part of the
`ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature. This
function used to take a `Region *` (but converted only the entry block).
It now takes a `Block *`.
* `convertRegionTypes` converts all block signatures of a region.

`convertNonEntryRegionTypes` is removed because it is not widely used
and can easily be expressed with a call to `applySignatureConversion`
inside a loop. (See `Detensorize.cpp` for an example.)

Note: For consistency, `convertRegionTypes` could be renamed to
`applySignatureConversion` (overload) in the future. (Or
`applySignatureConversion` renamed to `convertBlockTypes`.)

Also clarify when a type converter and/or signature conversion object is
needed and for what purpose.

Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the
part that deals with signature conversions). This part of the codebase
was quite convoluted and unintuitive.

>From a functional perspective, this change is NFC. However, the public
API changes, thus not marking as NFC.

Note for LLVM integration: When you see
`applySignatureConversion(region, ...)`, replace with
`applySignatureConversion(region->front(), ...)`. In the unlikely case
that you see `convertNonEntryRegionTypes`, apply the same changes as
this commit did to `Detensorize.cpp`.

---------

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    

Modified: 
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index a355d5a90e4d1..69781bb868bbf 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -372,19 +372,23 @@ class TypeConverter {
 From the perspective of type conversion, the types of block arguments are a bit
 special. Throughout the conversion process, blocks may move between regions of
 
diff erent operations. Given this, the conversion of the types for blocks must be
-done explicitly via a conversion pattern. To convert the types of block
-arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
-be invoked; `convertRegionTypes`. This hook uses a provided type converter to
-apply type conversions to all blocks within a given region, and all blocks that
-move into that region. As noted above, the conversions performed by this method
-use the argument materialization hook on the `TypeConverter`. This hook also
-takes an optional `TypeConverter::SignatureConversion` parameter that applies a
-custom conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
-AffineForOp, etc. To convert the signature of just the region entry block, and
-not any other blocks within the region, the `applySignatureConversion` hook may
-be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
-can be built programmatically:
+done explicitly via a conversion pattern. 
+
+To convert the types of block arguments within a Region, a custom hook on the
+`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
+uses a provided type converter to apply type conversions to all blocks of a
+given region. As noted above, the conversions performed by this method use the
+argument materialization hook on the `TypeConverter`. This hook also takes an
+optional `TypeConverter::SignatureConversion` parameter that applies a custom
+conversion to the entry block of the region. The types of the entry block
+arguments are often tied semantically to the operation, e.g.,
+`func::FuncOp`, `AffineForOp`, etc.
+
+To convert the signature of just one given block, the
+`applySignatureConversion` hook can be used.
+
+A signature conversion, `TypeConverter::SignatureConversion`, can be built
+programmatically:
 
 ```c++
 class SignatureConversion {

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6c51499f271c..f83f3a3fdf992 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
-  /// Apply a signature conversion to the entry block of the given region. This
-  /// replaces the entry block with a new block containing the updated
-  /// signature. The new entry block to the region is returned for convenience.
-  /// If no block argument types are changing, the entry original block will be
+  /// Apply a signature conversion to given block. This replaces the block with
+  /// a new block containing the updated signature. The operations of the given
+  /// block are inlined into the newly-created block, which is returned.
+  ///
+  /// If no block argument types are changing, the original block will be
   /// left in place and returned.
   ///
-  /// If provided, `converter` will be used for any materializations.
+  /// A signature converison must be provided. (Type converters can construct
+  /// a signature conversion with `convertBlockSignature`.)
+  ///
+  /// Optionally, a type converter can be provided to build materializations.
+  /// Note: If no type converter was provided or the type converter does not
+  /// specify any suitable argument/target materialization rules, the dialect
+  /// conversion may fail to legalize unresolved materializations.
   Block *
-  applySignatureConversion(Region *region,
+  applySignatureConversion(Block *block,
                            TypeConverter::SignatureConversion &conversion,
                            const TypeConverter *converter = nullptr);
 
-  /// Convert the types of block arguments within the given region. This
+  /// Apply a signature conversion to each block in the given region. This
   /// replaces each block with a new block containing the updated signature. If
   /// an updated signature would match the current signature, the respective
-  /// block is left in place as is.
+  /// block is left in place as is. (See `applySignatureConversion` for
+  /// details.) The new entry block of the region is returned.
+  ///
+  /// SignatureConversions are computed with the specified type converter.
+  /// This function returns "failure" if the type converter failed to compute
+  /// a SignatureConversion for at least one block.
   ///
-  /// The entry block may have a special conversion if `entryConversion` is
-  /// provided. On success, the new entry block to the region is returned for
-  /// convenience. Otherwise, failure is returned.
+  /// Optionally, a special SignatureConversion can be specified for the entry
+  /// block. This is because the types of the entry block arguments are often
+  /// tied semantically to the operation.
   FailureOr<Block *> convertRegionTypes(
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Convert the types of block arguments within the given region except for
-  /// the entry region. This replaces each non-entry block with a new block
-  /// containing the updated signature. If an updated signature would match the
-  /// current signature, the respective block is left in place as is.
-  ///
-  /// If special conversion behavior is needed for the non-entry blocks (for
-  /// example, we need to convert only a subset of a BB arguments), such
-  /// behavior can be specified in blockConversions.
-  LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions);
-
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
 

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index d90cf931385fc..f62de1f17a666 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     signatureConverter.remapInput(0, newIndVar);
     for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
       signatureConverter.remapInput(i, header->getArgument(i));
-    body = rewriter.applySignatureConversion(&forOp.getRegion(),
+    body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
                                              signatureConverter);
 
     // Move the blocks from the forOp into the loopOp. This is the body of the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 22968096a6891..af38485291182 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startOpModification(op);
     Region &region = op.getFunctionBody();
-    SmallVector<TypeConverter::SignatureConversion, 2> conversions;
 
-    for (Block &block : llvm::drop_begin(region, 1)) {
-      conversions.emplace_back(block.getNumArguments());
-      TypeConverter::SignatureConversion &back = conversions.back();
+    for (Block &block :
+         llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
+      TypeConverter::SignatureConversion conversion(
+          /*numOrigInputs=*/block.getNumArguments());
 
       for (BlockArgument blockArgument : block.getArguments()) {
         int idx = blockArgument.getArgNumber();
 
         if (blockArgsToDetensor.count(blockArgument))
-          back.addInputs(idx, {getTypeConverter()->convertType(
-                                  block.getArgumentTypes()[idx])});
+          conversion.addInputs(idx, {getTypeConverter()->convertType(
+                                        block.getArgumentTypes()[idx])});
         else
-          back.addInputs(idx, {block.getArgumentTypes()[idx]});
+          conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
       }
-    }
 
-    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
-                                                   conversions))) {
-      rewriter.cancelOpModification(op);
-      return failure();
+      rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
     }
 
     rewriter.finalizeOpModification(op);

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..2f0efe1b1e454 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // Type Conversion
   //===--------------------------------------------------------------------===//
 
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *> convertBlockSignature(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
-      TypeConverter::SignatureConversion *conversion = nullptr);
-
-  /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(
-      ConversionPatternRewriter &rewriter, Region *region,
-      const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
-
-  /// Apply a signature conversion on the given region, using `converter` for
-  /// materializations if not null.
-  Block *
-  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
-                           TypeConverter::SignatureConversion &conversion,
-                           const TypeConverter *converter);
-
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
   convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 // Type Conversion
 
-FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
-    TypeConverter::SignatureConversion *conversion) {
-  if (conversion)
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-  return failure();
-}
-
-Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Region *region,
-    TypeConverter::SignatureConversion &conversion,
-    const TypeConverter *converter) {
-  if (!region->empty())
-    return *convertBlockSignature(rewriter, &region->front(), converter,
-                                  &conversion);
-  return nullptr;
-}
-
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     ConversionPatternRewriter &rewriter, Region *region,
     const TypeConverter &converter,
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
-    return failure();
-
-  FailureOr<Block *> newEntry = convertBlockSignature(
-      rewriter, &region->front(), &converter, entryConversion);
-  return newEntry;
-}
-
-LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  regionToConverter[region] = &converter;
-  if (region->empty())
-    return success();
-
-  // Convert the arguments of each block within the region.
-  int blockIdx = 0;
-  assert((blockConversions.empty() ||
-          blockConversions.size() == region->getBlocks().size() - 1) &&
-         "expected either to provide no SignatureConversions at all or to "
-         "provide a SignatureConversion for each non-entry block");
-
+  // Convert the arguments of each non-entry block within the region.
   for (Block &block :
        llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
-    TypeConverter::SignatureConversion *blockConversion =
-        blockConversions.empty()
-            ? nullptr
-            : const_cast<TypeConverter::SignatureConversion *>(
-                  &blockConversions[blockIdx++]);
-
-    if (failed(convertBlockSignature(rewriter, &block, &converter,
-                                     blockConversion)))
+    // Compute the signature for the block with the provided converter.
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter.convertBlockSignature(&block);
+    if (!conversion)
       return failure();
-  }
-  return success();
+    // Convert the block with the computed signature.
+    applySignatureConversion(rewriter, &block, &converter, *conversion);
+  }
+
+  // Convert the entry block. If an entry signature conversion was provided,
+  // use that one. Otherwise, compute the signature with the type converter.
+  if (entryConversion)
+    return applySignatureConversion(rewriter, &region->front(), &converter,
+                                    *entryConversion);
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter.convertBlockSignature(&region->front());
+  if (!conversion)
+    return failure();
+  return applySignatureConversion(rewriter, &region->front(), &converter,
+                                  *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 }
 
 Block *ConversionPatternRewriter::applySignatureConversion(
-    Region *region, TypeConverter::SignatureConversion &conversion,
+    Block *block, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+  assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, region, conversion, converter);
+  return impl->applySignatureConversion(*this, block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
-LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
-    Region *region, const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
-  return impl->convertNonEntryRegionTypes(*this, region, converter,
-                                          blockConversions);
-}
-
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
   LLVM_DEBUG({
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // If the region of the block has a type converter, try to convert the block
     // directly.
     if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
+      std::optional<TypeConverter::SignatureConversion> conversion =
+          converter->convertBlockSignature(block);
+      if (!conversion) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();
       }
+      impl.applySignatureConversion(rewriter, block, converter, *conversion);
       continue;
     }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f9f7d4eacf948..a14a5da341098 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
     if (failed(
             converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
       return failure();
-    rewriter.modifyOpInPlace(
-        op, [&] { rewriter.applySignatureConversion(&region, result); });
+    rewriter.modifyOpInPlace(op, [&] {
+      rewriter.applySignatureConversion(&region.front(), result);
+    });
     return success();
   }
 


        


More information about the Mlir-commits mailing list