[Mlir-commits] [mlir] 9c4611f - [mlir] Implement pass utils for 1:N type conversions.

Ingo Müller llvmlistbot at llvm.org
Mon Mar 27 02:02:34 PDT 2023


Author: Ingo Müller
Date: 2023-03-27T09:02:28Z
New Revision: 9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0

URL: https://github.com/llvm/llvm-project/commit/9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0
DIFF: https://github.com/llvm/llvm-project/commit/9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0.diff

LOG: [mlir] Implement pass utils for 1:N type conversions.

The current dialect conversion does not support 1:N type conversions.
This commit implements a (poor-man's) dialect conversion pass that does
just that. To keep the pass independent of the "real" dialect conversion
infrastructure, it provides a specialization of the TypeConverter class
that allows for N:1 target materializations, a specialization of the
RewritePattern and PatternRewriter classes that automatically add
appropriate unrealized casts supporting 1:N type conversions and provide
converted operands for implementing subclasses, and a conversion driver
that applies the provided patterns and replaces the unrealized casts
that haven't folded away with user-provided materializations.

The current pass is powerful enough to express many existing manual
solutions for 1:N type conversions or extend transforms that previously
didn't support them, out of which this patch implements call graph type
decomposition (which is currently implemented with a ValueDecomposer
that is only used there).

The goal of this pass is to illustrate the effect that 1:N type
conversions could have, gain experience in how patterns should be
written that achieve that effect, and get feedback on how the APIs of
the dialect conversion should be extended or changed to support such
patterns. The hope is that the "real" dialect conversion eventually
supports such patterns, at which point, this pass could be removed
again.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D144469

Added: 
    mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
    mlir/include/mlir/Transforms/OneToNTypeConversion.h
    mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
    mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
    mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
    mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp

Modified: 
    mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Utils/CMakeLists.txt
    mlir/test/Transforms/decompose-call-graph-types.mlir
    mlir/test/lib/Conversion/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
new file mode 100644
index 0000000000000..2fba342ea80e7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
@@ -0,0 +1,26 @@
+//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
+#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
+
+namespace mlir {
+class TypeConverter;
+class RewritePatternSet;
+} // namespace mlir
+
+namespace mlir {
+
+// Populates the provided pattern set with patterns that do 1:N type conversions
+// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
+void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
+                                        RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H

diff  --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
new file mode 100644
index 0000000000000..25beee28c6ed5
--- /dev/null
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -0,0 +1,256 @@
+//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file provides utils for implementing (poor-man's) dialect conversion
+// passes with 1:N type conversions.
+//
+// The main function, `applyPartialOneToNConversion`, first applies a set of
+// `RewritePattern`s, which produce unrealized casts to convert the operands and
+// results from and to the source types, and then replaces all newly added
+// unrealized casts by user-provided materializations. For this to work, the
+// main function requires a special `TypeConverter`, a special
+// `PatternRewriter`, and special RewritePattern`s, which extend their
+// respective base classes for 1:N type converions.
+//
+// Note that this is much more simple-minded than the "real" dialect conversion,
+// which checks for legality before applying patterns and does probably many
+// other additional things. Ideally, some of the extensions here could be
+// integrated there.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
+#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+
+/// Extends `TypeConverter` with 1:N target materializations. Such
+/// materializations have to provide the "reverse" of 1:N type conversions,
+/// i.e., they need to materialize N values with target types into one value
+/// with a source type (which isn't possible in the base class currently).
+class OneToNTypeConverter : public TypeConverter {
+public:
+  /// Callback that expresses user-provided materialization logic from the given
+  /// value to N values of the given types. This is useful for expressing target
+  /// materializations for 1:N type conversions, which materialize one value in
+  /// a source type as N values in target types.
+  using OneToNMaterializationCallbackFn =
+      std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
+                                                      Value, Location)>;
+
+  /// Creates the mapping of the given range of original types to target types
+  /// of the conversion and stores that mapping in the given (signature)
+  /// conversion. This function simply calls
+  /// `TypeConverter::convertSignatureArgs` and exists here with a 
diff erent
+  /// name to reflect the broader semantic.
+  LogicalResult computeTypeMapping(TypeRange types,
+                                   SignatureConversion &result) {
+    return convertSignatureArgs(types, result);
+  }
+
+  /// Applies one of the user-provided 1:N target materializations. If several
+  /// exists, they are tried out in the reverse order in which they have been
+  /// added until the first one succeeds. If none succeeds, the functions
+  /// returns `std::nullopt`.
+  std::optional<SmallVector<Value>>
+  materializeTargetConversion(OpBuilder &builder, Location loc,
+                              TypeRange resultTypes, Value input) const;
+
+  /// Adds a 1:N target materialization to the converter. Such materializations
+  /// build IR that converts N values with target types into 1 value of the
+  /// source type.
+  void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
+    oneToNTargetMaterializations.emplace_back(std::move(callback));
+  }
+
+private:
+  SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
+};
+
+/// Stores a 1:N mapping of types and provides several useful accessors. This
+/// class extends `SignatureConversion`, which already supports 1:N type
+/// mappings but lacks some accessors into the mapping as well as access to the
+/// original types.
+class OneToNTypeMapping : public TypeConverter::SignatureConversion {
+public:
+  OneToNTypeMapping(TypeRange originalTypes)
+      : TypeConverter::SignatureConversion(originalTypes.size()),
+        originalTypes(originalTypes) {}
+
+  using TypeConverter::SignatureConversion::getConvertedTypes;
+
+  /// Returns the list of types that corresponds to the original type at the
+  /// given index.
+  TypeRange getConvertedTypes(unsigned originalTypeNo) const;
+
+  /// Returns the list of original types.
+  TypeRange getOriginalTypes() const { return originalTypes; }
+
+  /// Returns the slice of converted values that corresponds the original value
+  /// at the given index.
+  ValueRange getConvertedValues(ValueRange convertedValues,
+                                unsigned originalValueNo) const;
+
+  /// Fills the given result vector with as many copies of the location of the
+  /// original value as the number of values it is converted to.
+  void convertLocation(Value originalValue, unsigned originalValueNo,
+                       llvm::SmallVectorImpl<Location> &result) const;
+
+  /// Fills the given result vector with as many copies of the lociation of each
+  /// original value as the number of values they are respectively converted to.
+  void convertLocations(ValueRange originalValues,
+                        llvm::SmallVectorImpl<Location> &result) const;
+
+  /// Returns true iff at least one type conversion maps an input type to a type
+  /// that is 
diff erent from itself.
+  bool hasNonIdentityConversion() const;
+
+private:
+  llvm::SmallVector<Type> originalTypes;
+};
+
+/// Extends the basic `RewritePattern` class with a type converter member and
+/// some accessors to it. This is useful for patterns that are not
+/// `ConversionPattern`s but still require access to a type converter.
+class RewritePatternWithConverter : public mlir::RewritePattern {
+public:
+  /// Construct a conversion pattern with the given converter, and forward the
+  /// remaining arguments to RewritePattern.
+  template <typename... Args>
+  RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
+      : RewritePattern(std::forward<Args>(args)...),
+        typeConverter(&typeConverter) {}
+
+  /// Return the type converter held by this pattern, or nullptr if the pattern
+  /// does not require type conversion.
+  TypeConverter *getTypeConverter() const { return typeConverter; }
+
+  template <typename ConverterTy>
+  std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
+                   ConverterTy *>
+  getTypeConverter() const {
+    return static_cast<ConverterTy *>(typeConverter);
+  }
+
+protected:
+  /// A type converter for use by this pattern.
+  TypeConverter *const typeConverter;
+};
+
+/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
+/// class provides additional rewrite methods that are specific to 1:N type
+/// conversions.
+class OneToNPatternRewriter : public PatternRewriter {
+public:
+  OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {}
+
+  /// Replaces the results of the operation with the specified list of values
+  /// mapped back to the original types as specified in the provided type
+  /// mapping. That type mapping must match the replaced op (i.e., the original
+  /// types must be the same as the result types of the op) and the new values
+  /// (i.e., the converted types must be the same as the types of the new
+  /// values).
+  void replaceOp(Operation *op, ValueRange newValues,
+                 const OneToNTypeMapping &resultMapping);
+  using PatternRewriter::replaceOp;
+
+  /// Applies the given argument conversion to the given block. This consists of
+  /// replacing each original argument with N arguments as specified in the
+  /// argument conversion and inserting unrealized casts from the converted
+  /// values to the original types, which are then used in lieu of the original
+  /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
+  /// with a user-provided argument materialization if necessary.) This is
+  /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
+  /// type conversion properly and probably (2) doesn't handle many other edge
+  /// cases.
+  Block *applySignatureConversion(Block *block,
+                                  OneToNTypeMapping &argumentConversion);
+};
+
+/// Base class for patterns with 1:N type conversions. Derived classes have to
+/// overwrite the `matchAndRewrite` overlaod that provides additional
+/// information for 1:N type conversions.
+class OneToNConversionPattern : public RewritePatternWithConverter {
+public:
+  using RewritePatternWithConverter::RewritePatternWithConverter;
+
+  /// This function has to be implemented by base classes and is called from the
+  /// usual overloads. Like in "normal" `DialectConversion`, the function is
+  /// provided with the converted operands (which thus have target types). Since
+  /// 1:N conversion are supported, there is usually no 1:1 relationship between
+  /// the original and the converted operands. Instead, the provided
+  /// `operandMapping` can be used to access the converted operands that
+  /// correspond to a particular original operand. Similarly, `resultMapping`
+  /// is provided to help with assembling the result values, which may have 1:N
+  /// correspondences as well. In that case, the original op should be replaced
+  /// with the overload of `replaceOp` that takes the provided `resultMapping`
+  /// in order to deal with the mapping of converted result values to their
+  /// usages in the original types correctly.
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        OneToNPatternRewriter &rewriter,
+                                        const OneToNTypeMapping &operandMapping,
+                                        const OneToNTypeMapping &resultMapping,
+                                        ValueRange convertedOperands) const = 0;
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final;
+};
+
+/// This class is a wrapper around `OneToNConversionPattern` for matching
+/// against instances of a particular op class.
+template <typename SourceOp>
+class OneToNOpConversionPattern : public OneToNConversionPattern {
+public:
+  OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
+                            PatternBenefit benefit = 1,
+                            ArrayRef<StringRef> generatedNames = {})
+      : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
+                                benefit, context, generatedNames) {}
+
+  using OneToNConversionPattern::matchAndRewrite;
+
+  /// Overload that derived classes have to override for their op type.
+  virtual LogicalResult matchAndRewrite(SourceOp op,
+                                        OneToNPatternRewriter &rewriter,
+                                        const OneToNTypeMapping &operandMapping,
+                                        const OneToNTypeMapping &resultMapping,
+                                        ValueRange convertedOperands) const = 0;
+
+  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping &resultMapping,
+                                ValueRange convertedOperands) const final {
+    return matchAndRewrite(cast<SourceOp>(op), rewriter, operandMapping,
+                           resultMapping, convertedOperands);
+  }
+};
+
+/// Applies the given set of patterns recursively on the given op and adds user
+/// materializations where necessary. The patterns are expected to be
+/// `OneToNConversionPattern`, which help converting the types of the operands
+/// and results of the matched ops. The provided type converter is used to
+/// convert the operands of matched ops from their original types to operands
+/// with 
diff erent types. Unlike in `DialectConversion`, this supports 1:N type
+/// conversions. Those conversions at the "boundary" of the pattern application,
+/// where converted results are not consumed by replaced ops that expect the
+/// converted operands or vice versa, the function inserts user materializations
+/// from the type converter. Also unlike `DialectConversion`, there are no legal
+/// or illegal types; the function simply applies the given patterns and does
+/// not fail if some ops or types remain unconverted (i.e., the conversion is
+/// only "partial").
+LogicalResult
+applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+                             const FrozenRewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H

diff  --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 9a5b38ba6ea2c..172019907c3a8 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRFuncTransforms
   DuplicateFunctionElimination.cpp
   FuncBufferize.cpp
   FuncConversions.cpp
+  OneToNFuncConversions.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms

diff  --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
new file mode 100644
index 0000000000000..5e8125ca94283
--- /dev/null
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -0,0 +1,132 @@
+//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The patterns in this file are heavily inspired (and copied from)
+// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
+// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
+// type conversions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+using namespace mlir::func;
+
+namespace {
+
+class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
+public:
+  using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
+
+  LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping &resultMapping,
+                                ValueRange convertedOperands) const override {
+    Location loc = op->getLoc();
+
+    // Nothing to do if the op doesn't have any non-identity conversions for its
+    // operands or results.
+    if (!operandMapping.hasNonIdentityConversion() &&
+        !resultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Create new CallOp.
+    auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
+                                         convertedOperands);
+    newOp->setAttrs(op->getAttrs());
+
+    rewriter.replaceOp(op, newOp->getResults(), resultMapping);
+    return success();
+  }
+};
+
+class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
+public:
+  using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter,
+                  const OneToNTypeMapping & /*operandMapping*/,
+                  const OneToNTypeMapping & /*resultMapping*/,
+                  ValueRange /*convertedOperands*/) const override {
+    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+
+    // Construct mapping for function arguments.
+    OneToNTypeMapping argumentMapping(op.getArgumentTypes());
+    if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
+                                                 argumentMapping)))
+      return failure();
+
+    // Construct mapping for function results.
+    OneToNTypeMapping funcResultMapping(op.getResultTypes());
+    if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
+                                                 funcResultMapping)))
+      return failure();
+
+    // Nothing to do if the op doesn't have any non-identity conversions for its
+    // operands or results.
+    if (!argumentMapping.hasNonIdentityConversion() &&
+        !funcResultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Update the function signature in-place.
+    auto newType = FunctionType::get(rewriter.getContext(),
+                                     argumentMapping.getConvertedTypes(),
+                                     funcResultMapping.getConvertedTypes());
+    rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
+
+    // Update block signatures.
+    if (!op.isExternal()) {
+      Region *region = &op.getBody();
+      Block *block = &region->front();
+      rewriter.applySignatureConversion(block, argumentMapping);
+    }
+
+    return success();
+  }
+};
+
+class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
+public:
+  using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
+
+  LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping & /*resultMapping*/,
+                                ValueRange convertedOperands) const override {
+    // Nothing to do if there is no non-identity conversion.
+    if (!operandMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Convert operands.
+    rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
+
+    return success();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+
+void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
+                                        RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+      ConvertTypesInFuncCallOp,
+      ConvertTypesInFuncFuncOp,
+      ConvertTypesInFuncReturnOp
+      // clang-format on
+      >(typeConverter, patterns.getContext());
+}
+
+} // namespace mlir

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index ba8fa20faf592..6892d00d1d743 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRTransformUtils
   GreedyPatternRewriteDriver.cpp
   InliningUtils.cpp
   LoopInvariantCodeMotionUtils.cpp
+  OneToNTypeConversion.cpp
   RegionUtils.cpp
   TopologicalSortUtils.cpp
 

diff  --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
new file mode 100644
index 0000000000000..c0866f87f6833
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -0,0 +1,405 @@
+//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+
+using namespace llvm;
+using namespace mlir;
+
+std::optional<SmallVector<Value>>
+OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc,
+                                                 TypeRange resultTypes,
+                                                 Value input) const {
+  for (const OneToNMaterializationCallbackFn &fn :
+       llvm::reverse(oneToNTargetMaterializations)) {
+    if (std::optional<SmallVector<Value>> result =
+            fn(builder, resultTypes, input, loc))
+      return *result;
+  }
+  return std::nullopt;
+}
+
+TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
+  TypeRange convertedTypes = getConvertedTypes();
+  if (auto mapping = getInputMapping(originalTypeNo))
+    return convertedTypes.slice(mapping->inputNo, mapping->size);
+  return {};
+}
+
+ValueRange
+OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
+                                      unsigned originalValueNo) const {
+  if (auto mapping = getInputMapping(originalValueNo))
+    return convertedValues.slice(mapping->inputNo, mapping->size);
+  return {};
+}
+
+void OneToNTypeMapping::convertLocation(
+    Value originalValue, unsigned originalValueNo,
+    llvm::SmallVectorImpl<Location> &result) const {
+  if (auto mapping = getInputMapping(originalValueNo))
+    result.append(mapping->size, originalValue.getLoc());
+}
+
+void OneToNTypeMapping::convertLocations(
+    ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
+  assert(originalValues.size() == getOriginalTypes().size());
+  for (auto [i, value] : llvm::enumerate(originalValues))
+    convertLocation(value, i, result);
+}
+
+static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
+  return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
+}
+
+bool OneToNTypeMapping::hasNonIdentityConversion() const {
+  // XXX: I think that the original types and the converted types are the same
+  //      iff there was no non-identity type conversion. If that is true, the
+  //      patterns could actually test whether there is anything useful to do
+  //      without having access to the signature conversion.
+  for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
+    TypeRange types = getConvertedTypes(i);
+    if (!isIdentityConversion(originalType, types)) {
+      assert(TypeRange(originalTypes) != getConvertedTypes());
+      return true;
+    }
+  }
+  assert(TypeRange(originalTypes) == getConvertedTypes());
+  return false;
+}
+
+namespace {
+enum class CastKind {
+  // Casts block arguments in the target type back to the source type. (If
+  // necessary, this cast becomes an argument materialization.)
+  Argument,
+
+  // Casts other values in the target type back to the source type. (If
+  // necessary, this cast becomes a source materialization.)
+  Source,
+
+  // Casts values in the source type to the target type. (If necessary, this
+  // cast becomes a target materialization.)
+  Target
+};
+}
+
+/// Mapping of enum values to string values.
+StringRef getCastKindName(CastKind kind) {
+  static const std::unordered_map<CastKind, StringRef> castKindNames = {
+      {CastKind::Argument, "argument"},
+      {CastKind::Source, "source"},
+      {CastKind::Target, "target"}};
+  return castKindNames.at(kind);
+}
+
+/// Attribute name that is used to annotate inserted unrealized casts with their
+/// kind (source, argument, or target).
+static const char *const castKindAttrName =
+    "__one-to-n-type-conversion_cast-kind__";
+
+/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
+/// result types. Returns the result values of the cast.
+static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
+                                      ValueRange inputs, CastKind kind) {
+  // Create cast.
+  Location loc = builder.getUnknownLoc();
+  if (!inputs.empty())
+    loc = inputs.front().getLoc();
+  auto castOp =
+      builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
+
+  // Store cast kind as attribute.
+  auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
+  castOp->setAttr(castKindAttrName, kindAttr);
+
+  return castOp->getResults();
+}
+
+/// Builds one `UnrealizedConversionCastOp` for each of the given original
+/// values using the respective target types given in the provided conversion
+/// mapping and returns the results of these casts. If the conversion mapping of
+/// a value maps a type to itself (i.e., is an identity conversion), then no
+/// cast is inserted and the original value is returned instead.
+/// Note that these unrealized casts are 
diff erent from target materializations
+/// in that they are *always* inserted, even if they immediately fold away, such
+/// that patterns always see valid intermediate IR, whereas materializations are
+/// only used in the places where the unrealized casts *don't* fold away.
+static SmallVector<Value>
+buildUnrealizedForwardCasts(ValueRange originalValues,
+                            OneToNTypeMapping &conversion,
+                            RewriterBase &rewriter, CastKind kind) {
+
+  // Convert each operand one by one.
+  SmallVector<Value> convertedValues;
+  convertedValues.reserve(conversion.getConvertedTypes().size());
+  for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
+    TypeRange convertedTypes = conversion.getConvertedTypes(idx);
+
+    // Identity conversion: keep operand as is.
+    if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
+      convertedValues.push_back(originalValue);
+      continue;
+    }
+
+    // Non-identity conversion: materialize target types.
+    ValueRange castResult =
+        buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
+    convertedValues.append(castResult.begin(), castResult.end());
+  }
+
+  return convertedValues;
+}
+
+/// Builds one `UnrealizedConversionCastOp` for each sequence of the given
+/// original values to one value of the type they originated from, i.e., a
+/// "reverse" conversion from N converted values back to one value of the
+/// original type, using the given (forward) type conversion. If a given value
+/// was mapped to a value of the same type (i.e., the conversion in the mapping
+/// is an identity conversion), then the "converted" value is returned without
+/// cast.
+/// Note that these unrealized casts are 
diff erent from source materializations
+/// in that they are *always* inserted, even if they immediately fold away, such
+/// that patterns always see valid intermediate IR, whereas materializations are
+/// only used in the places where the unrealized casts *don't* fold away.
+static SmallVector<Value>
+buildUnrealizedBackwardsCasts(ValueRange convertedValues,
+                              const OneToNTypeMapping &typeConversion,
+                              RewriterBase &rewriter) {
+  assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
+
+  // Create unrealized cast op for each converted result of the op.
+  SmallVector<Value> recastValues;
+  TypeRange originalTypes = typeConversion.getOriginalTypes();
+  recastValues.reserve(originalTypes.size());
+  auto convertedValueIt = convertedValues.begin();
+  for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
+    TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
+    size_t numConvertedValues = convertedTypes.size();
+    if (isIdentityConversion(originalType, convertedTypes)) {
+      // Identity conversion: take result as is.
+      recastValues.push_back(*convertedValueIt);
+    } else {
+      // Non-identity conversion: cast back to source type.
+      ValueRange recastValue = buildUnrealizedCast(
+          rewriter, originalType,
+          ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
+          CastKind::Source);
+      assert(recastValue.size() == 1);
+      recastValues.push_back(recastValue.front());
+    }
+    convertedValueIt += numConvertedValues;
+  }
+
+  return recastValues;
+}
+
+void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
+                                      const OneToNTypeMapping &resultMapping) {
+  // Create a cast back to the original types and replace the results of the
+  // original op with those.
+  assert(newValues.size() == resultMapping.getConvertedTypes().size());
+  assert(op->getResultTypes() == resultMapping.getOriginalTypes());
+  SmallVector<Value> castResults =
+      buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
+  replaceOp(op, castResults);
+}
+
+Block *OneToNPatternRewriter::applySignatureConversion(
+    Block *block, OneToNTypeMapping &argumentConversion) {
+  // Split the block at the beginning to get a new block to use for the
+  // updated signature.
+  SmallVector<Location> locs;
+  argumentConversion.convertLocations(block->getArguments(), locs);
+  Block *newBlock =
+      createBlock(block, argumentConversion.getConvertedTypes(), locs);
+  replaceAllUsesWith(block, newBlock);
+
+  // Create necessary casts in new block.
+  SmallVector<Value> castResults;
+  for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
+    TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
+    ValueRange newArgs =
+        argumentConversion.getConvertedValues(newBlock->getArguments(), i);
+    if (isIdentityConversion(arg.getType(), convertedTypes)) {
+      // Identity conversion: take argument as is.
+      assert(newArgs.size() == 1);
+      castResults.push_back(newArgs.front());
+    } else {
+      // Non-identity conversion: cast the converted arguments to the original
+      // type.
+      PatternRewriter::InsertionGuard g(*this);
+      setInsertionPointToStart(newBlock);
+      ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
+                                                  CastKind::Argument);
+      assert(castResult.size() == 1);
+      castResults.push_back(castResult.front());
+    }
+  }
+
+  // Merge old block into new block such that we only have the latter with the
+  // new signature.
+  mergeBlocks(block, newBlock, castResults);
+
+  return newBlock;
+}
+
+LogicalResult
+OneToNConversionPattern::matchAndRewrite(Operation *op,
+                                         PatternRewriter &rewriter) const {
+  auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+
+  // Construct conversion mapping for results.
+  Operation::result_type_range originalResultTypes = op->getResultTypes();
+  OneToNTypeMapping resultMapping(originalResultTypes);
+  if (failed(typeConverter->computeTypeMapping(originalResultTypes,
+                                               resultMapping)))
+    return failure();
+
+  // Construct conversion mapping for operands.
+  Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
+  OneToNTypeMapping operandMapping(originalOperandTypes);
+  if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
+                                               operandMapping)))
+    return failure();
+
+  // Cast operands to target types.
+  SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts(
+      op->getOperands(), operandMapping, rewriter, CastKind::Target);
+
+  // Create a `OneToNPatternRewriter` for the pattern, which provides additional
+  // functionality.
+  // TODO(ingomueller): I guess it would be better to use only one rewriter
+  //                    throughout the whole pass, but that would require to
+  //                    drive the pattern application ourselves, which is a lot
+  //                    of additional boilerplate code. This seems to work fine,
+  //                    so I leave it like this for the time being.
+  OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext());
+  oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
+  oneToNPatternRewriter.setListener(rewriter.getListener());
+
+  // Apply actual pattern.
+  if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
+                             resultMapping, convertedOperands)))
+    return failure();
+
+  return success();
+}
+
+namespace mlir {
+
+// This function applies the provided patterns using
+// `applyPatternsAndFoldGreedily` and then replaces all newly inserted
+// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
+// from target to source types inserted by a `OneToNConversionPattern` normally
+// fold away with the "forward" casts from source to target types inserted by
+// the next pattern.) To understand which casts are "newly inserted", all casts
+// inserted by this pass are annotated with a string attribute that also
+// documents which kind of the cast (source, argument, or target).
+LogicalResult
+applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+                             const FrozenRewritePatternSet &patterns) {
+#ifndef NDEBUG
+  // Remember existing unrealized casts. This data structure is only used in
+  // asserts; building it only for that purpose may be an overkill.
+  SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
+  op->walk([&](UnrealizedConversionCastOp castOp) {
+    assert(!castOp->hasAttr(castKindAttrName));
+    existingCasts.insert(castOp);
+  });
+#endif // NDEBUG
+
+  // Apply provided conversion patterns.
+  if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
+    emitError(op->getLoc()) << "failed to apply conversion patterns";
+    return failure();
+  }
+
+  // Find all unrealized casts inserted by the pass that haven't folded away.
+  SmallVector<UnrealizedConversionCastOp> worklist;
+  op->walk([&](UnrealizedConversionCastOp castOp) {
+    if (castOp->hasAttr(castKindAttrName)) {
+      assert(!existingCasts.contains(castOp));
+      worklist.push_back(castOp);
+    }
+  });
+
+  // Replace new casts with user materializations.
+  IRRewriter rewriter(op->getContext());
+  for (UnrealizedConversionCastOp castOp : worklist) {
+    TypeRange resultTypes = castOp->getResultTypes();
+    ValueRange operands = castOp->getOperands();
+    StringRef castKind =
+        castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
+    rewriter.setInsertionPoint(castOp);
+
+#ifndef NDEBUG
+    // Determine whether operands or results are already legal to test some
+    // assumptions for the 
diff erent kind of materializations. These properties
+    // are only used it asserts and it may be overkill to compute them.
+    bool areOperandTypesLegal = llvm::all_of(
+        operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
+    bool areResultsTypesLegal = llvm::all_of(
+        resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
+#endif // NDEBUG
+
+    // Add materialization and remember materialized results.
+    SmallVector<Value> materializedResults;
+    if (castKind == getCastKindName(CastKind::Target)) {
+      // Target materialization.
+      assert(!areOperandTypesLegal && areResultsTypesLegal &&
+             operands.size() == 1 && "found unexpected target cast");
+      std::optional<SmallVector<Value>> maybeResults =
+          typeConverter.materializeTargetConversion(
+              rewriter, castOp->getLoc(), resultTypes, operands.front());
+      if (!maybeResults) {
+        emitError(castOp->getLoc())
+            << "failed to create target materialization";
+        return failure();
+      }
+      materializedResults = maybeResults.value();
+    } else {
+      // Source and argument materializations.
+      assert(areOperandTypesLegal && !areResultsTypesLegal &&
+             resultTypes.size() == 1 && "found unexpected cast");
+      std::optional<Value> maybeResult;
+      if (castKind == getCastKindName(CastKind::Source)) {
+        // Source materialization.
+        maybeResult = typeConverter.materializeSourceConversion(
+            rewriter, castOp->getLoc(), resultTypes.front(),
+            castOp.getOperands());
+      } else {
+        // Argument materialization.
+        assert(castKind == getCastKindName(CastKind::Argument) &&
+               "unexpected value of cast kind attribute");
+        assert(llvm::all_of(operands,
+                            [&](Value v) { return v.isa<BlockArgument>(); }));
+        maybeResult = typeConverter.materializeArgumentConversion(
+            rewriter, castOp->getLoc(), resultTypes.front(),
+            castOp.getOperands());
+      }
+      if (!maybeResult.has_value() || !maybeResult.value()) {
+        emitError(castOp->getLoc())
+            << "failed to create " << castKind << " materialization";
+        return failure();
+      }
+      materializedResults = {maybeResult.value()};
+    }
+
+    // Replace the cast with the result of the materialization.
+    rewriter.replaceOp(castOp, materializedResults);
+  }
+
+  return success();
+}
+
+} // namespace mlir

diff  --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index 604e948afaf6e..51b63ba4c0ad9 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s
 
+// RUN: mlir-opt %s -split-input-file \
+// RUN:   -test-one-to-n-type-conversion="convert-func-ops" \
+// RUN: | FileCheck %s --check-prefix=CHECK-12N
+
 // Test case: Most basic case of a 1:N decomposition, an identity function.
 
 // CHECK-LABEL:   func @identity(
@@ -9,6 +13,10 @@
 // CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
 // CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
 // CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK-12N-LABEL:   func @identity(
+// CHECK-12N-SAME:                   %[[ARG0:.*]]: i1,
+// CHECK-12N-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK-12N:           return %[[ARG0]], %[[ARG1]] : i1, i32
 func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
   return %arg0 : tuple<i1, i32>
 }
@@ -20,6 +28,9 @@ func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
 // CHECK-LABEL:   func @identity_1_to_1_no_materializations(
 // CHECK-SAME:                                              %[[ARG0:.*]]: i1) -> i1 {
 // CHECK:           return %[[ARG0]] : i1
+// CHECK-12N-LABEL:   func @identity_1_to_1_no_materializations(
+// CHECK-12N-SAME:                                              %[[ARG0:.*]]: i1) -> i1 {
+// CHECK-12N:           return %[[ARG0]] : i1
 func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
   return %arg0 : tuple<i1>
 }
@@ -31,6 +42,9 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
 // CHECK-LABEL:   func @recursive_decomposition(
 // CHECK-SAME:                                   %[[ARG0:.*]]: i1) -> i1 {
 // CHECK:           return %[[ARG0]] : i1
+// CHECK-12N-LABEL:   func @recursive_decomposition(
+// CHECK-12N-SAME:                                   %[[ARG0:.*]]: i1) -> i1 {
+// CHECK-12N:           return %[[ARG0]] : i1
 func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tuple<tuple<i1>>> {
   return %arg0 : tuple<tuple<tuple<i1>>>
 }
@@ -54,6 +68,10 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
 // CHECK:           %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
 // CHECK:           %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
 // CHECK:           return %[[V7]], %[[V10]] : i1, i2
+// CHECK-12N-LABEL:   func @mixed_recursive_decomposition(
+// CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-12N-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
+// CHECK-12N:           return %[[ARG0]], %[[ARG1]] : i1, i2
 func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> {
   return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
 }
@@ -63,6 +81,7 @@ func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<
 // Test case: Check decomposition of calls.
 
 // CHECK-LABEL:   func private @callee(i1, i32) -> (i1, i32)
+// CHECK-12N-LABEL:   func private @callee(i1, i32) -> (i1, i32)
 func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
 
 // CHECK-LABEL:   func @caller(
@@ -76,6 +95,11 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
 // CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
 // CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
 // CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK-12N-LABEL:   func @caller(
+// CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-12N-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK-12N:           %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK-12N:           return %[[V0]]#0, %[[V0]]#1 : i1, i32
 func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
   %0 = call @callee(%arg0) : (tuple<i1, i32>) -> tuple<i1, i32>
   return %0 : tuple<i1, i32>
@@ -86,10 +110,15 @@ func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
 // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition).
 
 // CHECK-LABEL:   func private @callee()
+// CHECK-12N-LABEL:   func private @callee()
 func.func private @callee(tuple<>) -> tuple<>
+
 // CHECK-LABEL:   func @caller() {
 // CHECK:           call @callee() : () -> ()
 // CHECK:           return
+// CHECK-12N-LABEL:   func @caller() {
+// CHECK-12N:           call @callee() : () -> ()
+// CHECK-12N:           return
 func.func @caller(%arg0: tuple<>) -> tuple<> {
   %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>)
   return %0 : tuple<>
@@ -105,6 +134,11 @@ func.func @caller(%arg0: tuple<>) -> tuple<> {
 // CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
 // CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
 // CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK-12N-LABEL:   func @unconverted_op_result() -> (i1, i32) {
+// CHECK-12N:           %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
+// CHECK-12N:           %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
+// CHECK-12N:           %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
+// CHECK-12N:           return %[[RET0]], %[[RET1]] : i1, i32
 func.func @unconverted_op_result() -> tuple<i1, i32> {
   %0 = "test.source"() : () -> (tuple<i1, i32>)
   return %0 : tuple<i1, i32>
@@ -125,6 +159,16 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
 // CHECK:           %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
 // CHECK:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
 // CHECK:           return %[[V3]], %[[V5]] : i1, i32
+// CHECK-12N-LABEL:   func @nested_unconverted_op_result(
+// CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
+// CHECK-12N-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK-12N:           %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
+// CHECK-12N:           %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
+// CHECK-12N:           %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
+// CHECK-12N:           %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
+// CHECK-12N:           %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
+// CHECK-12N:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
+// CHECK-12N:           return %[[V3]], %[[V5]] : i1, i32
 func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
   %0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
   return %0 : tuple<i1, tuple<i32>>
@@ -136,6 +180,7 @@ func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1
 // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
 
 // CHECK-LABEL:   func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK-12N-LABEL:   func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
 func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
 
 // CHECK-LABEL:   func @caller(
@@ -153,6 +198,15 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
 // CHECK:           %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
 // CHECK:           %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
 // CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+// CHECK-12N-LABEL:   func @caller(
+// CHECK-12N-SAME:                 %[[I1:.*]]: i1,
+// CHECK-12N-SAME:                 %[[I2:.*]]: i2,
+// CHECK-12N-SAME:                 %[[I3:.*]]: i3,
+// CHECK-12N-SAME:                 %[[I4:.*]]: i4,
+// CHECK-12N-SAME:                 %[[I5:.*]]: i5,
+// CHECK-12N-SAME:                 %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
+// CHECK-12N:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK-12N:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple<i2>, %arg3: i3, %arg4: tuple<i4, i5>, %arg5: i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) {
   %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
   return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6

diff  --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index 14f0e0dbe1802..14df652ac7dfd 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(FuncToLLVM)
+add_subdirectory(OneToNTypeConversion)
 add_subdirectory(VectorToSPIRV)

diff  --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
new file mode 100644
index 0000000000000..418978688c90d
--- /dev/null
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_library(MLIRTestOneToNTypeConversionPass
+  TestOneToNTypeConversionPass.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRFuncTransforms
+  MLIRIR
+  MLIRTestDialect
+  MLIRTransformUtils
+ )
+
+target_include_directories(MLIRTestOneToNTypeConversionPass
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
+  )

diff  --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
new file mode 100644
index 0000000000000..220bcb58bf788
--- /dev/null
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -0,0 +1,245 @@
+//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms
+/// in `applyPartialOneToNConversion` by converting built-in tuples to the
+/// elements they consist of as well as some dummy ops operating on these
+/// tuples.
+struct TestOneToNTypeConversionPass
+    : public PassWrapper<TestOneToNTypeConversionPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass)
+
+  TestOneToNTypeConversionPass() = default;
+  TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<test::TestDialect>();
+  }
+
+  StringRef getArgument() const final {
+    return "test-one-to-n-type-conversion";
+  }
+
+  StringRef getDescription() const final {
+    return "Test pass for 1:N type conversion";
+  }
+
+  Option<bool> convertFuncOps{*this, "convert-func-ops",
+                              llvm::cl::desc("Enable conversion on func ops"),
+                              llvm::cl::init(false)};
+
+  Option<bool> convertTupleOps{*this, "convert-tuple-ops",
+                               llvm::cl::desc("Enable conversion on tuple ops"),
+                               llvm::cl::init(false)};
+
+  void runOnOperation() override;
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestOneToNTypeConversionPass() {
+  PassRegistration<TestOneToNTypeConversionPass>();
+}
+} // namespace test
+} // namespace mlir
+
+namespace {
+
+/// Test pattern on for the `make_tuple` op from the test dialect that converts
+/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple
+/// that is being produced by `test.make_tuple`, which are really just the
+/// operands of this op.
+class ConvertMakeTupleOp
+    : public OneToNOpConversionPattern<::test::MakeTupleOp> {
+public:
+  using OneToNOpConversionPattern<
+      ::test::MakeTupleOp>::OneToNOpConversionPattern;
+
+  LogicalResult matchAndRewrite(::test::MakeTupleOp op,
+                                OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping &resultMapping,
+                                ValueRange convertedOperands) const override {
+    // Simply replace the current op with the converted operands.
+    rewriter.replaceOp(op, convertedOperands, resultMapping);
+    return success();
+  }
+};
+
+/// Test pattern on for the `get_tuple_element` op from the test dialect that
+/// converts this kind of op into it's "decomposed" form, i.e., instead of
+/// "physically" extracting one element from the tuple, we forward the one
+/// element of the decomposed form that is being extracted (or the several
+/// elements in case that element is a nested tuple).
+class ConvertGetTupleElementOp
+    : public OneToNOpConversionPattern<::test::GetTupleElementOp> {
+public:
+  using OneToNOpConversionPattern<
+      ::test::GetTupleElementOp>::OneToNOpConversionPattern;
+
+  LogicalResult matchAndRewrite(::test::GetTupleElementOp op,
+                                OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping &resultMapping,
+                                ValueRange convertedOperands) const override {
+    // Construct mapping for tuple element types.
+    auto stateType = op->getOperand(0).getType().cast<TupleType>();
+    TypeRange originalElementTypes = stateType.getTypes();
+    OneToNTypeMapping elementMapping(originalElementTypes);
+    if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
+                                                   elementMapping)))
+      return failure();
+
+    // Compute converted operands corresponding to original input tuple.
+    ValueRange convertedTuple =
+        operandMapping.getConvertedValues(convertedOperands, 0);
+
+    // Got those converted operands that correspond to the index-th element of
+    // the original input tuple.
+    size_t index = op.getIndex();
+    ValueRange extractedElement =
+        elementMapping.getConvertedValues(convertedTuple, index);
+
+    rewriter.replaceOp(op, extractedElement, resultMapping);
+
+    return success();
+  }
+};
+
+} // namespace
+
+static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter,
+                                                RewritePatternSet &patterns) {
+  patterns.add<
+      // clang-format off
+      ConvertMakeTupleOp,
+      ConvertGetTupleElementOp
+      // clang-format on
+      >(typeConverter, patterns.getContext());
+}
+
+/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
+/// given tuple value. If some tuple elements are, in turn, tuples, the elements
+/// of those are extracted recursively such that the returned values have the
+/// same types as `resultTypes.getFlattenedTypes()`.
+///
+/// This function has been copied (with small adaptions) from
+/// TestDecomposeCallGraphTypes.cpp.
+static std::optional<SmallVector<Value>>
+buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
+                        Location loc) {
+  TupleType inputType = input.getType().dyn_cast<TupleType>();
+  if (!inputType)
+    return {};
+
+  SmallVector<Value> values;
+  for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
+    Value element = builder.create<::test::GetTupleElementOp>(
+        loc, elementType, input, builder.getI32IntegerAttr(idx));
+    if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+      // Recurse if the current element is also a tuple.
+      SmallVector<Type> flatRecursiveTypes;
+      nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
+      std::optional<SmallVector<Value>> resursiveValues =
+          buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
+      if (!resursiveValues.has_value())
+        return {};
+      values.append(resursiveValues.value());
+    } else {
+      values.push_back(element);
+    }
+  }
+  return values;
+}
+
+/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
+/// type `resultType`. If that type is nested, each nested tuple is built
+/// recursively with another `test.make_tuple` op.
+///
+/// This function has been copied (with small adaptions) from
+/// TestDecomposeCallGraphTypes.cpp.
+static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
+                                             TupleType resultType,
+                                             ValueRange inputs, Location loc) {
+  // Build one value for each element at this nesting level.
+  SmallVector<Value> elements;
+  elements.reserve(resultType.getTypes().size());
+  ValueRange::iterator inputIt = inputs.begin();
+  for (Type elementType : resultType.getTypes()) {
+    if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+      // Determine how many input values are needed for the nested elements of
+      // the nested TupleType and advance inputIt by that number.
+      // TODO: We only need the *number* of nested types, not the types itself.
+      //       Maybe it's worth adding a more efficient overload?
+      SmallVector<Type> nestedFlattenedTypes;
+      nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
+      size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
+      ValueRange nestedFlattenedelements(inputIt,
+                                         inputIt + numNestedFlattenedTypes);
+      inputIt += numNestedFlattenedTypes;
+
+      // Recurse on the values for the nested TupleType.
+      std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
+                                                  nestedFlattenedelements, loc);
+      if (!res.has_value())
+        return {};
+
+      // The tuple constructed by the conversion is the element value.
+      elements.push_back(res.value());
+    } else {
+      // Base case: take one input as is.
+      elements.push_back(*inputIt++);
+    }
+  }
+
+  // Assemble the tuple from the elements.
+  return builder.create<::test::MakeTupleOp>(loc, resultType, elements);
+}
+
+void TestOneToNTypeConversionPass::runOnOperation() {
+  ModuleOp module = getOperation();
+  auto *context = &getContext();
+
+  // Assemble type converter.
+  OneToNTypeConverter typeConverter;
+
+  typeConverter.addConversion([](Type type) { return type; });
+  typeConverter.addConversion(
+      [](TupleType tupleType, SmallVectorImpl<Type> &types) {
+        tupleType.getFlattenedTypes(types);
+        return success();
+      });
+
+  typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+  typeConverter.addSourceMaterialization(buildMakeTupleOp);
+  typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+
+  // Assemble patterns.
+  RewritePatternSet patterns(context);
+  if (convertTupleOps)
+    populateDecomposeTuplesTestPatterns(typeConverter, patterns);
+  if (convertFuncOps)
+    populateFuncTypeConversionPatterns(typeConverter, patterns);
+
+  // Run conversion.
+  if (failed(applyPartialOneToNConversion(module, typeConverter,
+                                          std::move(patterns))))
+    return signalPassFailure();
+}

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index f84fbe631cf16..c43056906da8c 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestDialect
     MLIRTestDynDialect
     MLIRTestIR
+    MLIRTestOneToNTypeConversionPass
     MLIRTestPass
     MLIRTestPDLL
     MLIRTestReducer

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e7ca06b4bef9d..855c6a6e0f28a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -107,6 +107,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
 void registerTestPadFusion();
 void registerTestPDLByteCodePass();
@@ -218,6 +219,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();
   mlir::test::registerTestPadFusion();
   mlir::test::registerTestPDLByteCodePass();


        


More information about the Mlir-commits mailing list