[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Simplify block conversion API (PR #94866)

Matthias Springer llvmlistbot at llvm.org
Sat Jun 8 12:32:41 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/94866

This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature. This function used to take a `Region *` (but converted only the entry block). It now takes a `Block *`.
* `convertRegionTypes` converts all block signatures of a region.

`convertNonEntryRegionTypes` is removed because it is not widely used and can easily be expressed with a call to `applySignatureConversion` inside a loop. (See `Detensorize.cpp` for an example.)

Note: For consistency, `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future.

Also clarify when a type converter and/or signature conversion object is needed and for what purpose.

Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive.

>From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.

Note for LLVM integration: When you see `applySignatureConversion(region, ...)`, replace with `applySignatureConversion(region->front(), ...)`. In the unlikely case that you see `convertNonEntryRegionTypes`, apply the same changes as this commit did to `Detensorize.cpp`.


>From 5c86bfabeb670dfa9ddd64739423165588874ee2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 8 Jun 2024 21:15:08 +0200
Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Simplify block
 conversion API

This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature.
* `convertRegionTypes` converts all block signatures of a region.

Note: `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future.

Also clarify when a type converter and/or signature conversion object is needed and for what purpose.

>From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC.
---
 mlir/docs/DialectConversion.md                |  30 +++--
 .../mlir/Transforms/DialectConversion.h       |  43 +++---
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp |   2 +-
 .../Dialect/Linalg/Transforms/Detensorize.cpp |  20 ++-
 .../Transforms/Utils/DialectConversion.cpp    | 123 ++++--------------
 5 files changed, 73 insertions(+), 145 deletions(-)

diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index a355d5a90e4d1..8338109eb97c3 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -372,19 +372,23 @@ class TypeConverter {
 From the perspective of type conversion, the types of block arguments are a bit
 special. Throughout the conversion process, blocks may move between regions of
 different operations. Given this, the conversion of the types for blocks must be
-done explicitly via a conversion pattern. To convert the types of block
-arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
-be invoked; `convertRegionTypes`. This hook uses a provided type converter to
-apply type conversions to all blocks within a given region, and all blocks that
-move into that region. As noted above, the conversions performed by this method
-use the argument materialization hook on the `TypeConverter`. This hook also
-takes an optional `TypeConverter::SignatureConversion` parameter that applies a
-custom conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
-AffineForOp, etc. To convert the signature of just the region entry block, and
-not any other blocks within the region, the `applySignatureConversion` hook may
-be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
-can be built programmatically:
+done explicitly via a conversion pattern. 
+
+To convert the types of block arguments within a Region, a custom hook on the
+`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
+uses a provided type converter to apply type conversions to all blocks of a
+given region. As noted above, the conversions performed by this method use the
+argument materialization hook on the `TypeConverter`. This hook also takes an
+optional `TypeConverter::SignatureConversion` parameter that applies a custom
+conversion to the entry block of the region. The types of the entry block
+arguments are often tied semantically to details on the operation, e.g.,
+`func::FuncOp`, `AffineForOp`, etc.
+
+To convert the signature of just one given block, the
+`applySignatureConversion` hook can be used.
+
+A signature conversion, `TypeConverter::SignatureConversion`, can be built
+programmatically:
 
 ```c++
 class SignatureConversion {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..5f4a972748ffc 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -661,42 +662,38 @@ class ConversionPatternRewriter final : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
-  /// Apply a signature conversion to the entry block of the given region. This
-  /// replaces the entry block with a new block containing the updated
-  /// signature. The new entry block to the region is returned for convenience.
+  /// Apply a signature conversion to given block. This replaces the block with
+  /// a new block containing the updated signature. The operations of the given
+  /// block are inlined into the newly-created block, which is returned.
+  ///
   /// If no block argument types are changing, the entry original block will be
   /// left in place and returned.
   ///
-  /// If provided, `converter` will be used for any materializations.
+  /// A signature converison must be provided. (Type converters can construct
+  /// signature conversion with `convertBlockSignature`.) Optionally, a type
+  /// converter can be provided to build materializations.
   Block *
-  applySignatureConversion(Region *region,
+  applySignatureConversion(Block *block,
                            TypeConverter::SignatureConversion &conversion,
                            const TypeConverter *converter = nullptr);
 
-  /// Convert the types of block arguments within the given region. This
+  /// Apply a signature conversion to each block in the given region. This
   /// replaces each block with a new block containing the updated signature. If
   /// an updated signature would match the current signature, the respective
-  /// block is left in place as is.
+  /// block is left in place as is. (See `applySignatureConversion` for
+  /// details.) The new entry block of the region is returned.
+  ///
+  /// SignatureConversions are computed with the specified type converter.
+  /// This function returns "failure" if the type converter failed to compute
+  /// a SignatureConversion for at least one block.
   ///
-  /// The entry block may have a special conversion if `entryConversion` is
-  /// provided. On success, the new entry block to the region is returned for
-  /// convenience. Otherwise, failure is returned.
+  /// Optionally, a special SignatureConversion can be specified for the entry
+  /// block. This is because the types of the entry block arguments are often
+  /// tied semantically to details on the operation.
   FailureOr<Block *> convertRegionTypes(
       Region *region, const TypeConverter &converter,
       TypeConverter::SignatureConversion *entryConversion = nullptr);
 
-  /// Convert the types of block arguments within the given region except for
-  /// the entry region. This replaces each non-entry block with a new block
-  /// containing the updated signature. If an updated signature would match the
-  /// current signature, the respective block is left in place as is.
-  ///
-  /// If special conversion behavior is needed for the non-entry blocks (for
-  /// example, we need to convert only a subset of a BB arguments), such
-  /// behavior can be specified in blockConversions.
-  LogicalResult convertNonEntryRegionTypes(
-      Region *region, const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions);
-
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
 
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index d90cf931385fc..f62de1f17a666 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     signatureConverter.remapInput(0, newIndVar);
     for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
       signatureConverter.remapInput(i, header->getArgument(i));
-    body = rewriter.applySignatureConversion(&forOp.getRegion(),
+    body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
                                              signatureConverter);
 
     // Move the blocks from the forOp into the loopOp. This is the body of the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 22968096a6891..af38485291182 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startOpModification(op);
     Region &region = op.getFunctionBody();
-    SmallVector<TypeConverter::SignatureConversion, 2> conversions;
 
-    for (Block &block : llvm::drop_begin(region, 1)) {
-      conversions.emplace_back(block.getNumArguments());
-      TypeConverter::SignatureConversion &back = conversions.back();
+    for (Block &block :
+         llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
+      TypeConverter::SignatureConversion conversion(
+          /*numOrigInputs=*/block.getNumArguments());
 
       for (BlockArgument blockArgument : block.getArguments()) {
         int idx = blockArgument.getArgNumber();
 
         if (blockArgsToDetensor.count(blockArgument))
-          back.addInputs(idx, {getTypeConverter()->convertType(
-                                  block.getArgumentTypes()[idx])});
+          conversion.addInputs(idx, {getTypeConverter()->convertType(
+                                        block.getArgumentTypes()[idx])});
         else
-          back.addInputs(idx, {block.getArgumentTypes()[idx]});
+          conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
       }
-    }
 
-    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
-                                                   conversions))) {
-      rewriter.cancelOpModification(op);
-      return failure();
+      rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
     }
 
     rewriter.finalizeOpModification(op);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..2f0efe1b1e454 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // Type Conversion
   //===--------------------------------------------------------------------===//
 
-  /// Attempt to convert the signature of the given block, if successful a new
-  /// block is returned containing the new arguments. Returns `block` if it did
-  /// not require conversion.
-  FailureOr<Block *> convertBlockSignature(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
-      TypeConverter::SignatureConversion *conversion = nullptr);
-
-  /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(
-      ConversionPatternRewriter &rewriter, Region *region,
-      const TypeConverter &converter,
-      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
-
-  /// Apply a signature conversion on the given region, using `converter` for
-  /// materializations if not null.
-  Block *
-  applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
-                           TypeConverter::SignatureConversion &conversion,
-                           const TypeConverter *converter);
-
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
   convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 // Type Conversion
 
-FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
-    TypeConverter::SignatureConversion *conversion) {
-  if (conversion)
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-
-  // If a converter wasn't provided, and the block wasn't already converted,
-  // there is nothing we can do.
-  if (!converter)
-    return failure();
-
-  // Try to convert the signature for the block with the provided converter.
-  if (auto conversion = converter->convertBlockSignature(block))
-    return applySignatureConversion(rewriter, block, converter, *conversion);
-  return failure();
-}
-
-Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Region *region,
-    TypeConverter::SignatureConversion &conversion,
-    const TypeConverter *converter) {
-  if (!region->empty())
-    return *convertBlockSignature(rewriter, &region->front(), converter,
-                                  &conversion);
-  return nullptr;
-}
-
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     ConversionPatternRewriter &rewriter, Region *region,
     const TypeConverter &converter,
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
-    return failure();
-
-  FailureOr<Block *> newEntry = convertBlockSignature(
-      rewriter, &region->front(), &converter, entryConversion);
-  return newEntry;
-}
-
-LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  regionToConverter[region] = &converter;
-  if (region->empty())
-    return success();
-
-  // Convert the arguments of each block within the region.
-  int blockIdx = 0;
-  assert((blockConversions.empty() ||
-          blockConversions.size() == region->getBlocks().size() - 1) &&
-         "expected either to provide no SignatureConversions at all or to "
-         "provide a SignatureConversion for each non-entry block");
-
+  // Convert the arguments of each non-entry block within the region.
   for (Block &block :
        llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
-    TypeConverter::SignatureConversion *blockConversion =
-        blockConversions.empty()
-            ? nullptr
-            : const_cast<TypeConverter::SignatureConversion *>(
-                  &blockConversions[blockIdx++]);
-
-    if (failed(convertBlockSignature(rewriter, &block, &converter,
-                                     blockConversion)))
+    // Compute the signature for the block with the provided converter.
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter.convertBlockSignature(&block);
+    if (!conversion)
       return failure();
-  }
-  return success();
+    // Convert the block with the computed signature.
+    applySignatureConversion(rewriter, &block, &converter, *conversion);
+  }
+
+  // Convert the entry block. If an entry signature conversion was provided,
+  // use that one. Otherwise, compute the signature with the type converter.
+  if (entryConversion)
+    return applySignatureConversion(rewriter, &region->front(), &converter,
+                                    *entryConversion);
+  std::optional<TypeConverter::SignatureConversion> conversion =
+      converter.convertBlockSignature(&region->front());
+  if (!conversion)
+    return failure();
+  return applySignatureConversion(rewriter, &region->front(), &converter,
+                                  *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 }
 
 Block *ConversionPatternRewriter::applySignatureConversion(
-    Region *region, TypeConverter::SignatureConversion &conversion,
+    Block *block, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+  assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, region, conversion, converter);
+  return impl->applySignatureConversion(*this, block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   return impl->convertRegionTypes(*this, region, converter, entryConversion);
 }
 
-LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
-    Region *region, const TypeConverter &converter,
-    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
-  assert(!impl->wasOpReplaced(region->getParentOp()) &&
-         "attempting to apply a signature conversion to a block within a "
-         "replaced/erased op");
-  return impl->convertNonEntryRegionTypes(*this, region, converter,
-                                          blockConversions);
-}
-
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                                                            Value to) {
   LLVM_DEBUG({
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // If the region of the block has a type converter, try to convert the block
     // directly.
     if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
-      if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
+      std::optional<TypeConverter::SignatureConversion> conversion =
+          converter->convertBlockSignature(block);
+      if (!conversion) {
         LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
                                            "block"));
         return failure();
       }
+      impl.applySignatureConversion(rewriter, block, converter, *conversion);
       continue;
     }
 



More information about the Mlir-commits mailing list