[llvm] [mlir] [mlir][Transforms] Delete 1:N dialect conversion driver (PR #121389)
Matthias Springer via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 3 07:37:55 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/121389
>From cded913a87864e39a31ede01bc13b35190942359 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 31 Dec 2024 13:32:25 +0100
Subject: [PATCH] [mlir][Transforms] Remove 1:N dialect conversion driver
---
.../Func/Transforms/OneToNFuncConversions.h | 26 -
.../mlir/Dialect/SCF/Transforms/Patterns.h | 10 -
.../SPIRV/Transforms/SPIRVConversion.h | 1 -
.../Dialect/SparseTensor/Transforms/Passes.h | 1 -
.../mlir/Transforms/DialectConversion.h | 20 -
.../mlir/Transforms/OneToNTypeConversion.h | 290 -----------
.../Dialect/Func/Transforms/CMakeLists.txt | 1 -
.../Func/Transforms/OneToNFuncConversions.cpp | 87 ----
.../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 -
.../SCF/Transforms/OneToNTypeConversion.cpp | 215 --------
.../SPIRV/Transforms/SPIRVConversion.cpp | 7 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 -
.../Transforms/Utils/DialectConversion.cpp | 11 -
.../Transforms/Utils/OneToNTypeConversion.cpp | 458 ------------------
.../one-to-n-type-conversion.mlir | 140 ------
...f-structural-one-to-n-type-conversion.mlir | 183 -------
.../decompose-call-graph-types.mlir | 53 --
mlir/test/lib/Conversion/CMakeLists.txt | 1 -
.../OneToNTypeConversion/CMakeLists.txt | 21 -
.../TestOneToNTypeConversionPass.cpp | 261 ----------
mlir/tools/mlir-opt/CMakeLists.txt | 1 -
mlir/tools/mlir-opt/mlir-opt.cpp | 2 -
.../llvm-project-overlay/mlir/BUILD.bazel | 2 -
.../mlir/test/BUILD.bazel | 17 -
24 files changed, 4 insertions(+), 1806 deletions(-)
delete mode 100644 mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
delete mode 100644 mlir/include/mlir/Transforms/OneToNTypeConversion.h
delete mode 100644 mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
delete mode 100644 mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
delete mode 100644 mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
delete mode 100644 mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
delete mode 100644 mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
deleted file mode 100644
index c9e407daf9bf8c..00000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
+++ /dev/null
@@ -1,26 +0,0 @@
-//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
-
-namespace mlir {
-class TypeConverter;
-class RewritePatternSet;
-} // namespace mlir
-
-namespace mlir {
-
-// Populates the provided pattern set with patterns that do 1:N type conversions
-// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
-void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 9c1479d28c305f..00c8a5c0c517b7 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -63,16 +63,6 @@ void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter,
void populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target);
-/// Populates the provided pattern set with patterns that do 1:N type
-/// conversions on (some) SCF ops. This is intended to be used with
-/// applyPartialOneToNConversion.
-/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
-/// 1:N support has been added to the regular dialect conversion driver.
-LLVM_DEPRECATED("Use populateSCFStructuralTypeConversions() instead",
- "populateSCFStructuralTypeConversions")
-void populateSCFStructuralOneToNTypeConversions(
- const TypeConverter &typeConverter, RewritePatternSet &patterns);
-
/// Populate patterns for SCF software pipelining transformation. See the
/// ForLoopPipeliningPattern for the transformation details.
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index bed4d66ccd6cbe..3d22ec918f4c5f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -20,7 +20,6 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/LogicalResult.h"
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 2e9c297f20182a..acd347c530d58b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -16,7 +16,6 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
//===----------------------------------------------------------------------===//
// Include the generated pass header (which needs some early definitions).
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfae..7e5389a83855a5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -45,13 +45,11 @@ class TypeConverter {
// Copy the registered conversions, but not the caches
TypeConverter(const TypeConverter &other)
: conversions(other.conversions),
- argumentMaterializations(other.argumentMaterializations),
sourceMaterializations(other.sourceMaterializations),
targetMaterializations(other.targetMaterializations),
typeAttributeConversions(other.typeAttributeConversions) {}
TypeConverter &operator=(const TypeConverter &other) {
conversions = other.conversions;
- argumentMaterializations = other.argumentMaterializations;
sourceMaterializations = other.sourceMaterializations;
targetMaterializations = other.targetMaterializations;
typeAttributeConversions = other.typeAttributeConversions;
@@ -177,21 +175,6 @@ class TypeConverter {
/// can be a TypeRange; in that case, the function must return a
/// SmallVector<Value>.
- /// This method registers a materialization that will be called when
- /// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value
- /// with the old block argument type.
- ///
- /// Note: Argument materializations are used only with the 1:N dialect
- /// conversion driver. The 1:N dialect conversion driver will be removed soon
- /// and so will be argument materializations.
- template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<1>>
- void addArgumentMaterialization(FnT &&callback) {
- argumentMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
- }
-
/// This method registers a materialization that will be called when
/// converting a replacement value back to its original source type.
/// This is used when some uses of the original value persist beyond the main
@@ -319,8 +302,6 @@ class TypeConverter {
/// generating a cast sequence of some kind. See the respective
/// `add*Materialization` for more information on the context for these
/// methods.
- Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const;
Value materializeSourceConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs) const;
Value materializeTargetConversion(OpBuilder &builder, Location loc,
@@ -507,7 +488,6 @@ class TypeConverter {
SmallVector<ConversionCallbackFn, 4> conversions;
/// The list of registered materialization functions.
- SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
deleted file mode 100644
index 9c74bf916d971b..00000000000000
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ /dev/null
@@ -1,290 +0,0 @@
-//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Note: The 1:N dialect conversion is deprecated and will be removed soon.
-// 1:N support has been added to the regular dialect conversion driver.
-//
-// This file provides utils for implementing (poor-man's) dialect conversion
-// passes with 1:N type conversions.
-//
-// The main function, `applyPartialOneToNConversion`, first applies a set of
-// `RewritePattern`s, which produce unrealized casts to convert the operands and
-// results from and to the source types, and then replaces all newly added
-// unrealized casts by user-provided materializations. For this to work, the
-// main function requires a special `TypeConverter`, a special
-// `PatternRewriter`, and special RewritePattern`s, which extend their
-// respective base classes for 1:N type converions.
-//
-// Note that this is much more simple-minded than the "real" dialect conversion,
-// which checks for legality before applying patterns and does probably many
-// other additional things. Ideally, some of the extensions here could be
-// integrated there.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
-#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SmallVector.h"
-
-namespace mlir {
-
-/// Stores a 1:N mapping of types and provides several useful accessors. This
-/// class extends `SignatureConversion`, which already supports 1:N type
-/// mappings but lacks some accessors into the mapping as well as access to the
-/// original types.
-class OneToNTypeMapping : public TypeConverter::SignatureConversion {
-public:
- OneToNTypeMapping(TypeRange originalTypes)
- : TypeConverter::SignatureConversion(originalTypes.size()),
- originalTypes(originalTypes) {}
-
- using TypeConverter::SignatureConversion::getConvertedTypes;
-
- /// Returns the list of types that corresponds to the original type at the
- /// given index.
- TypeRange getConvertedTypes(unsigned originalTypeNo) const;
-
- /// Returns the list of original types.
- TypeRange getOriginalTypes() const { return originalTypes; }
-
- /// Returns the slice of converted values that corresponds the original value
- /// at the given index.
- ValueRange getConvertedValues(ValueRange convertedValues,
- unsigned originalValueNo) const;
-
- /// Fills the given result vector with as many copies of the location of the
- /// original value as the number of values it is converted to.
- void convertLocation(Value originalValue, unsigned originalValueNo,
- llvm::SmallVectorImpl<Location> &result) const;
-
- /// Fills the given result vector with as many copies of the lociation of each
- /// original value as the number of values they are respectively converted to.
- void convertLocations(ValueRange originalValues,
- llvm::SmallVectorImpl<Location> &result) const;
-
- /// Returns true iff at least one type conversion maps an input type to a type
- /// that is different from itself.
- bool hasNonIdentityConversion() const;
-
-private:
- llvm::SmallVector<Type> originalTypes;
-};
-
-/// Extends the basic `RewritePattern` class with a type converter member and
-/// some accessors to it. This is useful for patterns that are not
-/// `ConversionPattern`s but still require access to a type converter.
-class RewritePatternWithConverter : public mlir::RewritePattern {
-public:
- /// Construct a conversion pattern with the given converter, and forward the
- /// remaining arguments to RewritePattern.
- template <typename... Args>
- RewritePatternWithConverter(const TypeConverter &typeConverter,
- Args &&...args)
- : RewritePattern(std::forward<Args>(args)...),
- typeConverter(&typeConverter) {}
-
- /// Return the type converter held by this pattern, or nullptr if the pattern
- /// does not require type conversion.
- const TypeConverter *getTypeConverter() const { return typeConverter; }
-
- template <typename ConverterTy>
- std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
- const ConverterTy *>
- getTypeConverter() const {
- return static_cast<const ConverterTy *>(typeConverter);
- }
-
-protected:
- /// A type converter for use by this pattern.
- const TypeConverter *const typeConverter;
-};
-
-/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
-/// class provides additional rewrite methods that are specific to 1:N type
-/// conversions.
-class OneToNPatternRewriter : public PatternRewriter {
-public:
- OneToNPatternRewriter(MLIRContext *context,
- OpBuilder::Listener *listener = nullptr)
- : PatternRewriter(context, listener) {}
-
- /// Replaces the results of the operation with the specified list of values
- /// mapped back to the original types as specified in the provided type
- /// mapping. That type mapping must match the replaced op (i.e., the original
- /// types must be the same as the result types of the op) and the new values
- /// (i.e., the converted types must be the same as the types of the new
- /// values).
- /// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
- /// 1:N support has been added to the regular dialect conversion driver.
- LLVM_DEPRECATED("Use replaceOpWithMultiple() instead",
- "replaceOpWithMultiple")
- void replaceOp(Operation *op, ValueRange newValues,
- const OneToNTypeMapping &resultMapping);
- using PatternRewriter::replaceOp;
-
- /// Applies the given argument conversion to the given block. This consists of
- /// replacing each original argument with N arguments as specified in the
- /// argument conversion and inserting unrealized casts from the converted
- /// values to the original types, which are then used in lieu of the original
- /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
- /// with a user-provided argument materialization if necessary.) This is
- /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
- /// type conversion properly and probably (2) doesn't handle many other edge
- /// cases.
- Block *applySignatureConversion(Block *block,
- OneToNTypeMapping &argumentConversion);
-};
-
-/// Base class for patterns with 1:N type conversions. Derived classes have to
-/// overwrite the `matchAndRewrite` overlaod that provides additional
-/// information for 1:N type conversions.
-class OneToNConversionPattern : public RewritePatternWithConverter {
-public:
- using RewritePatternWithConverter::RewritePatternWithConverter;
-
- /// This function has to be implemented by derived classes and is called from
- /// the usual overloads. Like in "normal" `DialectConversion`, the function is
- /// provided with the converted operands (which thus have target types). Since
- /// 1:N conversions are supported, there is usually no 1:1 relationship
- /// between the original and the converted operands. Instead, the provided
- /// `operandMapping` can be used to access the converted operands that
- /// correspond to a particular original operand. Similarly, `resultMapping`
- /// is provided to help with assembling the result values, which may have 1:N
- /// correspondences as well. In that case, the original op should be replaced
- /// with the overload of `replaceOp` that takes the provided `resultMapping`
- /// in order to deal with the mapping of converted result values to their
- /// usages in the original types correctly.
- virtual LogicalResult matchAndRewrite(Operation *op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const = 0;
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final;
-};
-
-/// This class is a wrapper around `OneToNConversionPattern` for matching
-/// against instances of a particular op class.
-template <typename SourceOp>
-class OneToNOpConversionPattern : public OneToNConversionPattern {
-public:
- OneToNOpConversionPattern(const TypeConverter &typeConverter,
- MLIRContext *context, PatternBenefit benefit = 1,
- ArrayRef<StringRef> generatedNames = {})
- : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
- benefit, context, generatedNames) {}
- /// Generic adaptor around the root op of this pattern using the converted
- /// operands. Importantly, each operand is represented as a *range* of values,
- /// namely the N values each original operand gets converted to. Concretely,
- /// this makes the result type of the accessor functions of the adaptor class
- /// be a `ValueRange`.
- class OpAdaptor
- : public SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> {
- public:
- using RangeT = ArrayRef<ValueRange>;
- using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
- using Properties = typename SourceOp::template InferredProperties<SourceOp>;
-
- OpAdaptor(const OneToNTypeMapping *operandMapping,
- const OneToNTypeMapping *resultMapping,
- const ValueRange *convertedOperands, RangeT values, SourceOp op)
- : BaseT(values, op), operandMapping(operandMapping),
- resultMapping(resultMapping), convertedOperands(convertedOperands) {}
-
- /// Get the type mapping of the original operands to the converted operands.
- const OneToNTypeMapping &getOperandMapping() const {
- return *operandMapping;
- }
-
- /// Get the type mapping of the original results to the converted results.
- const OneToNTypeMapping &getResultMapping() const { return *resultMapping; }
-
- /// Get a flat range of all converted operands. Unlike `getOperands`, which
- /// returns an `ArrayRef` with one `ValueRange` for each original operand,
- /// this function returns a `ValueRange` that contains all converted
- /// operands irrespectively of which operand they originated from.
- ValueRange getFlatOperands() const { return *convertedOperands; }
-
- private:
- const OneToNTypeMapping *operandMapping;
- const OneToNTypeMapping *resultMapping;
- const ValueRange *convertedOperands;
- };
-
- using OneToNConversionPattern::matchAndRewrite;
-
- /// Overload that derived classes have to override for their op type.
- virtual LogicalResult
- matchAndRewrite(SourceOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const = 0;
-
- LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const final {
- // Wrap converted operands and type mappings into an adaptor.
- SmallVector<ValueRange> valueRanges;
- for (int64_t i = 0; i < op->getNumOperands(); i++) {
- auto values = operandMapping.getConvertedValues(convertedOperands, i);
- valueRanges.push_back(values);
- }
- OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
- valueRanges, cast<SourceOp>(op));
-
- // Call overload implemented by the derived class.
- return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);
- }
-};
-
-/// Applies the given set of patterns recursively on the given op and adds user
-/// materializations where necessary. The patterns are expected to be
-/// `OneToNConversionPattern`, which help converting the types of the operands
-/// and results of the matched ops. The provided type converter is used to
-/// convert the operands of matched ops from their original types to operands
-/// with different types. Unlike in `DialectConversion`, this supports 1:N type
-/// conversions. Those conversions at the "boundary" of the pattern application,
-/// where converted results are not consumed by replaced ops that expect the
-/// converted operands or vice versa, the function inserts user materializations
-/// from the type converter. Also unlike `DialectConversion`, there are no legal
-/// or illegal types; the function simply applies the given patterns and does
-/// not fail if some ops or types remain unconverted (i.e., the conversion is
-/// only "partial").
-/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
-/// 1:N support has been added to the regular dialect conversion driver.
-LLVM_DEPRECATED("Use applyPartialConversion() instead",
- "applyPartialConversion")
-LogicalResult
-applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
- const FrozenRewritePatternSet &patterns);
-
-/// Add a pattern to the given pattern list to convert the signature of a
-/// FunctionOpInterface op with the given type converter. This only supports
-/// ops which use FunctionType to represent their type. This is intended to be
-/// used with the 1:N dialect conversion.
-/// FIXME: The 1:N dialect conversion is deprecated and will be removed soon.
-/// 1:N support has been added to the regular dialect conversion driver.
-LLVM_DEPRECATED(
- "Use populateFunctionOpInterfaceTypeConversionPattern() instead",
- "populateFunctionOpInterfaceTypeConversionPattern")
-void populateOneToNFunctionOpInterfaceTypeConversionPattern(
- StringRef functionLikeOpName, const TypeConverter &converter,
- RewritePatternSet &patterns);
-template <typename FuncOpT>
-void populateOneToNFunctionOpInterfaceTypeConversionPattern(
- const TypeConverter &converter, RewritePatternSet &patterns) {
- populateOneToNFunctionOpInterfaceTypeConversionPattern(
- FuncOpT::getOperationName(), converter, patterns);
-}
-
-} // namespace mlir
-
-#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 6384d25ee70273..0bed59e109503f 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -1,7 +1,6 @@
add_mlir_dialect_library(MLIRFuncTransforms
DuplicateFunctionElimination.cpp
FuncConversions.cpp
- OneToNFuncConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
deleted file mode 100644
index 3b8982257a9c95..00000000000000
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ /dev/null
@@ -1,87 +0,0 @@
-//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The patterns in this file are heavily inspired (and copied from)
-// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
-// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
-// type conversions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-namespace {
-
-class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
-public:
- using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(CallOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!adaptor.getOperandMapping().hasNonIdentityConversion() &&
- !resultMapping.hasNonIdentityConversion())
- return failure();
-
- // Create new CallOp.
- auto newOp =
- rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
- adaptor.getFlatOperands(), op->getAttrs());
-
- rewriter.replaceOp(op, newOp->getResults(), resultMapping);
- return success();
- }
-};
-
-class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
-public:
- using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- // Nothing to do if there is no non-identity conversion.
- if (!adaptor.getOperandMapping().hasNonIdentityConversion())
- return failure();
-
- // Convert operands.
- rewriter.modifyOpInPlace(
- op, [&] { op->setOperands(adaptor.getFlatOperands()); });
-
- return success();
- }
-};
-
-} // namespace
-
-namespace mlir {
-
-void populateFuncTypeConversionPatterns(const TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<
- // clang-format off
- ConvertTypesInFuncCallOp,
- ConvertTypesInFuncReturnOp
- // clang-format on
- >(typeConverter, patterns.getContext());
- populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- typeConverter, patterns);
-}
-
-} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e99b5d0cc26fc7..84dd992bec53a7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_dialect_library(MLIRSCFTransforms
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
- OneToNTypeConversion.cpp
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
deleted file mode 100644
index 4cd17f77dfb941..00000000000000
--- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp
+++ /dev/null
@@ -1,215 +0,0 @@
-//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The patterns in this file are heavily inspired (and copied from)
-// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N
-// type conversions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-using namespace mlir;
-using namespace mlir::scf;
-
-class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern<IfOp> {
-public:
- using OneToNOpConversionPattern<IfOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(IfOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-
- // Nothing to do if there is no non-identity conversion.
- if (!resultMapping.hasNonIdentityConversion())
- return failure();
-
- // Create new IfOp.
- TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
- auto newOp = rewriter.create<IfOp>(loc, convertedResultTypes,
- op.getCondition(), true);
- newOp->setAttrs(op->getAttrs());
-
- // We do not need the empty blocks created by rewriter.
- rewriter.eraseBlock(newOp.elseBlock());
- rewriter.eraseBlock(newOp.thenBlock());
-
- // Inlines block from the original operation.
- rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
- newOp.getThenRegion().end());
- rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
- newOp.getElseRegion().end());
-
- rewriter.replaceOp(op, newOp->getResults(), resultMapping);
- return success();
- }
-};
-
-class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern<WhileOp> {
-public:
- using OneToNOpConversionPattern<WhileOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(WhileOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- Location loc = op->getLoc();
-
- const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!operandMapping.hasNonIdentityConversion() &&
- !resultMapping.hasNonIdentityConversion())
- return failure();
-
- // Create new WhileOp.
- TypeRange convertedResultTypes = resultMapping.getConvertedTypes();
-
- auto newOp = rewriter.create<WhileOp>(loc, convertedResultTypes,
- adaptor.getFlatOperands());
- newOp->setAttrs(op->getAttrs());
-
- // Update block signatures.
- std::array<OneToNTypeMapping, 2> blockMappings = {operandMapping,
- resultMapping};
- for (unsigned int i : {0u, 1u}) {
- Region *region = &op.getRegion(i);
- Block *block = ®ion->front();
-
- rewriter.applySignatureConversion(block, blockMappings[i]);
-
- // Move updated region to new WhileOp.
- Region &dstRegion = newOp.getRegion(i);
- rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
- }
-
- rewriter.replaceOp(op, newOp->getResults(), resultMapping);
- return success();
- }
-};
-
-class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern<YieldOp> {
-public:
- using OneToNOpConversionPattern<YieldOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(YieldOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- // Nothing to do if there is no non-identity conversion.
- if (!adaptor.getOperandMapping().hasNonIdentityConversion())
- return failure();
-
- // Convert operands.
- rewriter.modifyOpInPlace(
- op, [&] { op->setOperands(adaptor.getFlatOperands()); });
-
- return success();
- }
-};
-
-class ConvertTypesInSCFConditionOp
- : public OneToNOpConversionPattern<ConditionOp> {
-public:
- using OneToNOpConversionPattern<ConditionOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- // Nothing to do if there is no non-identity conversion.
- if (!adaptor.getOperandMapping().hasNonIdentityConversion())
- return failure();
-
- // Convert operands.
- rewriter.modifyOpInPlace(
- op, [&] { op->setOperands(adaptor.getFlatOperands()); });
-
- return success();
- }
-};
-
-class ConvertTypesInSCFForOp final : public OneToNOpConversionPattern<ForOp> {
-public:
- using OneToNOpConversionPattern<ForOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- const OneToNTypeMapping &operandMapping = adaptor.getOperandMapping();
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-
- // Nothing to do if there is no non-identity conversion.
- if (!operandMapping.hasNonIdentityConversion() &&
- !resultMapping.hasNonIdentityConversion())
- return failure();
-
- // If the lower-bound, upper-bound, or step were expanded, abort the
- // conversion. This conversion does not know what to do in such cases.
- ValueRange lbs = adaptor.getLowerBound();
- ValueRange ubs = adaptor.getUpperBound();
- ValueRange steps = adaptor.getStep();
- if (lbs.size() != 1 || ubs.size() != 1 || steps.size() != 1)
- return rewriter.notifyMatchFailure(
- forOp, "index operands converted to multiple values");
-
- Location loc = forOp.getLoc();
-
- Region *region = &forOp.getRegion();
- Block *block = ®ion->front();
-
- // Construct the new for-op with an empty body.
- ValueRange newInits = adaptor.getFlatOperands().drop_front(3);
- auto newOp =
- rewriter.create<ForOp>(loc, lbs[0], ubs[0], steps[0], newInits);
- newOp->setAttrs(forOp->getAttrs());
-
- // We do not need the empty blocks created by rewriter.
- rewriter.eraseBlock(newOp.getBody());
-
- // Convert the signature of the body region.
- OneToNTypeMapping bodyTypeMapping(block->getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
- bodyTypeMapping)))
- return failure();
-
- // Perform signature conversion on the body block.
- rewriter.applySignatureConversion(block, bodyTypeMapping);
-
- // Splice the old body region into the new for-op.
- Region &dstRegion = newOp.getBodyRegion();
- rewriter.inlineRegionBefore(forOp.getRegion(), dstRegion, dstRegion.end());
-
- rewriter.replaceOp(forOp, newOp.getResults(), resultMapping);
-
- return success();
- }
-};
-
-namespace mlir {
-namespace scf {
-
-void populateSCFStructuralOneToNTypeConversions(
- const TypeConverter &typeConverter, RewritePatternSet &patterns) {
- patterns.add<
- // clang-format off
- ConvertTypesInSCFConditionOp,
- ConvertTypesInSCFForOp,
- ConvertTypesInSCFIfOp,
- ConvertTypesInSCFWhileOp,
- ConvertTypesInSCFYieldOp
- // clang-format on
- >(typeConverter, patterns.getContext());
-}
-
-} // namespace scf
-} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 29f7e8afe0773b..d837e305c4c34e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -29,7 +29,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
@@ -933,7 +932,8 @@ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
- OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
+ TypeConverter::SignatureConversion oneToNTypeMapping(
+ fnType.getInputs().size());
// For arguments that are of illegal types and require unrolling.
// `unrolledInputNums` stores the indices of arguments that result from
@@ -1073,7 +1073,8 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
return failure();
FunctionType fnType = funcOp.getFunctionType();
- OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
+ TypeConverter::SignatureConversion oneToNTypeMapping(
+ fnType.getResults().size());
Location loc = returnOp.getLoc();
// For the new return op.
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 72eb34f36cf5f6..3ca16239ba33c0 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_library(MLIRTransformUtils
Inliner.cpp
InliningUtils.cpp
LoopInvariantCodeMotionUtils.cpp
- OneToNTypeConversion.cpp
RegionUtils.cpp
WalkPatternRewriteDriver.cpp
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6c3863e4c7f666..f54b9b1c1328bf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2968,17 +2968,6 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return success();
}
-Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
- Location loc,
- Type resultType,
- ValueRange inputs) const {
- for (const MaterializationCallbackFn &fn :
- llvm::reverse(argumentMaterializations))
- if (Value result = fn(builder, resultType, inputs, loc))
- return result;
- return nullptr;
-}
-
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs) const {
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
deleted file mode 100644
index 6474c59595eb43..00000000000000
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ /dev/null
@@ -1,458 +0,0 @@
-//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallSet.h"
-
-#include <unordered_map>
-
-using namespace llvm;
-using namespace mlir;
-
-TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
- TypeRange convertedTypes = getConvertedTypes();
- if (auto mapping = getInputMapping(originalTypeNo))
- return convertedTypes.slice(mapping->inputNo, mapping->size);
- return {};
-}
-
-ValueRange
-OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
- unsigned originalValueNo) const {
- if (auto mapping = getInputMapping(originalValueNo))
- return convertedValues.slice(mapping->inputNo, mapping->size);
- return {};
-}
-
-void OneToNTypeMapping::convertLocation(
- Value originalValue, unsigned originalValueNo,
- llvm::SmallVectorImpl<Location> &result) const {
- if (auto mapping = getInputMapping(originalValueNo))
- result.append(mapping->size, originalValue.getLoc());
-}
-
-void OneToNTypeMapping::convertLocations(
- ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
- assert(originalValues.size() == getOriginalTypes().size());
- for (auto [i, value] : llvm::enumerate(originalValues))
- convertLocation(value, i, result);
-}
-
-static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
- return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
-}
-
-bool OneToNTypeMapping::hasNonIdentityConversion() const {
- // XXX: I think that the original types and the converted types are the same
- // iff there was no non-identity type conversion. If that is true, the
- // patterns could actually test whether there is anything useful to do
- // without having access to the signature conversion.
- for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
- TypeRange types = getConvertedTypes(i);
- if (!isIdentityConversion(originalType, types)) {
- assert(TypeRange(originalTypes) != getConvertedTypes());
- return true;
- }
- }
- assert(TypeRange(originalTypes) == getConvertedTypes());
- return false;
-}
-
-namespace {
-enum class CastKind {
- // Casts block arguments in the target type back to the source type. (If
- // necessary, this cast becomes an argument materialization.)
- Argument,
-
- // Casts other values in the target type back to the source type. (If
- // necessary, this cast becomes a source materialization.)
- Source,
-
- // Casts values in the source type to the target type. (If necessary, this
- // cast becomes a target materialization.)
- Target
-};
-} // namespace
-
-/// Mapping of enum values to string values.
-StringRef getCastKindName(CastKind kind) {
- static const std::unordered_map<CastKind, StringRef> castKindNames = {
- {CastKind::Argument, "argument"},
- {CastKind::Source, "source"},
- {CastKind::Target, "target"}};
- return castKindNames.at(kind);
-}
-
-/// Attribute name that is used to annotate inserted unrealized casts with their
-/// kind (source, argument, or target).
-static const char *const castKindAttrName =
- "__one-to-n-type-conversion_cast-kind__";
-
-/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
-/// result types. Returns the result values of the cast.
-static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
- ValueRange inputs, CastKind kind) {
- // Special case: 1-to-N conversion with N = 0. No need to build an
- // UnrealizedConversionCastOp because the op will always be dead.
- if (resultTypes.empty())
- return ValueRange();
-
- // Create cast.
- Location loc = builder.getUnknownLoc();
- if (!inputs.empty())
- loc = inputs.front().getLoc();
- auto castOp =
- builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
-
- // Store cast kind as attribute.
- auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
- castOp->setAttr(castKindAttrName, kindAttr);
-
- return castOp->getResults();
-}
-
-/// Builds one `UnrealizedConversionCastOp` for each of the given original
-/// values using the respective target types given in the provided conversion
-/// mapping and returns the results of these casts. If the conversion mapping of
-/// a value maps a type to itself (i.e., is an identity conversion), then no
-/// cast is inserted and the original value is returned instead.
-/// Note that these unrealized casts are different from target materializations
-/// in that they are *always* inserted, even if they immediately fold away, such
-/// that patterns always see valid intermediate IR, whereas materializations are
-/// only used in the places where the unrealized casts *don't* fold away.
-static SmallVector<Value>
-buildUnrealizedForwardCasts(ValueRange originalValues,
- OneToNTypeMapping &conversion,
- RewriterBase &rewriter, CastKind kind) {
-
- // Convert each operand one by one.
- SmallVector<Value> convertedValues;
- convertedValues.reserve(conversion.getConvertedTypes().size());
- for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
- TypeRange convertedTypes = conversion.getConvertedTypes(idx);
-
- // Identity conversion: keep operand as is.
- if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
- convertedValues.push_back(originalValue);
- continue;
- }
-
- // Non-identity conversion: materialize target types.
- ValueRange castResult =
- buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
- convertedValues.append(castResult.begin(), castResult.end());
- }
-
- return convertedValues;
-}
-
-/// Builds one `UnrealizedConversionCastOp` for each sequence of the given
-/// original values to one value of the type they originated from, i.e., a
-/// "reverse" conversion from N converted values back to one value of the
-/// original type, using the given (forward) type conversion. If a given value
-/// was mapped to a value of the same type (i.e., the conversion in the mapping
-/// is an identity conversion), then the "converted" value is returned without
-/// cast.
-/// Note that these unrealized casts are different from source materializations
-/// in that they are *always* inserted, even if they immediately fold away, such
-/// that patterns always see valid intermediate IR, whereas materializations are
-/// only used in the places where the unrealized casts *don't* fold away.
-static SmallVector<Value>
-buildUnrealizedBackwardsCasts(ValueRange convertedValues,
- const OneToNTypeMapping &typeConversion,
- RewriterBase &rewriter) {
- assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
-
- // Create unrealized cast op for each converted result of the op.
- SmallVector<Value> recastValues;
- TypeRange originalTypes = typeConversion.getOriginalTypes();
- recastValues.reserve(originalTypes.size());
- auto convertedValueIt = convertedValues.begin();
- for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
- TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
- size_t numConvertedValues = convertedTypes.size();
- if (isIdentityConversion(originalType, convertedTypes)) {
- // Identity conversion: take result as is.
- recastValues.push_back(*convertedValueIt);
- } else {
- // Non-identity conversion: cast back to source type.
- ValueRange recastValue = buildUnrealizedCast(
- rewriter, originalType,
- ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
- CastKind::Source);
- assert(recastValue.size() == 1);
- recastValues.push_back(recastValue.front());
- }
- convertedValueIt += numConvertedValues;
- }
-
- return recastValues;
-}
-
-void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
- const OneToNTypeMapping &resultMapping) {
- // Create a cast back to the original types and replace the results of the
- // original op with those.
- assert(newValues.size() == resultMapping.getConvertedTypes().size());
- assert(op->getResultTypes() == resultMapping.getOriginalTypes());
- PatternRewriter::InsertionGuard g(*this);
- setInsertionPointAfter(op);
- SmallVector<Value> castResults =
- buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
- replaceOp(op, castResults);
-}
-
-Block *OneToNPatternRewriter::applySignatureConversion(
- Block *block, OneToNTypeMapping &argumentConversion) {
- PatternRewriter::InsertionGuard g(*this);
-
- // Split the block at the beginning to get a new block to use for the
- // updated signature.
- SmallVector<Location> locs;
- argumentConversion.convertLocations(block->getArguments(), locs);
- Block *newBlock =
- createBlock(block, argumentConversion.getConvertedTypes(), locs);
- replaceAllUsesWith(block, newBlock);
-
- // Create necessary casts in new block.
- SmallVector<Value> castResults;
- for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
- TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
- ValueRange newArgs =
- argumentConversion.getConvertedValues(newBlock->getArguments(), i);
- if (isIdentityConversion(arg.getType(), convertedTypes)) {
- // Identity conversion: take argument as is.
- assert(newArgs.size() == 1);
- castResults.push_back(newArgs.front());
- } else {
- // Non-identity conversion: cast the converted arguments to the original
- // type.
- PatternRewriter::InsertionGuard g(*this);
- setInsertionPointToStart(newBlock);
- ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
- CastKind::Argument);
- assert(castResult.size() == 1);
- castResults.push_back(castResult.front());
- }
- }
-
- // Merge old block into new block such that we only have the latter with the
- // new signature.
- mergeBlocks(block, newBlock, castResults);
-
- return newBlock;
-}
-
-LogicalResult
-OneToNConversionPattern::matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter();
-
- // Construct conversion mapping for results.
- Operation::result_type_range originalResultTypes = op->getResultTypes();
- OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
- resultMapping)))
- return failure();
-
- // Construct conversion mapping for operands.
- Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
- OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
- operandMapping)))
- return failure();
-
- // Cast operands to target types.
- SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts(
- op->getOperands(), operandMapping, rewriter, CastKind::Target);
-
- // Create a `OneToNPatternRewriter` for the pattern, which provides additional
- // functionality.
- // TODO(ingomueller): I guess it would be better to use only one rewriter
- // throughout the whole pass, but that would require to
- // drive the pattern application ourselves, which is a lot
- // of additional boilerplate code. This seems to work fine,
- // so I leave it like this for the time being.
- OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(),
- rewriter.getListener());
- oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
-
- // Apply actual pattern.
- if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
- resultMapping, convertedOperands)))
- return failure();
-
- return success();
-}
-
-namespace mlir {
-
-// This function applies the provided patterns using
-// `applyPatternsGreedily` and then replaces all newly inserted
-// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
-// from target to source types inserted by a `OneToNConversionPattern` normally
-// fold away with the "forward" casts from source to target types inserted by
-// the next pattern.) To understand which casts are "newly inserted", all casts
-// inserted by this pass are annotated with a string attribute that also
-// documents which kind of the cast (source, argument, or target).
-LogicalResult
-applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
- const FrozenRewritePatternSet &patterns) {
-#ifndef NDEBUG
- // Remember existing unrealized casts. This data structure is only used in
- // asserts; building it only for that purpose may be an overkill.
- SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
- op->walk([&](UnrealizedConversionCastOp castOp) {
- assert(!castOp->hasAttr(castKindAttrName));
- existingCasts.insert(castOp);
- });
-#endif // NDEBUG
-
- // Apply provided conversion patterns.
- if (failed(applyPatternsGreedily(op, patterns))) {
- emitError(op->getLoc()) << "failed to apply conversion patterns";
- return failure();
- }
-
- // Find all unrealized casts inserted by the pass that haven't folded away.
- SmallVector<UnrealizedConversionCastOp> worklist;
- op->walk([&](UnrealizedConversionCastOp castOp) {
- if (castOp->hasAttr(castKindAttrName)) {
- assert(!existingCasts.contains(castOp));
- worklist.push_back(castOp);
- }
- });
-
- // Replace new casts with user materializations.
- IRRewriter rewriter(op->getContext());
- for (UnrealizedConversionCastOp castOp : worklist) {
- TypeRange resultTypes = castOp->getResultTypes();
- ValueRange operands = castOp->getOperands();
- StringRef castKind =
- castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
- rewriter.setInsertionPoint(castOp);
-
-#ifndef NDEBUG
- // Determine whether operands or results are already legal to test some
- // assumptions for the different kind of materializations. These properties
- // are only used it asserts and it may be overkill to compute them.
- bool areOperandTypesLegal = llvm::all_of(
- operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
- bool areResultsTypesLegal = llvm::all_of(
- resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
-#endif // NDEBUG
-
- // Add materialization and remember materialized results.
- SmallVector<Value> materializedResults;
- if (castKind == getCastKindName(CastKind::Target)) {
- // Target materialization.
- assert(!areOperandTypesLegal && areResultsTypesLegal &&
- operands.size() == 1 && "found unexpected target cast");
- materializedResults = typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (materializedResults.empty()) {
- emitError(castOp->getLoc())
- << "failed to create target materialization";
- return failure();
- }
- } else {
- // Source and argument materializations.
- assert(areOperandTypesLegal && !areResultsTypesLegal &&
- resultTypes.size() == 1 && "found unexpected cast");
- std::optional<Value> maybeResult;
- if (castKind == getCastKindName(CastKind::Source)) {
- // Source materialization.
- maybeResult = typeConverter.materializeSourceConversion(
- rewriter, castOp->getLoc(), resultTypes.front(),
- castOp.getOperands());
- } else {
- // Argument materialization.
- assert(castKind == getCastKindName(CastKind::Argument) &&
- "unexpected value of cast kind attribute");
- assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
- maybeResult = typeConverter.materializeArgumentConversion(
- rewriter, castOp->getLoc(), resultTypes.front(),
- castOp.getOperands());
- }
- if (!maybeResult.has_value() || !maybeResult.value()) {
- emitError(castOp->getLoc())
- << "failed to create " << castKind << " materialization";
- return failure();
- }
- materializedResults = {maybeResult.value()};
- }
-
- // Replace the cast with the result of the materialization.
- rewriter.replaceOp(castOp, materializedResults);
- }
-
- return success();
-}
-
-namespace {
-class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
-public:
- FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
- MLIRContext *ctx,
- const TypeConverter &converter)
- : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1,
- ctx) {}
-
- LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
- auto funcOp = cast<FunctionOpInterface>(op);
- auto *typeConverter = getTypeConverter();
-
- // Construct mapping for function arguments.
- OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
- if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
- argumentMapping)))
- return failure();
-
- // Construct mapping for function results.
- OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
- if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
- funcResultMapping)))
- return failure();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!argumentMapping.hasNonIdentityConversion() &&
- !funcResultMapping.hasNonIdentityConversion())
- return failure();
-
- // Update the function signature in-place.
- auto newType = FunctionType::get(rewriter.getContext(),
- argumentMapping.getConvertedTypes(),
- funcResultMapping.getConvertedTypes());
- rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); });
-
- // Update block signatures.
- if (!funcOp.isExternal()) {
- Region *region = &funcOp.getFunctionBody();
- Block *block = ®ion->front();
- rewriter.applySignatureConversion(block, argumentMapping);
- }
-
- return success();
- }
-};
-} // namespace
-
-void populateOneToNFunctionOpInterfaceTypeConversionPattern(
- StringRef functionLikeOpName, const TypeConverter &converter,
- RewritePatternSet &patterns) {
- patterns.add<FunctionOpInterfaceSignatureConversion>(
- functionLikeOpName, patterns.getContext(), converter);
-}
-} // namespace mlir
diff --git a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
deleted file mode 100644
index 611ec0265cd37b..00000000000000
--- a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir
+++ /dev/null
@@ -1,140 +0,0 @@
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-tuple-ops" \
-// RUN: | FileCheck --check-prefix=CHECK-TUP %s
-
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-func-ops" \
-// RUN: | FileCheck --check-prefix=CHECK-FUNC %s
-
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-tuple-ops" \
-// RUN: | FileCheck --check-prefix=CHECK-BOTH %s
-
-// Test case: Matching nested packs and unpacks just disappear.
-
-// CHECK-TUP-LABEL: func.func @pack_unpack(
-// CHECK-TUP-SAME: %[[ARG0:.*]]: i1,
-// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-TUP-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2
-func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) {
- %0 = "test.make_tuple"() : () -> tuple<>
- %1 = "test.make_tuple"(%arg1) : (i2) -> tuple<i2>
- %2 = "test.make_tuple"(%1) : (tuple<i2>) -> tuple<tuple<i2>>
- %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
- %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
- %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
- %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
- %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
- %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple<i2>) -> i2
- return %5, %8 : i1, i2
-}
-
-// -----
-
-// Test case: Appropriate materializations are created depending on which ops
-// are converted.
-
-// If we only convert the tuple ops, the original `get_tuple_element` ops will
-// disappear but one target materialization will be inserted from the
-// unconverted function arguments to each of the return values (which have
-// redundancy among themselves).
-//
-// CHECK-TUP-LABEL: func.func @materializations_tuple_args(
-// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) {
-// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK-TUP-DAG: return %[[V1]], %[[V9]] : i1, i2
-
-// If we only convert the func ops, argument materializations are created from
-// the converted tuple elements back to the tuples that the `get_tuple_element`
-// ops expect.
-//
-// CHECK-FUNC-LABEL: func.func @materializations_tuple_args(
-// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1,
-// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-FUNC-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
-// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
-
-// If we convert both tuple and func ops, basically everything disappears.
-//
-// CHECK-BOTH-LABEL: func.func @materializations_tuple_args(
-// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1,
-// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-BOTH-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2
-
-func.func @materializations_tuple_args(%arg0: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) {
- %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
- %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
- %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
- %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
- %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple<i2>) -> i2
- return %1, %4 : i1, i2
-}
-// -----
-
-// Test case: Appropriate materializations are created depending on which ops
-// are converted.
-
-// If we only convert the tuple ops, the original `make_tuple` ops will
-// disappear but a source materialization will be inserted from the result of
-// conversion (which, for `make_tuple`, are the original ops that get forwarded)
-// to the operands of the unconverted op with the original type (i.e.,
-// `return`).
-
-// CHECK-TUP-LABEL: func.func @materializations_tuple_return(
-// CHECK-TUP-SAME: %[[ARG0:.*]]: i1,
-// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> tuple<tuple<>, i1, tuple<tuple<i2>>> {
-// CHECK-TUP-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-TUP-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK-TUP-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK-TUP-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
-// CHECK-TUP-DAG: return %[[V3]] : tuple<tuple<>, i1, tuple<tuple<i2>>>
-
-// If we only convert the func ops, target materializations are created from
-// original tuples produced by `make_tuple` to its constituent elements that the
-// converted op (i.e., `return`) expect.
-//
-// CHECK-FUNC-LABEL: func.func @materializations_tuple_return(
-// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1,
-// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-FUNC-DAG: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
-// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
-// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
-// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
-
-// If we convert both tuple and func ops, basically everything disappears.
-//
-// CHECK-BOTH-LABEL: func.func @materializations_tuple_return(
-// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1,
-// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-BOTH-DAG: return %[[ARG0]], %[[ARG1]] : i1, i2
-
-func.func @materializations_tuple_return(%arg0: i1, %arg1: i2) -> tuple<tuple<>, i1, tuple<tuple<i2>>> {
- %0 = "test.make_tuple"() : () -> tuple<>
- %1 = "test.make_tuple"(%arg1) : (i2) -> tuple<i2>
- %2 = "test.make_tuple"(%1) : (tuple<i2>) -> tuple<tuple<i2>>
- %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
- return %3 : tuple<tuple<>, i1, tuple<tuple<i2>>>
-}
diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
deleted file mode 100644
index 535ab68e8d893c..00000000000000
--- a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir
+++ /dev/null
@@ -1,183 +0,0 @@
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \
-// RUN: | FileCheck %s
-
-// Test case: Nested 1:N type conversion is carried through scf.if and
-// scf.yield.
-
-// CHECK-LABEL: func.func @if_result(
-// CHECK-SAME: %[[ARG0:.*]]: i1,
-// CHECK-SAME: %[[ARG1:.*]]: i2,
-// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) {
-// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) {
-// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
-// CHECK-NEXT: } else {
-// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2
-func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
- %0 = scf.if %arg1 -> (tuple<tuple<>, i1, tuple<i2>>) {
- scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
- } else {
- scf.yield %arg0 : tuple<tuple<>, i1, tuple<i2>>
- }
- return %0 : tuple<tuple<>, i1, tuple<i2>>
-}
-
-// -----
-
-// Test case: Nested 1:N type conversion is carried through scf.if and
-// scf.yield and unconverted ops inside have proper materializations.
-
-// CHECK-LABEL: func.func @if_tuple_ops(
-// CHECK-SAME: %[[ARG0:.*]]: i1,
-// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 {
-// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) {
-// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: scf.yield %[[V5]] : i1
-// CHECK-NEXT: } else {
-// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: scf.yield %[[V8]] : i1
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[V2]] : i1
-func.func @if_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
- %0 = scf.if %arg1 -> (tuple<tuple<>, i1>) {
- %1 = "test.op"(%arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
- scf.yield %1 : tuple<tuple<>, i1>
- } else {
- %1 = "test.source"() : () -> tuple<tuple<>, i1>
- scf.yield %1 : tuple<tuple<>, i1>
- }
- return %0 : tuple<tuple<>, i1>
-}
-// -----
-
-// Test case: Nested 1:N type conversion is carried through scf.while,
-// scf.condition, and scf.yield.
-
-// CHECK-LABEL: func.func @while_operands_results(
-// CHECK-SAME: %[[ARG0:.*]]: i1,
-// CHECK-SAME: %[[ARG1:.*]]: i2,
-// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) {
-// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) {
-// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2
-// } do {
-// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2):
-// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2
-// }
-// return %[[V0]]#0, %[[V0]]#1 : i1, i2
-func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<tuple<>, i1, tuple<i2>> {
- %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
- scf.condition(%arg1) %arg2 : tuple<tuple<>, i1, tuple<i2>>
- } do {
- ^bb0(%arg2: tuple<tuple<>, i1, tuple<i2>>):
- scf.yield %arg2 : tuple<tuple<>, i1, tuple<i2>>
- }
- return %0 : tuple<tuple<>, i1, tuple<i2>>
-}
-
-// -----
-
-// Test case: Nested 1:N type conversion is carried through scf.while,
-// scf.condition, and unconverted ops inside have proper materializations.
-
-// CHECK-LABEL: func.func @while_tuple_ops(
-// CHECK-SAME: %[[ARG0:.*]]: i1,
-// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 {
-// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 {
-// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1
-// CHECK-NEXT: } do {
-// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1):
-// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: scf.yield %[[V8]] : i1
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[V0]] : i1
-func.func @while_tuple_ops(%arg0: tuple<tuple<>, i1>, %arg1: i1) -> tuple<tuple<>, i1> {
- %0 = scf.while (%arg2 = %arg0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
- %1 = "test.op"(%arg2) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
- scf.condition(%arg1) %1 : tuple<tuple<>, i1>
- } do {
- ^bb0(%arg2: tuple<tuple<>, i1>):
- %1 = "test.source"() : () -> tuple<tuple<>, i1>
- scf.yield %1 : tuple<tuple<>, i1>
- }
- return %0 : tuple<tuple<>, i1>
-}
-
-// -----
-
-// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield.
-
-// CHECK-LABEL: func.func @for_operands_results(
-// CHECK-SAME: %[[ARG0:.*]]: i1,
-// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-NEXT: %[[OUT:.+]]:2 = scf.for %arg2 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER0:.+]] = %[[ARG0]], %[[ITER1:.+]] = %[[ARG1]]) -> (i1, i2) {
-// CHECK-NEXT: scf.yield %[[ITER0]], %[[ITER1]] : i1, i2
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[OUT]]#0, %[[OUT]]#1 : i1, i2
-
-func.func @for_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>) -> tuple<tuple<>, i1, tuple<i2>> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c10 = arith.constant 10 : index
-
- %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1, tuple<i2>> {
- scf.yield %acc : tuple<tuple<>, i1, tuple<i2>>
- }
-
- return %0 : tuple<tuple<>, i1, tuple<i2>>
-}
-
-// -----
-
-// Test case: Nested 1:N type conversion is carried through scf.for and scf.yield
-
-// CHECK-LABEL: func.func @for_tuple_ops(
-// CHECK-SAME: %[[ARG0:.+]]: i1) -> i1 {
-// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-NEXT: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-NEXT: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-NEXT: %[[FOR:.+]] = scf.for %arg1 = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[ITER:.+]] = %[[ARG0]]) -> (i1) {
-// CHECK-NEXT: %[[V1:.+]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-NEXT: %[[V2:.+]] = "test.make_tuple"(%[[V1]], %[[ITER]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V3:.+]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V4:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V5:.+]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: scf.yield %[[V5]] : i1
-// CHECK-NEXT: }
-// CHECK-NEXT: %[[V6:.+]] = "test.make_tuple"() : () -> tuple<>
-// CHECK-NEXT: %[[V7:.+]] = "test.make_tuple"(%[[V6]], %[[FOR]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V8:.+]] = "test.op"(%[[V7]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
-// CHECK-NEXT: %[[V9:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
-// CHECK-NEXT: %[[V10:.+]] = "test.get_tuple_element"(%[[V8]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
-// CHECK-NEXT: return %[[V10]] : i1
-
-func.func @for_tuple_ops(%arg0: tuple<tuple<>, i1>) -> tuple<tuple<>, i1> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c10 = arith.constant 10 : index
-
- %0 = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %arg0) -> tuple<tuple<>, i1> {
- %1 = "test.op"(%acc) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
- scf.yield %1 : tuple<tuple<>, i1>
- }
-
- %1 = "test.op"(%0) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
- return %1 : tuple<tuple<>, i1>
-}
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index 4e641317ac2f3d..55d78d9fedebb3 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -1,19 +1,11 @@
// RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-func-ops" \
-// RUN: | FileCheck %s --check-prefix=CHECK-12N
-
// Test case: Most basic case of a 1:N decomposition, an identity function.
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32
-// CHECK-12N-LABEL: func @identity(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32
func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
return %arg0 : tuple<i1, i32>
}
@@ -25,9 +17,6 @@ func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
// CHECK-LABEL: func @identity_1_to_1_no_materializations(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 {
// CHECK: return %[[ARG0]] : i1
-// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 {
-// CHECK-12N: return %[[ARG0]] : i1
func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
return %arg0 : tuple<i1>
}
@@ -39,9 +28,6 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
// CHECK-LABEL: func @recursive_decomposition(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 {
// CHECK: return %[[ARG0]] : i1
-// CHECK-12N-LABEL: func @recursive_decomposition(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 {
-// CHECK-12N: return %[[ARG0]] : i1
func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tuple<tuple<i1>>> {
return %arg0 : tuple<tuple<tuple<i1>>>
}
@@ -54,10 +40,6 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2
-// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2
func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> {
return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
}
@@ -67,7 +49,6 @@ func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<
// Test case: Check decomposition of calls.
// CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32)
-// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32)
func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-LABEL: func @caller(
@@ -75,11 +56,6 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32
-// CHECK-12N-LABEL: func @caller(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32
func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
%0 = call @callee(%arg0) : (tuple<i1, i32>) -> tuple<i1, i32>
return %0 : tuple<i1, i32>
@@ -90,15 +66,11 @@ func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
// Test case: Type that decomposes to nothing (that is, a 1:0 decomposition).
// CHECK-LABEL: func private @callee()
-// CHECK-12N-LABEL: func private @callee()
func.func private @callee(tuple<>) -> tuple<>
// CHECK-LABEL: func @caller() {
// CHECK: call @callee() : () -> ()
// CHECK: return
-// CHECK-12N-LABEL: func @caller() {
-// CHECK-12N: call @callee() : () -> ()
-// CHECK-12N: return
func.func @caller(%arg0: tuple<>) -> tuple<> {
%0 = call @callee(%arg0) : (tuple<>) -> (tuple<>)
return %0 : tuple<>
@@ -114,11 +86,6 @@ func.func @caller(%arg0: tuple<>) -> tuple<> {
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
-// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) {
-// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
-// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32
func.func @unconverted_op_result() -> tuple<i1, i32> {
%0 = "test.source"() : () -> (tuple<i1, i32>)
return %0 : tuple<i1, i32>
@@ -139,16 +106,6 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
// CHECK: return %[[V3]], %[[V5]] : i1, i32
-// CHECK-12N-LABEL: func @nested_unconverted_op_result(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
-// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
-// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
-// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1
-// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
-// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
-// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32
func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
%0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
return %0 : tuple<i1, tuple<i32>>
@@ -160,7 +117,6 @@ func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1
// This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
// CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
// CHECK-LABEL: func @caller(
@@ -172,15 +128,6 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
-// CHECK-12N-LABEL: func @caller(
-// CHECK-12N-SAME: %[[I1:.*]]: i1,
-// CHECK-12N-SAME: %[[I2:.*]]: i2,
-// CHECK-12N-SAME: %[[I3:.*]]: i3,
-// CHECK-12N-SAME: %[[I4:.*]]: i4,
-// CHECK-12N-SAME: %[[I5:.*]]: i5,
-// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple<i2>, %arg3: i3, %arg4: tuple<i4, i5>, %arg5: i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) {
%0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6
diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index 19975f671b081d..c09496be729be2 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1,5 +1,4 @@
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToLLVM)
add_subdirectory(MathToVCIX)
-add_subdirectory(OneToNTypeConversion)
add_subdirectory(VectorToSPIRV)
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
deleted file mode 100644
index b72302202f72b0..00000000000000
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
+++ /dev/null
@@ -1,21 +0,0 @@
-add_mlir_library(MLIRTestOneToNTypeConversionPass
- TestOneToNTypeConversionPass.cpp
-
- EXCLUDE_FROM_LIBMLIR
-
- LINK_LIBS PUBLIC
- MLIRFuncDialect
- MLIRFuncTransforms
- MLIRIR
- MLIRPass
- MLIRSCFDialect
- MLIRSCFTransforms
- MLIRTestDialect
- MLIRTransformUtils
- )
-
-target_include_directories(MLIRTestOneToNTypeConversionPass
- PRIVATE
- ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
- ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
- )
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
deleted file mode 100644
index b18dfd8bb22cb1..00000000000000
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ /dev/null
@@ -1,261 +0,0 @@
-//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "TestDialect.h"
-#include "TestOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-using namespace mlir;
-
-namespace {
-/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms
-/// in `applyPartialOneToNConversion` by converting built-in tuples to the
-/// elements they consist of as well as some dummy ops operating on these
-/// tuples.
-struct TestOneToNTypeConversionPass
- : public PassWrapper<TestOneToNTypeConversionPass,
- OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass)
-
- TestOneToNTypeConversionPass() = default;
- TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass)
- : PassWrapper(pass) {}
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<test::TestDialect>();
- }
-
- StringRef getArgument() const final {
- return "test-one-to-n-type-conversion";
- }
-
- StringRef getDescription() const final {
- return "Test pass for 1:N type conversion";
- }
-
- Option<bool> convertFuncOps{*this, "convert-func-ops",
- llvm::cl::desc("Enable conversion on func ops"),
- llvm::cl::init(false)};
-
- Option<bool> convertSCFOps{*this, "convert-scf-ops",
- llvm::cl::desc("Enable conversion on scf ops"),
- llvm::cl::init(false)};
-
- Option<bool> convertTupleOps{*this, "convert-tuple-ops",
- llvm::cl::desc("Enable conversion on tuple ops"),
- llvm::cl::init(false)};
-
- void runOnOperation() override;
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestOneToNTypeConversionPass() {
- PassRegistration<TestOneToNTypeConversionPass>();
-}
-} // namespace test
-} // namespace mlir
-
-namespace {
-
-/// Test pattern on for the `make_tuple` op from the test dialect that converts
-/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple
-/// that is being produced by `test.make_tuple`, which are really just the
-/// operands of this op.
-class ConvertMakeTupleOp
- : public OneToNOpConversionPattern<::test::MakeTupleOp> {
-public:
- using OneToNOpConversionPattern<
- ::test::MakeTupleOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(::test::MakeTupleOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- // Simply replace the current op with the converted operands.
- rewriter.replaceOp(op, adaptor.getFlatOperands(),
- adaptor.getResultMapping());
- return success();
- }
-};
-
-/// Test pattern on for the `get_tuple_element` op from the test dialect that
-/// converts this kind of op into it's "decomposed" form, i.e., instead of
-/// "physically" extracting one element from the tuple, we forward the one
-/// element of the decomposed form that is being extracted (or the several
-/// elements in case that element is a nested tuple).
-class ConvertGetTupleElementOp
- : public OneToNOpConversionPattern<::test::GetTupleElementOp> {
-public:
- using OneToNOpConversionPattern<
- ::test::GetTupleElementOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(::test::GetTupleElementOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- // Construct mapping for tuple element types.
- auto stateType = cast<TupleType>(op->getOperand(0).getType());
- TypeRange originalElementTypes = stateType.getTypes();
- OneToNTypeMapping elementMapping(originalElementTypes);
- if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
- elementMapping)))
- return failure();
-
- // Compute converted operands corresponding to original input tuple.
- assert(adaptor.getOperands().size() == 1 &&
- "expected 'get_tuple_element' to have one operand");
- ValueRange convertedTuple = adaptor.getOperands()[0];
-
- // Got those converted operands that correspond to the index-th element ofq
- // the original input tuple.
- size_t index = op.getIndex();
- ValueRange extractedElement =
- elementMapping.getConvertedValues(convertedTuple, index);
-
- rewriter.replaceOp(op, extractedElement, adaptor.getResultMapping());
-
- return success();
- }
-};
-
-} // namespace
-
-static void
-populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<
- // clang-format off
- ConvertMakeTupleOp,
- ConvertGetTupleElementOp
- // clang-format on
- >(typeConverter, patterns.getContext());
-}
-
-/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
-/// given tuple value. If some tuple elements are, in turn, tuples, the elements
-/// of those are extracted recursively such that the returned values have the
-/// same types as `resultTypes.getFlattenedTypes()`.
-///
-/// This function has been copied (with small adaptions) from
-/// TestDecomposeCallGraphTypes.cpp.
-static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
- TypeRange resultTypes,
- ValueRange inputs,
- Location loc) {
- if (inputs.size() != 1)
- return {};
- Value input = inputs.front();
-
- TupleType inputType = dyn_cast<TupleType>(input.getType());
- if (!inputType)
- return {};
-
- SmallVector<Value> values;
- for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
- Value element = builder.create<::test::GetTupleElementOp>(
- loc, elementType, input, builder.getI32IntegerAttr(idx));
- if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
- // Recurse if the current element is also a tuple.
- SmallVector<Type> flatRecursiveTypes;
- nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
- std::optional<SmallVector<Value>> resursiveValues =
- buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
- if (!resursiveValues.has_value())
- return {};
- values.append(resursiveValues.value());
- } else {
- values.push_back(element);
- }
- }
- return values;
-}
-
-/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
-/// type `resultType`. If that type is nested, each nested tuple is built
-/// recursively with another `test.make_tuple` op.
-///
-/// This function has been copied (with small adaptions) from
-/// TestDecomposeCallGraphTypes.cpp.
-static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
- ValueRange inputs, Location loc) {
- // Build one value for each element at this nesting level.
- SmallVector<Value> elements;
- elements.reserve(resultType.getTypes().size());
- ValueRange::iterator inputIt = inputs.begin();
- for (Type elementType : resultType.getTypes()) {
- if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
- // Determine how many input values are needed for the nested elements of
- // the nested TupleType and advance inputIt by that number.
- // TODO: We only need the *number* of nested types, not the types itself.
- // Maybe it's worth adding a more efficient overload?
- SmallVector<Type> nestedFlattenedTypes;
- nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
- size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
- ValueRange nestedFlattenedelements(inputIt,
- inputIt + numNestedFlattenedTypes);
- inputIt += numNestedFlattenedTypes;
-
- // Recurse on the values for the nested TupleType.
- Value res = buildMakeTupleOp(builder, nestedTupleType,
- nestedFlattenedelements, loc);
- if (!res)
- return Value();
-
- // The tuple constructed by the conversion is the element value.
- elements.push_back(res);
- } else {
- // Base case: take one input as is.
- elements.push_back(*inputIt++);
- }
- }
-
- // Assemble the tuple from the elements.
- return builder.create<::test::MakeTupleOp>(loc, resultType, elements);
-}
-
-void TestOneToNTypeConversionPass::runOnOperation() {
- ModuleOp module = getOperation();
- auto *context = &getContext();
-
- // Assemble type converter.
- TypeConverter typeConverter;
-
- typeConverter.addConversion([](Type type) { return type; });
- typeConverter.addConversion(
- [](TupleType tupleType, SmallVectorImpl<Type> &types) {
- tupleType.getFlattenedTypes(types);
- return success();
- });
-
- typeConverter.addArgumentMaterialization(buildMakeTupleOp);
- typeConverter.addSourceMaterialization(buildMakeTupleOp);
- typeConverter.addTargetMaterialization(buildGetTupleElementOps);
- // Test the other target materialization variant that takes the original type
- // as additional argument. This materialization function always fails.
- typeConverter.addTargetMaterialization(
- [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
- Location loc, Type originalType) -> SmallVector<Value> { return {}; });
-
- // Assemble patterns.
- RewritePatternSet patterns(context);
- if (convertTupleOps)
- populateDecomposeTuplesTestPatterns(typeConverter, patterns);
- if (convertFuncOps)
- populateFuncTypeConversionPatterns(typeConverter, patterns);
- if (convertSCFOps)
- scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
-
- // Run conversion.
- if (failed(applyPartialOneToNConversion(module, typeConverter,
- std::move(patterns))))
- return signalPassFailure();
-}
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 3563d66fa9e798..670f13caa9fafb 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -40,7 +40,6 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestDialect
MLIRTestDynDialect
MLIRTestIR
- MLIRTestOneToNTypeConversionPass
MLIRTestPass
MLIRTestReducer
MLIRTestTransforms
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 960f7037a1b61f..3542d7898f32cb 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -132,7 +132,6 @@ void registerTestMeshSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
-void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
void registerTestOpLoweringPasses();
void registerTestPadFusion();
@@ -271,7 +270,6 @@ void registerTestPasses() {
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();
- mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestPadFusion();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e823af2f147120..3398b9c63927dc 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7986,7 +7986,6 @@ cc_library(
"include/mlir/Transforms/GreedyPatternRewriteDriver.h",
"include/mlir/Transforms/Inliner.h",
"include/mlir/Transforms/LoopInvariantCodeMotionUtils.h",
- "include/mlir/Transforms/OneToNTypeConversion.h",
"include/mlir/Transforms/RegionUtils.h",
"include/mlir/Transforms/WalkPatternRewriteDriver.h",
],
@@ -9901,7 +9900,6 @@ cc_binary(
"//mlir/test:TestMemRef",
"//mlir/test:TestMesh",
"//mlir/test:TestNVGPU",
- "//mlir/test:TestOneToNTypeConversion",
"//mlir/test:TestPDLL",
"//mlir/test:TestPass",
"//mlir/test:TestReducer",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 7d51a3829e9120..a010809274e4c2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -631,23 +631,6 @@ cc_library(
],
)
-cc_library(
- name = "TestOneToNTypeConversion",
- srcs = glob(["lib/Conversion/OneToNTypeConversion/*.cpp"]),
- includes = ["lib/Dialect/Test"],
- deps = [
- ":TestDialect",
- "//mlir:FuncDialect",
- "//mlir:FuncTransforms",
- "//mlir:IR",
- "//mlir:Pass",
- "//mlir:SCFDialect",
- "//mlir:SCFTransforms",
- "//mlir:TransformUtils",
- "//mlir:Transforms",
- ],
-)
-
cc_library(
name = "TestVectorToSPIRV",
srcs = glob(["lib/Conversion/VectorToSPIRV/*.cpp"]),
More information about the llvm-commits
mailing list