[Mlir-commits] [mlir] a8416e3 - Revert "[mlir] Implement pass utils for 1:N type conversions."
Ingo Müller
llvmlistbot at llvm.org
Mon Mar 27 02:24:08 PDT 2023
Author: Ingo Müller
Date: 2023-03-27T09:23:57Z
New Revision: a8416e3c047a7d590fe3884ed49d965c4425d5c3
URL: https://github.com/llvm/llvm-project/commit/a8416e3c047a7d590fe3884ed49d965c4425d5c3
DIFF: https://github.com/llvm/llvm-project/commit/a8416e3c047a7d590fe3884ed49d965c4425d5c3.diff
LOG: Revert "[mlir] Implement pass utils for 1:N type conversions."
This reverts commit 9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0.
Added:
Modified:
mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
mlir/lib/Transforms/Utils/CMakeLists.txt
mlir/test/Transforms/decompose-call-graph-types.mlir
mlir/test/lib/Conversion/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
mlir/include/mlir/Transforms/OneToNTypeConversion.h
mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
deleted file mode 100644
index 2fba342ea80e7..0000000000000
--- a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h
+++ /dev/null
@@ -1,26 +0,0 @@
-//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
-#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
-
-namespace mlir {
-class TypeConverter;
-class RewritePatternSet;
-} // namespace mlir
-
-namespace mlir {
-
-// Populates the provided pattern set with patterns that do 1:N type conversions
-// on func ops. This is intended to be used with `applyPartialOneToNConversion`.
-void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
deleted file mode 100644
index 25beee28c6ed5..0000000000000
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ /dev/null
@@ -1,256 +0,0 @@
-//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file provides utils for implementing (poor-man's) dialect conversion
-// passes with 1:N type conversions.
-//
-// The main function, `applyPartialOneToNConversion`, first applies a set of
-// `RewritePattern`s, which produce unrealized casts to convert the operands and
-// results from and to the source types, and then replaces all newly added
-// unrealized casts by user-provided materializations. For this to work, the
-// main function requires a special `TypeConverter`, a special
-// `PatternRewriter`, and special RewritePattern`s, which extend their
-// respective base classes for 1:N type converions.
-//
-// Note that this is much more simple-minded than the "real" dialect conversion,
-// which checks for legality before applying patterns and does probably many
-// other additional things. Ideally, some of the extensions here could be
-// integrated there.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
-#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
-
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SmallVector.h"
-
-namespace mlir {
-
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
- /// Callback that expresses user-provided materialization logic from the given
- /// value to N values of the given types. This is useful for expressing target
- /// materializations for 1:N type conversions, which materialize one value in
- /// a source type as N values in target types.
- using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
-
- /// Creates the mapping of the given range of original types to target types
- /// of the conversion and stores that mapping in the given (signature)
- /// conversion. This function simply calls
- /// `TypeConverter::convertSignatureArgs` and exists here with a
diff erent
- /// name to reflect the broader semantic.
- LogicalResult computeTypeMapping(TypeRange types,
- SignatureConversion &result) {
- return convertSignatureArgs(types, result);
- }
-
- /// Applies one of the user-provided 1:N target materializations. If several
- /// exists, they are tried out in the reverse order in which they have been
- /// added until the first one succeeds. If none succeeds, the functions
- /// returns `std::nullopt`.
- std::optional<SmallVector<Value>>
- materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
-
- /// Adds a 1:N target materialization to the converter. Such materializations
- /// build IR that converts N values with target types into 1 value of the
- /// source type.
- void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
- oneToNTargetMaterializations.emplace_back(std::move(callback));
- }
-
-private:
- SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
-/// Stores a 1:N mapping of types and provides several useful accessors. This
-/// class extends `SignatureConversion`, which already supports 1:N type
-/// mappings but lacks some accessors into the mapping as well as access to the
-/// original types.
-class OneToNTypeMapping : public TypeConverter::SignatureConversion {
-public:
- OneToNTypeMapping(TypeRange originalTypes)
- : TypeConverter::SignatureConversion(originalTypes.size()),
- originalTypes(originalTypes) {}
-
- using TypeConverter::SignatureConversion::getConvertedTypes;
-
- /// Returns the list of types that corresponds to the original type at the
- /// given index.
- TypeRange getConvertedTypes(unsigned originalTypeNo) const;
-
- /// Returns the list of original types.
- TypeRange getOriginalTypes() const { return originalTypes; }
-
- /// Returns the slice of converted values that corresponds the original value
- /// at the given index.
- ValueRange getConvertedValues(ValueRange convertedValues,
- unsigned originalValueNo) const;
-
- /// Fills the given result vector with as many copies of the location of the
- /// original value as the number of values it is converted to.
- void convertLocation(Value originalValue, unsigned originalValueNo,
- llvm::SmallVectorImpl<Location> &result) const;
-
- /// Fills the given result vector with as many copies of the lociation of each
- /// original value as the number of values they are respectively converted to.
- void convertLocations(ValueRange originalValues,
- llvm::SmallVectorImpl<Location> &result) const;
-
- /// Returns true iff at least one type conversion maps an input type to a type
- /// that is
diff erent from itself.
- bool hasNonIdentityConversion() const;
-
-private:
- llvm::SmallVector<Type> originalTypes;
-};
-
-/// Extends the basic `RewritePattern` class with a type converter member and
-/// some accessors to it. This is useful for patterns that are not
-/// `ConversionPattern`s but still require access to a type converter.
-class RewritePatternWithConverter : public mlir::RewritePattern {
-public:
- /// Construct a conversion pattern with the given converter, and forward the
- /// remaining arguments to RewritePattern.
- template <typename... Args>
- RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args)
- : RewritePattern(std::forward<Args>(args)...),
- typeConverter(&typeConverter) {}
-
- /// Return the type converter held by this pattern, or nullptr if the pattern
- /// does not require type conversion.
- TypeConverter *getTypeConverter() const { return typeConverter; }
-
- template <typename ConverterTy>
- std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
- ConverterTy *>
- getTypeConverter() const {
- return static_cast<ConverterTy *>(typeConverter);
- }
-
-protected:
- /// A type converter for use by this pattern.
- TypeConverter *const typeConverter;
-};
-
-/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The
-/// class provides additional rewrite methods that are specific to 1:N type
-/// conversions.
-class OneToNPatternRewriter : public PatternRewriter {
-public:
- OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {}
-
- /// Replaces the results of the operation with the specified list of values
- /// mapped back to the original types as specified in the provided type
- /// mapping. That type mapping must match the replaced op (i.e., the original
- /// types must be the same as the result types of the op) and the new values
- /// (i.e., the converted types must be the same as the types of the new
- /// values).
- void replaceOp(Operation *op, ValueRange newValues,
- const OneToNTypeMapping &resultMapping);
- using PatternRewriter::replaceOp;
-
- /// Applies the given argument conversion to the given block. This consists of
- /// replacing each original argument with N arguments as specified in the
- /// argument conversion and inserting unrealized casts from the converted
- /// values to the original types, which are then used in lieu of the original
- /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts
- /// with a user-provided argument materialization if necessary.) This is
- /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N
- /// type conversion properly and probably (2) doesn't handle many other edge
- /// cases.
- Block *applySignatureConversion(Block *block,
- OneToNTypeMapping &argumentConversion);
-};
-
-/// Base class for patterns with 1:N type conversions. Derived classes have to
-/// overwrite the `matchAndRewrite` overlaod that provides additional
-/// information for 1:N type conversions.
-class OneToNConversionPattern : public RewritePatternWithConverter {
-public:
- using RewritePatternWithConverter::RewritePatternWithConverter;
-
- /// This function has to be implemented by base classes and is called from the
- /// usual overloads. Like in "normal" `DialectConversion`, the function is
- /// provided with the converted operands (which thus have target types). Since
- /// 1:N conversion are supported, there is usually no 1:1 relationship between
- /// the original and the converted operands. Instead, the provided
- /// `operandMapping` can be used to access the converted operands that
- /// correspond to a particular original operand. Similarly, `resultMapping`
- /// is provided to help with assembling the result values, which may have 1:N
- /// correspondences as well. In that case, the original op should be replaced
- /// with the overload of `replaceOp` that takes the provided `resultMapping`
- /// in order to deal with the mapping of converted result values to their
- /// usages in the original types correctly.
- virtual LogicalResult matchAndRewrite(Operation *op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const = 0;
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const final;
-};
-
-/// This class is a wrapper around `OneToNConversionPattern` for matching
-/// against instances of a particular op class.
-template <typename SourceOp>
-class OneToNOpConversionPattern : public OneToNConversionPattern {
-public:
- OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
- PatternBenefit benefit = 1,
- ArrayRef<StringRef> generatedNames = {})
- : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(),
- benefit, context, generatedNames) {}
-
- using OneToNConversionPattern::matchAndRewrite;
-
- /// Overload that derived classes have to override for their op type.
- virtual LogicalResult matchAndRewrite(SourceOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const = 0;
-
- LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const final {
- return matchAndRewrite(cast<SourceOp>(op), rewriter, operandMapping,
- resultMapping, convertedOperands);
- }
-};
-
-/// Applies the given set of patterns recursively on the given op and adds user
-/// materializations where necessary. The patterns are expected to be
-/// `OneToNConversionPattern`, which help converting the types of the operands
-/// and results of the matched ops. The provided type converter is used to
-/// convert the operands of matched ops from their original types to operands
-/// with
diff erent types. Unlike in `DialectConversion`, this supports 1:N type
-/// conversions. Those conversions at the "boundary" of the pattern application,
-/// where converted results are not consumed by replaced ops that expect the
-/// converted operands or vice versa, the function inserts user materializations
-/// from the type converter. Also unlike `DialectConversion`, there are no legal
-/// or illegal types; the function simply applies the given patterns and does
-/// not fail if some ops or types remain unconverted (i.e., the conversion is
-/// only "partial").
-LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
- const FrozenRewritePatternSet &patterns);
-
-} // namespace mlir
-
-#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
index 172019907c3a8..9a5b38ba6ea2c 100644
--- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRFuncTransforms
DuplicateFunctionElimination.cpp
FuncBufferize.cpp
FuncConversions.cpp
- OneToNFuncConversions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
deleted file mode 100644
index 5e8125ca94283..0000000000000
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ /dev/null
@@ -1,132 +0,0 @@
-//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// The patterns in this file are heavily inspired (and copied from)
-// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the
-// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N
-// type conversions.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-using namespace mlir;
-using namespace mlir::func;
-
-namespace {
-
-class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
-public:
- using OneToNOpConversionPattern<CallOp>::OneToNOpConversionPattern;
-
- LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
- Location loc = op->getLoc();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!operandMapping.hasNonIdentityConversion() &&
- !resultMapping.hasNonIdentityConversion())
- return failure();
-
- // Create new CallOp.
- auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
- convertedOperands);
- newOp->setAttrs(op->getAttrs());
-
- rewriter.replaceOp(op, newOp->getResults(), resultMapping);
- return success();
- }
-};
-
-class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
-public:
- using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping & /*operandMapping*/,
- const OneToNTypeMapping & /*resultMapping*/,
- ValueRange /*convertedOperands*/) const override {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
-
- // Construct mapping for function arguments.
- OneToNTypeMapping argumentMapping(op.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
- argumentMapping)))
- return failure();
-
- // Construct mapping for function results.
- OneToNTypeMapping funcResultMapping(op.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
- funcResultMapping)))
- return failure();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!argumentMapping.hasNonIdentityConversion() &&
- !funcResultMapping.hasNonIdentityConversion())
- return failure();
-
- // Update the function signature in-place.
- auto newType = FunctionType::get(rewriter.getContext(),
- argumentMapping.getConvertedTypes(),
- funcResultMapping.getConvertedTypes());
- rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
-
- // Update block signatures.
- if (!op.isExternal()) {
- Region *region = &op.getBody();
- Block *block = ®ion->front();
- rewriter.applySignatureConversion(block, argumentMapping);
- }
-
- return success();
- }
-};
-
-class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
-public:
- using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
-
- LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping & /*resultMapping*/,
- ValueRange convertedOperands) const override {
- // Nothing to do if there is no non-identity conversion.
- if (!operandMapping.hasNonIdentityConversion())
- return failure();
-
- // Convert operands.
- rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); });
-
- return success();
- }
-};
-
-} // namespace
-
-namespace mlir {
-
-void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<
- // clang-format off
- ConvertTypesInFuncCallOp,
- ConvertTypesInFuncFuncOp,
- ConvertTypesInFuncReturnOp
- // clang-format on
- >(typeConverter, patterns.getContext());
-}
-
-} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 6892d00d1d743..ba8fa20faf592 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -6,7 +6,6 @@ add_mlir_library(MLIRTransformUtils
GreedyPatternRewriteDriver.cpp
InliningUtils.cpp
LoopInvariantCodeMotionUtils.cpp
- OneToNTypeConversion.cpp
RegionUtils.cpp
TopologicalSortUtils.cpp
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
deleted file mode 100644
index c0866f87f6833..0000000000000
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ /dev/null
@@ -1,405 +0,0 @@
-//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallSet.h"
-
-using namespace llvm;
-using namespace mlir;
-
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
- Location loc,
- TypeRange resultTypes,
- Value input) const {
- for (const OneToNMaterializationCallbackFn &fn :
- llvm::reverse(oneToNTargetMaterializations)) {
- if (std::optional<SmallVector<Value>> result =
- fn(builder, resultTypes, input, loc))
- return *result;
- }
- return std::nullopt;
-}
-
-TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
- TypeRange convertedTypes = getConvertedTypes();
- if (auto mapping = getInputMapping(originalTypeNo))
- return convertedTypes.slice(mapping->inputNo, mapping->size);
- return {};
-}
-
-ValueRange
-OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
- unsigned originalValueNo) const {
- if (auto mapping = getInputMapping(originalValueNo))
- return convertedValues.slice(mapping->inputNo, mapping->size);
- return {};
-}
-
-void OneToNTypeMapping::convertLocation(
- Value originalValue, unsigned originalValueNo,
- llvm::SmallVectorImpl<Location> &result) const {
- if (auto mapping = getInputMapping(originalValueNo))
- result.append(mapping->size, originalValue.getLoc());
-}
-
-void OneToNTypeMapping::convertLocations(
- ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
- assert(originalValues.size() == getOriginalTypes().size());
- for (auto [i, value] : llvm::enumerate(originalValues))
- convertLocation(value, i, result);
-}
-
-static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
- return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
-}
-
-bool OneToNTypeMapping::hasNonIdentityConversion() const {
- // XXX: I think that the original types and the converted types are the same
- // iff there was no non-identity type conversion. If that is true, the
- // patterns could actually test whether there is anything useful to do
- // without having access to the signature conversion.
- for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
- TypeRange types = getConvertedTypes(i);
- if (!isIdentityConversion(originalType, types)) {
- assert(TypeRange(originalTypes) != getConvertedTypes());
- return true;
- }
- }
- assert(TypeRange(originalTypes) == getConvertedTypes());
- return false;
-}
-
-namespace {
-enum class CastKind {
- // Casts block arguments in the target type back to the source type. (If
- // necessary, this cast becomes an argument materialization.)
- Argument,
-
- // Casts other values in the target type back to the source type. (If
- // necessary, this cast becomes a source materialization.)
- Source,
-
- // Casts values in the source type to the target type. (If necessary, this
- // cast becomes a target materialization.)
- Target
-};
-}
-
-/// Mapping of enum values to string values.
-StringRef getCastKindName(CastKind kind) {
- static const std::unordered_map<CastKind, StringRef> castKindNames = {
- {CastKind::Argument, "argument"},
- {CastKind::Source, "source"},
- {CastKind::Target, "target"}};
- return castKindNames.at(kind);
-}
-
-/// Attribute name that is used to annotate inserted unrealized casts with their
-/// kind (source, argument, or target).
-static const char *const castKindAttrName =
- "__one-to-n-type-conversion_cast-kind__";
-
-/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
-/// result types. Returns the result values of the cast.
-static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
- ValueRange inputs, CastKind kind) {
- // Create cast.
- Location loc = builder.getUnknownLoc();
- if (!inputs.empty())
- loc = inputs.front().getLoc();
- auto castOp =
- builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
-
- // Store cast kind as attribute.
- auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
- castOp->setAttr(castKindAttrName, kindAttr);
-
- return castOp->getResults();
-}
-
-/// Builds one `UnrealizedConversionCastOp` for each of the given original
-/// values using the respective target types given in the provided conversion
-/// mapping and returns the results of these casts. If the conversion mapping of
-/// a value maps a type to itself (i.e., is an identity conversion), then no
-/// cast is inserted and the original value is returned instead.
-/// Note that these unrealized casts are
diff erent from target materializations
-/// in that they are *always* inserted, even if they immediately fold away, such
-/// that patterns always see valid intermediate IR, whereas materializations are
-/// only used in the places where the unrealized casts *don't* fold away.
-static SmallVector<Value>
-buildUnrealizedForwardCasts(ValueRange originalValues,
- OneToNTypeMapping &conversion,
- RewriterBase &rewriter, CastKind kind) {
-
- // Convert each operand one by one.
- SmallVector<Value> convertedValues;
- convertedValues.reserve(conversion.getConvertedTypes().size());
- for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
- TypeRange convertedTypes = conversion.getConvertedTypes(idx);
-
- // Identity conversion: keep operand as is.
- if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
- convertedValues.push_back(originalValue);
- continue;
- }
-
- // Non-identity conversion: materialize target types.
- ValueRange castResult =
- buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
- convertedValues.append(castResult.begin(), castResult.end());
- }
-
- return convertedValues;
-}
-
-/// Builds one `UnrealizedConversionCastOp` for each sequence of the given
-/// original values to one value of the type they originated from, i.e., a
-/// "reverse" conversion from N converted values back to one value of the
-/// original type, using the given (forward) type conversion. If a given value
-/// was mapped to a value of the same type (i.e., the conversion in the mapping
-/// is an identity conversion), then the "converted" value is returned without
-/// cast.
-/// Note that these unrealized casts are
diff erent from source materializations
-/// in that they are *always* inserted, even if they immediately fold away, such
-/// that patterns always see valid intermediate IR, whereas materializations are
-/// only used in the places where the unrealized casts *don't* fold away.
-static SmallVector<Value>
-buildUnrealizedBackwardsCasts(ValueRange convertedValues,
- const OneToNTypeMapping &typeConversion,
- RewriterBase &rewriter) {
- assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
-
- // Create unrealized cast op for each converted result of the op.
- SmallVector<Value> recastValues;
- TypeRange originalTypes = typeConversion.getOriginalTypes();
- recastValues.reserve(originalTypes.size());
- auto convertedValueIt = convertedValues.begin();
- for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
- TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
- size_t numConvertedValues = convertedTypes.size();
- if (isIdentityConversion(originalType, convertedTypes)) {
- // Identity conversion: take result as is.
- recastValues.push_back(*convertedValueIt);
- } else {
- // Non-identity conversion: cast back to source type.
- ValueRange recastValue = buildUnrealizedCast(
- rewriter, originalType,
- ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
- CastKind::Source);
- assert(recastValue.size() == 1);
- recastValues.push_back(recastValue.front());
- }
- convertedValueIt += numConvertedValues;
- }
-
- return recastValues;
-}
-
-void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
- const OneToNTypeMapping &resultMapping) {
- // Create a cast back to the original types and replace the results of the
- // original op with those.
- assert(newValues.size() == resultMapping.getConvertedTypes().size());
- assert(op->getResultTypes() == resultMapping.getOriginalTypes());
- SmallVector<Value> castResults =
- buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
- replaceOp(op, castResults);
-}
-
-Block *OneToNPatternRewriter::applySignatureConversion(
- Block *block, OneToNTypeMapping &argumentConversion) {
- // Split the block at the beginning to get a new block to use for the
- // updated signature.
- SmallVector<Location> locs;
- argumentConversion.convertLocations(block->getArguments(), locs);
- Block *newBlock =
- createBlock(block, argumentConversion.getConvertedTypes(), locs);
- replaceAllUsesWith(block, newBlock);
-
- // Create necessary casts in new block.
- SmallVector<Value> castResults;
- for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
- TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
- ValueRange newArgs =
- argumentConversion.getConvertedValues(newBlock->getArguments(), i);
- if (isIdentityConversion(arg.getType(), convertedTypes)) {
- // Identity conversion: take argument as is.
- assert(newArgs.size() == 1);
- castResults.push_back(newArgs.front());
- } else {
- // Non-identity conversion: cast the converted arguments to the original
- // type.
- PatternRewriter::InsertionGuard g(*this);
- setInsertionPointToStart(newBlock);
- ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
- CastKind::Argument);
- assert(castResult.size() == 1);
- castResults.push_back(castResult.front());
- }
- }
-
- // Merge old block into new block such that we only have the latter with the
- // new signature.
- mergeBlocks(block, newBlock, castResults);
-
- return newBlock;
-}
-
-LogicalResult
-OneToNConversionPattern::matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
-
- // Construct conversion mapping for results.
- Operation::result_type_range originalResultTypes = op->getResultTypes();
- OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->computeTypeMapping(originalResultTypes,
- resultMapping)))
- return failure();
-
- // Construct conversion mapping for operands.
- Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
- OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
- operandMapping)))
- return failure();
-
- // Cast operands to target types.
- SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts(
- op->getOperands(), operandMapping, rewriter, CastKind::Target);
-
- // Create a `OneToNPatternRewriter` for the pattern, which provides additional
- // functionality.
- // TODO(ingomueller): I guess it would be better to use only one rewriter
- // throughout the whole pass, but that would require to
- // drive the pattern application ourselves, which is a lot
- // of additional boilerplate code. This seems to work fine,
- // so I leave it like this for the time being.
- OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext());
- oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
- oneToNPatternRewriter.setListener(rewriter.getListener());
-
- // Apply actual pattern.
- if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
- resultMapping, convertedOperands)))
- return failure();
-
- return success();
-}
-
-namespace mlir {
-
-// This function applies the provided patterns using
-// `applyPatternsAndFoldGreedily` and then replaces all newly inserted
-// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
-// from target to source types inserted by a `OneToNConversionPattern` normally
-// fold away with the "forward" casts from source to target types inserted by
-// the next pattern.) To understand which casts are "newly inserted", all casts
-// inserted by this pass are annotated with a string attribute that also
-// documents which kind of the cast (source, argument, or target).
-LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
- const FrozenRewritePatternSet &patterns) {
-#ifndef NDEBUG
- // Remember existing unrealized casts. This data structure is only used in
- // asserts; building it only for that purpose may be an overkill.
- SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
- op->walk([&](UnrealizedConversionCastOp castOp) {
- assert(!castOp->hasAttr(castKindAttrName));
- existingCasts.insert(castOp);
- });
-#endif // NDEBUG
-
- // Apply provided conversion patterns.
- if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
- emitError(op->getLoc()) << "failed to apply conversion patterns";
- return failure();
- }
-
- // Find all unrealized casts inserted by the pass that haven't folded away.
- SmallVector<UnrealizedConversionCastOp> worklist;
- op->walk([&](UnrealizedConversionCastOp castOp) {
- if (castOp->hasAttr(castKindAttrName)) {
- assert(!existingCasts.contains(castOp));
- worklist.push_back(castOp);
- }
- });
-
- // Replace new casts with user materializations.
- IRRewriter rewriter(op->getContext());
- for (UnrealizedConversionCastOp castOp : worklist) {
- TypeRange resultTypes = castOp->getResultTypes();
- ValueRange operands = castOp->getOperands();
- StringRef castKind =
- castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
- rewriter.setInsertionPoint(castOp);
-
-#ifndef NDEBUG
- // Determine whether operands or results are already legal to test some
- // assumptions for the
diff erent kind of materializations. These properties
- // are only used it asserts and it may be overkill to compute them.
- bool areOperandTypesLegal = llvm::all_of(
- operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
- bool areResultsTypesLegal = llvm::all_of(
- resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
-#endif // NDEBUG
-
- // Add materialization and remember materialized results.
- SmallVector<Value> materializedResults;
- if (castKind == getCastKindName(CastKind::Target)) {
- // Target materialization.
- assert(!areOperandTypesLegal && areResultsTypesLegal &&
- operands.size() == 1 && "found unexpected target cast");
- std::optional<SmallVector<Value>> maybeResults =
- typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (!maybeResults) {
- emitError(castOp->getLoc())
- << "failed to create target materialization";
- return failure();
- }
- materializedResults = maybeResults.value();
- } else {
- // Source and argument materializations.
- assert(areOperandTypesLegal && !areResultsTypesLegal &&
- resultTypes.size() == 1 && "found unexpected cast");
- std::optional<Value> maybeResult;
- if (castKind == getCastKindName(CastKind::Source)) {
- // Source materialization.
- maybeResult = typeConverter.materializeSourceConversion(
- rewriter, castOp->getLoc(), resultTypes.front(),
- castOp.getOperands());
- } else {
- // Argument materialization.
- assert(castKind == getCastKindName(CastKind::Argument) &&
- "unexpected value of cast kind attribute");
- assert(llvm::all_of(operands,
- [&](Value v) { return v.isa<BlockArgument>(); }));
- maybeResult = typeConverter.materializeArgumentConversion(
- rewriter, castOp->getLoc(), resultTypes.front(),
- castOp.getOperands());
- }
- if (!maybeResult.has_value() || !maybeResult.value()) {
- emitError(castOp->getLoc())
- << "failed to create " << castKind << " materialization";
- return failure();
- }
- materializedResults = {maybeResult.value()};
- }
-
- // Replace the cast with the result of the materialization.
- rewriter.replaceOp(castOp, materializedResults);
- }
-
- return success();
-}
-
-} // namespace mlir
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index 51b63ba4c0ad9..604e948afaf6e 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -1,9 +1,5 @@
// RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s
-// RUN: mlir-opt %s -split-input-file \
-// RUN: -test-one-to-n-type-conversion="convert-func-ops" \
-// RUN: | FileCheck %s --check-prefix=CHECK-12N
-
// Test case: Most basic case of a 1:N decomposition, an identity function.
// CHECK-LABEL: func @identity(
@@ -13,10 +9,6 @@
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
-// CHECK-12N-LABEL: func @identity(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32
func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
return %arg0 : tuple<i1, i32>
}
@@ -28,9 +20,6 @@ func.func @identity(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
// CHECK-LABEL: func @identity_1_to_1_no_materializations(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 {
// CHECK: return %[[ARG0]] : i1
-// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 {
-// CHECK-12N: return %[[ARG0]] : i1
func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
return %arg0 : tuple<i1>
}
@@ -42,9 +31,6 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple<i1>) -> tuple<i1> {
// CHECK-LABEL: func @recursive_decomposition(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 {
// CHECK: return %[[ARG0]] : i1
-// CHECK-12N-LABEL: func @recursive_decomposition(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 {
-// CHECK-12N: return %[[ARG0]] : i1
func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tuple<tuple<i1>>> {
return %arg0 : tuple<tuple<tuple<i1>>>
}
@@ -68,10 +54,6 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK: return %[[V7]], %[[V10]] : i1, i2
-// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2
func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> {
return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
}
@@ -81,7 +63,6 @@ func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<
// Test case: Check decomposition of calls.
// CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32)
-// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32)
func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-LABEL: func @caller(
@@ -95,11 +76,6 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
-// CHECK-12N-LABEL: func @caller(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32
func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
%0 = call @callee(%arg0) : (tuple<i1, i32>) -> tuple<i1, i32>
return %0 : tuple<i1, i32>
@@ -110,15 +86,10 @@ func.func @caller(%arg0: tuple<i1, i32>) -> tuple<i1, i32> {
// Test case: Type that decomposes to nothing (that is, a 1:0 decomposition).
// CHECK-LABEL: func private @callee()
-// CHECK-12N-LABEL: func private @callee()
func.func private @callee(tuple<>) -> tuple<>
-
// CHECK-LABEL: func @caller() {
// CHECK: call @callee() : () -> ()
// CHECK: return
-// CHECK-12N-LABEL: func @caller() {
-// CHECK-12N: call @callee() : () -> ()
-// CHECK-12N: return
func.func @caller(%arg0: tuple<>) -> tuple<> {
%0 = call @callee(%arg0) : (tuple<>) -> (tuple<>)
return %0 : tuple<>
@@ -134,11 +105,6 @@ func.func @caller(%arg0: tuple<>) -> tuple<> {
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
-// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) {
-// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
-// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
-// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
-// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32
func.func @unconverted_op_result() -> tuple<i1, i32> {
%0 = "test.source"() : () -> (tuple<i1, i32>)
return %0 : tuple<i1, i32>
@@ -159,16 +125,6 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
// CHECK: return %[[V3]], %[[V5]] : i1, i32
-// CHECK-12N-LABEL: func @nested_unconverted_op_result(
-// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
-// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
-// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
-// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
-// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
-// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
-// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
-// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32
func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
%0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
return %0 : tuple<i1, tuple<i32>>
@@ -180,7 +136,6 @@ func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1
// This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
// CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
// CHECK-LABEL: func @caller(
@@ -198,15 +153,6 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
-// CHECK-12N-LABEL: func @caller(
-// CHECK-12N-SAME: %[[I1:.*]]: i1,
-// CHECK-12N-SAME: %[[I2:.*]]: i2,
-// CHECK-12N-SAME: %[[I3:.*]]: i3,
-// CHECK-12N-SAME: %[[I4:.*]]: i4,
-// CHECK-12N-SAME: %[[I5:.*]]: i5,
-// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple<i2>, %arg3: i3, %arg4: tuple<i4, i5>, %arg5: i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) {
%0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6)
return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6
diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt
index 14df652ac7dfd..14f0e0dbe1802 100644
--- a/mlir/test/lib/Conversion/CMakeLists.txt
+++ b/mlir/test/lib/Conversion/CMakeLists.txt
@@ -1,3 +1,2 @@
add_subdirectory(FuncToLLVM)
-add_subdirectory(OneToNTypeConversion)
add_subdirectory(VectorToSPIRV)
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
deleted file mode 100644
index 418978688c90d..0000000000000
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-add_mlir_library(MLIRTestOneToNTypeConversionPass
- TestOneToNTypeConversionPass.cpp
-
- EXCLUDE_FROM_LIBMLIR
-
- LINK_LIBS PUBLIC
- MLIRFuncDialect
- MLIRFuncTransforms
- MLIRIR
- MLIRTestDialect
- MLIRTransformUtils
- )
-
-target_include_directories(MLIRTestOneToNTypeConversionPass
- PRIVATE
- ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
- ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
- )
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
deleted file mode 100644
index 220bcb58bf788..0000000000000
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ /dev/null
@@ -1,245 +0,0 @@
-//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "TestDialect.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
-
-using namespace mlir;
-
-namespace {
-/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms
-/// in `applyPartialOneToNConversion` by converting built-in tuples to the
-/// elements they consist of as well as some dummy ops operating on these
-/// tuples.
-struct TestOneToNTypeConversionPass
- : public PassWrapper<TestOneToNTypeConversionPass,
- OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass)
-
- TestOneToNTypeConversionPass() = default;
- TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass)
- : PassWrapper(pass) {}
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<test::TestDialect>();
- }
-
- StringRef getArgument() const final {
- return "test-one-to-n-type-conversion";
- }
-
- StringRef getDescription() const final {
- return "Test pass for 1:N type conversion";
- }
-
- Option<bool> convertFuncOps{*this, "convert-func-ops",
- llvm::cl::desc("Enable conversion on func ops"),
- llvm::cl::init(false)};
-
- Option<bool> convertTupleOps{*this, "convert-tuple-ops",
- llvm::cl::desc("Enable conversion on tuple ops"),
- llvm::cl::init(false)};
-
- void runOnOperation() override;
-};
-
-} // namespace
-
-namespace mlir {
-namespace test {
-void registerTestOneToNTypeConversionPass() {
- PassRegistration<TestOneToNTypeConversionPass>();
-}
-} // namespace test
-} // namespace mlir
-
-namespace {
-
-/// Test pattern on for the `make_tuple` op from the test dialect that converts
-/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple
-/// that is being produced by `test.make_tuple`, which are really just the
-/// operands of this op.
-class ConvertMakeTupleOp
- : public OneToNOpConversionPattern<::test::MakeTupleOp> {
-public:
- using OneToNOpConversionPattern<
- ::test::MakeTupleOp>::OneToNOpConversionPattern;
-
- LogicalResult matchAndRewrite(::test::MakeTupleOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
- // Simply replace the current op with the converted operands.
- rewriter.replaceOp(op, convertedOperands, resultMapping);
- return success();
- }
-};
-
-/// Test pattern on for the `get_tuple_element` op from the test dialect that
-/// converts this kind of op into it's "decomposed" form, i.e., instead of
-/// "physically" extracting one element from the tuple, we forward the one
-/// element of the decomposed form that is being extracted (or the several
-/// elements in case that element is a nested tuple).
-class ConvertGetTupleElementOp
- : public OneToNOpConversionPattern<::test::GetTupleElementOp> {
-public:
- using OneToNOpConversionPattern<
- ::test::GetTupleElementOp>::OneToNOpConversionPattern;
-
- LogicalResult matchAndRewrite(::test::GetTupleElementOp op,
- OneToNPatternRewriter &rewriter,
- const OneToNTypeMapping &operandMapping,
- const OneToNTypeMapping &resultMapping,
- ValueRange convertedOperands) const override {
- // Construct mapping for tuple element types.
- auto stateType = op->getOperand(0).getType().cast<TupleType>();
- TypeRange originalElementTypes = stateType.getTypes();
- OneToNTypeMapping elementMapping(originalElementTypes);
- if (failed(typeConverter->convertSignatureArgs(originalElementTypes,
- elementMapping)))
- return failure();
-
- // Compute converted operands corresponding to original input tuple.
- ValueRange convertedTuple =
- operandMapping.getConvertedValues(convertedOperands, 0);
-
- // Got those converted operands that correspond to the index-th element of
- // the original input tuple.
- size_t index = op.getIndex();
- ValueRange extractedElement =
- elementMapping.getConvertedValues(convertedTuple, index);
-
- rewriter.replaceOp(op, extractedElement, resultMapping);
-
- return success();
- }
-};
-
-} // namespace
-
-static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<
- // clang-format off
- ConvertMakeTupleOp,
- ConvertGetTupleElementOp
- // clang-format on
- >(typeConverter, patterns.getContext());
-}
-
-/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
-/// given tuple value. If some tuple elements are, in turn, tuples, the elements
-/// of those are extracted recursively such that the returned values have the
-/// same types as `resultTypes.getFlattenedTypes()`.
-///
-/// This function has been copied (with small adaptions) from
-/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
- Location loc) {
- TupleType inputType = input.getType().dyn_cast<TupleType>();
- if (!inputType)
- return {};
-
- SmallVector<Value> values;
- for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) {
- Value element = builder.create<::test::GetTupleElementOp>(
- loc, elementType, input, builder.getI32IntegerAttr(idx));
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
- // Recurse if the current element is also a tuple.
- SmallVector<Type> flatRecursiveTypes;
- nestedTupleType.getFlattenedTypes(flatRecursiveTypes);
- std::optional<SmallVector<Value>> resursiveValues =
- buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc);
- if (!resursiveValues.has_value())
- return {};
- values.append(resursiveValues.value());
- } else {
- values.push_back(element);
- }
- }
- return values;
-}
-
-/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
-/// type `resultType`. If that type is nested, each nested tuple is built
-/// recursively with another `test.make_tuple` op.
-///
-/// This function has been copied (with small adaptions) from
-/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
- TupleType resultType,
- ValueRange inputs, Location loc) {
- // Build one value for each element at this nesting level.
- SmallVector<Value> elements;
- elements.reserve(resultType.getTypes().size());
- ValueRange::iterator inputIt = inputs.begin();
- for (Type elementType : resultType.getTypes()) {
- if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
- // Determine how many input values are needed for the nested elements of
- // the nested TupleType and advance inputIt by that number.
- // TODO: We only need the *number* of nested types, not the types itself.
- // Maybe it's worth adding a more efficient overload?
- SmallVector<Type> nestedFlattenedTypes;
- nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
- size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
- ValueRange nestedFlattenedelements(inputIt,
- inputIt + numNestedFlattenedTypes);
- inputIt += numNestedFlattenedTypes;
-
- // Recurse on the values for the nested TupleType.
- std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
- nestedFlattenedelements, loc);
- if (!res.has_value())
- return {};
-
- // The tuple constructed by the conversion is the element value.
- elements.push_back(res.value());
- } else {
- // Base case: take one input as is.
- elements.push_back(*inputIt++);
- }
- }
-
- // Assemble the tuple from the elements.
- return builder.create<::test::MakeTupleOp>(loc, resultType, elements);
-}
-
-void TestOneToNTypeConversionPass::runOnOperation() {
- ModuleOp module = getOperation();
- auto *context = &getContext();
-
- // Assemble type converter.
- OneToNTypeConverter typeConverter;
-
- typeConverter.addConversion([](Type type) { return type; });
- typeConverter.addConversion(
- [](TupleType tupleType, SmallVectorImpl<Type> &types) {
- tupleType.getFlattenedTypes(types);
- return success();
- });
-
- typeConverter.addArgumentMaterialization(buildMakeTupleOp);
- typeConverter.addSourceMaterialization(buildMakeTupleOp);
- typeConverter.addTargetMaterialization(buildGetTupleElementOps);
-
- // Assemble patterns.
- RewritePatternSet patterns(context);
- if (convertTupleOps)
- populateDecomposeTuplesTestPatterns(typeConverter, patterns);
- if (convertFuncOps)
- populateFuncTypeConversionPatterns(typeConverter, patterns);
-
- // Run conversion.
- if (failed(applyPartialOneToNConversion(module, typeConverter,
- std::move(patterns))))
- return signalPassFailure();
-}
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index c43056906da8c..f84fbe631cf16 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -33,7 +33,6 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestDialect
MLIRTestDynDialect
MLIRTestIR
- MLIRTestOneToNTypeConversionPass
MLIRTestPass
MLIRTestPDLL
MLIRTestReducer
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 855c6a6e0f28a..e7ca06b4bef9d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -107,7 +107,6 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
-void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
void registerTestPadFusion();
void registerTestPDLByteCodePass();
@@ -219,7 +218,6 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
- mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestPadFusion();
mlir::test::registerTestPDLByteCodePass();
More information about the Mlir-commits
mailing list