[flang-commits] [flang] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)
Matthias Springer via flang-commits
flang-commits at lists.llvm.org
Tue Dec 31 04:33:10 PST 2024
Markus =?utf-8?q?Böck?= <markus.boeck02 at gmail.com>,Matthias Springer
<mspringer at nvidia.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/116524 at github.com>
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524
>From 40612e0473a55a10588ec03105a63c503cf2a112 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 15 Dec 2024 17:36:49 +0100
Subject: [PATCH 1/5] ex
---
.../lib/Optimizer/CodeGen/BoxedProcedure.cpp | 1 -
mlir/docs/DialectConversion.md | 35 +-
.../mlir/Transforms/DialectConversion.h | 18 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 16 +-
.../EmitC/Transforms/TypeConversions.cpp | 1 -
.../Dialect/Linalg/Transforms/Detensorize.cpp | 1 -
.../Quant/Transforms/StripFuncQuantTypes.cpp | 1 -
.../Utils/SparseTensorDescriptor.cpp | 3 -
.../Vector/Transforms/VectorLinearize.cpp | 1 -
.../Transforms/Utils/DialectConversion.cpp | 432 +++++++++---------
mlir/test/Transforms/test-legalizer.mlir | 7 +-
.../Func/TestDecomposeCallGraphTypes.cpp | 2 +-
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 -
.../lib/Transforms/TestDialectConversion.cpp | 1 -
14 files changed, 224 insertions(+), 296 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 1bb91d252529f0..104ae7408b80c1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
addConversion([&](TypeDescType ty) {
return TypeDescType::get(convertType(ty.getOfTy()));
});
- addArgumentMaterialization(materializeProcedure);
addSourceMaterialization(materializeProcedure);
addTargetMaterialization(materializeProcedure);
}
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c7515..abacd5a82c61eb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
type safety during the conversion process. There are several types of
materializations depending on the situation.
-* Argument Materialization
-
- - An argument materialization is used when converting the type of a block
- argument during a [signature conversion](#region-signature-conversion).
- The new block argument types are specified in a `SignatureConversion`
- object. An original block argument can be converted into multiple
- block arguments, which is not supported everywhere in the dialect
- conversion. (E.g., adaptors support only a single replacement value for
- each original value.) Therefore, an argument materialization is used to
- convert potentially multiple new block arguments back into a single SSA
- value. An argument materialization is also used when replacing an op
- result with multiple values.
-
* Source Materialization
- A source materialization is used when a value was replaced with a value
@@ -343,17 +330,6 @@ class TypeConverter {
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
- /// This method registers a materialization that will be called when
- /// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value
- /// with the old argument type.
- template <typename FnT,
- typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
- void addArgumentMaterialization(FnT &&callback) {
- argumentMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
- }
-
/// This method registers a materialization that will be called when
/// converting a replacement value back to its original source type.
/// This is used when some uses of the original value persist beyond the main
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..9a6975dcf8dfae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ class TypeConverter {
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
+ ///
+ /// Note: Argument materializations are used only with the 1:N dialect
+ /// conversion driver. The 1:N dialect conversion driver will be removed soon
+ /// and so will be argument materializations.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
void replaceOp(Operation *op, Operation *newOp) override;
/// Replace the given operation with the new value ranges. The number of op
- /// results and value ranges must match. If an original SSA value is replaced
- /// by multiple SSA values (i.e., a value range has more than 1 element), the
- /// conversion driver will insert an argument materialization to convert the
- /// N SSA values back into 1 SSA value of the original type. The given
- /// operation is erased.
- ///
- /// Note: The argument materialization is a workaround until we have full 1:N
- /// support in the dialect conversion. (It is going to disappear from both
- /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+ /// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
/// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
- /// If set to "true", the dialect conversion attempts to build source/target/
- /// argument materializations through the type converter API in lieu of
+ /// If set to "true", the dialect conversion attempts to build source/target
+ /// materializations through the type converter API in lieu of
/// "builtin.unrealized_conversion_cast ops". The conversion process fails if
/// at least one materialization could not be built.
///
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 49e2d943286645..72799e42cf3fd1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
- // An argument materialization must return a value of type
+ // A source materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
Value packed =
@@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
MemRefType resultType,
ValueRange inputs, Location loc,
const LLVMTypeConverter &converter) {
- // An argument materialization must return a value of type `resultType`,
+ // A source materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
Value packed =
@@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});
- // Argument materializations convert from the new block argument types
+ // Source materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
- addArgumentMaterialization([&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
- return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
- *this);
- });
- addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc) {
- return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
- });
addSourceMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType, ValueRange inputs,
Location loc) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
converter.addSourceMaterialization(materializeAsUnrealizedCast);
converter.addTargetMaterialization(materializeAsUnrealizedCast);
- converter.addArgumentMaterialization(materializeAsUnrealizedCast);
}
/// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 0e651f4cee4c36..fc6671ef811759 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
});
addSourceMaterialization(sourceMaterializationCallback);
- addArgumentMaterialization(sourceMaterializationCallback);
}
};
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
addConversion(convertQuantizedType);
addConversion(convertTensorType);
- addArgumentMaterialization(materializeConversion);
addSourceMaterialization(materializeConversion);
addTargetMaterialization(materializeConversion);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
// Required by scf.for 1:N type conversion.
addSourceMaterialization(materializeTuple);
-
- // Required as a workaround until we have full 1:N support.
- addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
};
- typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559ee6..96cbe07f0f12f9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
+/// Given two insertion points in the same block, choose the later one.
+static OpBuilder::InsertPoint
+chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a,
+ OpBuilder::InsertPoint b) {
+ assert(a.getBlock() == b.getBlock() && "expected same block");
+ Block *block = a.getBlock();
+ if (a.getPoint() == block->begin())
+ return b;
+ if (b.getPoint() == block->begin())
+ return a;
+ if (a.getPoint()->isBeforeInBlock(&*b.getPoint()))
+ return b;
+ return a;
+}
+
+/// Helper function that chooses the insertion point among the two given ones
+/// that is later.
+// TODO: Extend DominanceInfo API to work with block iterators.
+static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a,
+ OpBuilder::InsertPoint b) {
+ // Case 1: Same block.
+ if (a.getBlock() == b.getBlock())
+ return chooseLaterInsertPointInBlock(a, b);
+
+ // Case 2: Different block, but same region.
+ if (a.getBlock()->getParent() == b.getBlock()->getParent()) {
+ DominanceInfo domInfo;
+ if (domInfo.properlyDominates(a.getBlock(), b.getBlock()))
+ return b;
+ if (domInfo.properlyDominates(b.getBlock(), a.getBlock()))
+ return a;
+ // Neither of the two blocks dominante each other.
+ llvm_unreachable("unable to find valid insertion point");
+ }
+
+ // Case 3: b's region contains a: choose a.
+ if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion(
+ *a.getPoint()->getParentOp()))
+ return a;
+
+ // Case 4: a's region contains b: choose b.
+ if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion(
+ *b.getPoint()->getParentOp()))
+ return b;
+
+ // Neither of the two operations contain each other.
+ llvm_unreachable("unable to find valid insertion point");
+}
+
/// Helper function that computes an insertion point where the given value is
/// defined and can be used without a dominance violation.
static OpBuilder::InsertPoint computeInsertPoint(Value value) {
@@ -63,11 +113,36 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
return OpBuilder::InsertPoint(insertBlock, insertPt);
}
+/// Helper function that computes an insertion point where the given values are
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+ assert(!vals.empty() && "expected at least one value");
+ OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+ for (Value v : vals.drop_front())
+ pt = chooseLaterInsertPoint(pt, computeInsertPoint(v));
+ return pt;
+}
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
+/// A vector of SSA values, optimized for the most common case of a single
+/// value.
+using ValueVector = SmallVector<Value, 1>;
+
namespace {
+
+/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
+struct ValueVectorMapInfo {
+ static ValueVector getEmptyKey() { return ValueVector{}; }
+ static ValueVector getTombstoneKey() { return ValueVector{}; }
+ static ::llvm::hash_code getHashValue(ValueVector val) {
+ return ::llvm::hash_combine_range(val.begin(), val.end());
+ }
+ static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; }
+};
+
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
@@ -75,68 +150,103 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
- /// Lookup the most recently mapped value with the desired type in the
+ /// Lookup the most recently mapped values with the desired types in the
/// mapping.
///
/// Special cases:
- /// - If the desired type is "null", simply return the most recently mapped
- /// value.
- /// - If there is no mapping to the desired type, also return the most
- /// recently mapped value.
- /// - If there is no mapping for the given value at all, return the given
- /// value.
- Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
- /// Lookup a mapped value within the map, or return null if a mapping does not
- /// exist. If a mapping exists, this follows the same behavior of
- /// `lookupOrDefault`.
- Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+ /// - If the desired type range is empty, simply return the most recently
+ /// mapped values.
+ /// - If there is no mapping to the desired types, also return the most
+ /// recently mapped values.
+ /// - If there is no mapping for the given values at all, return the given
+ /// values.
+ ValueVector lookupOrDefault(ValueVector from,
+ TypeRange desiredTypes = {}) const;
+
+ /// Lookup the given values within the map, or return an empty vector if the
+ /// values are not mapped. If they are mapped, this follows the same behavior
+ /// as `lookupOrDefault`.
+ ValueVector lookupOrNull(const ValueVector &from,
+ TypeRange desiredTypes = {}) const;
/// Map a value to the one provided.
- void map(Value oldVal, Value newVal) {
+ void map(const ValueVector &oldVal, const ValueVector &newVal) {
LLVM_DEBUG({
- for (Value it = newVal; it; it = mapping.lookupOrNull(it))
- assert(it != oldVal && "inserting cyclic mapping");
+ ValueVector next = newVal;
+ while (true) {
+ assert(next != oldVal && "inserting cyclic mapping");
+ auto it = mapping.find(next);
+ if (it == mapping.end())
+ break;
+ next = it->second;
+ }
});
- mapping.map(oldVal, newVal);
- mappedTo.insert(newVal);
+ mapping[oldVal] = newVal;
+ for (Value v : newVal)
+ mappedTo.insert(v);
}
- /// Drop the last mapping for the given value.
- void erase(Value value) { mapping.erase(value); }
+ /// Drop the last mapping for the given values.
+ void erase(ValueVector value) { mapping.erase(value); }
private:
/// Current value mappings.
- IRMapping mapping;
+ DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
/// All SSA values that are mapped to. May contain false positives.
DenseSet<Value> mappedTo;
};
} // namespace
-Value ConversionValueMapping::lookupOrDefault(Value from,
- Type desiredType) const {
- // Try to find the deepest value that has the desired type. If there is no
- // such value, simply return the deepest value.
- Value desiredValue;
+ValueVector
+ConversionValueMapping::lookupOrDefault(ValueVector from,
+ TypeRange desiredTypes) const {
+ // Try to find the deepest values that have the desired types. If there is no
+ // such mapping, simply return the deepest values.
+ ValueVector desiredValue;
do {
- if (!desiredType || from.getType() == desiredType)
+ // Store the current value if the types match.
+ if (desiredTypes.empty() || TypeRange(from) == desiredTypes)
desiredValue = from;
- Value mappedValue = mapping.lookupOrNull(from);
- if (!mappedValue)
+ // If possible, Replace each value with (one or multiple) mapped values.
+ ValueVector next;
+ for (Value v : from) {
+ auto it = mapping.find({v});
+ if (it != mapping.end()) {
+ llvm::append_range(next, it->second);
+ } else {
+ next.push_back(v);
+ }
+ }
+ if (next != from) {
+ // If at least one value was replaced, continue the lookup from there.
+ from = next;
+ continue;
+ }
+
+ // Otherwise: Check if there is a mapping for the entire vector. Such
+ // mappings are materializations. (N:M mapping are not supported for value
+ // replacements.)
+ auto it = mapping.find(from);
+ if (it == mapping.end()) {
+ // No mapping found: The lookup stops here.
break;
- from = mappedValue;
+ }
+ from = it->second;
} while (true);
- // If the desired value was found use it, otherwise default to the leaf value.
- return desiredValue ? desiredValue : from;
+ // If the desired values were found use them, otherwise default to the leaf
+ // values.
+ return !desiredValue.empty() ? desiredValue : from;
}
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
- Value result = lookupOrDefault(from, desiredType);
- if (result == from || (desiredType && result.getType() != desiredType))
- return nullptr;
+ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from,
+ TypeRange desiredTypes) const {
+ ValueVector result = lookupOrDefault(from, desiredTypes);
+ TypeRange resultTypes(result);
+ if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes))
+ return {};
return result;
}
@@ -651,10 +761,6 @@ class CreateOperationRewrite : public OperationRewrite {
/// The type of materialization.
enum MaterializationKind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to the original one.
- Argument,
-
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target,
@@ -673,7 +779,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType,
- Value mappedValue);
+ ValueVector mappedValues);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,9 +814,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// materializations.
Type originalType;
- /// The value in the conversion value mapping that is being replaced by the
+ /// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
- Value mappedValue;
+ ValueVector mappedValues;
};
} // namespace
@@ -779,7 +885,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped);
+ SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -820,39 +926,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If a cast op was built, it can optionally be returned with the `castOp`
/// output argument.
///
- /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+ /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp = nullptr);
- Value buildUnresolvedMaterialization(
- MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
- const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr) {
- return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
- TypeRange(outputType), originalType,
- converter, castOp)
- .front();
- }
-
- /// Build an N:1 materialization for the given original value that was
- /// replaced with the given replacement values.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. The conversion mapping can store only 1:1 replacements
- /// and the conversion patterns only support single Value replacements in the
- /// adaptor, so N values must be converted back to a single value. This
- /// function will be deleted when full 1:N support has been added.
- ///
- /// This function inserts an argument materialization back to the original
- /// type.
- void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
- ValueRange replacements, Value originalValue,
- const TypeConverter *converter);
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
@@ -862,16 +943,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value findOrBuildReplacementValue(Value value,
const TypeConverter *converter);
- /// Unpack an N:1 materialization and return the inputs of the
- /// materialization. This function unpacks only those materializations that
- /// were built with `insertNTo1Materialization`.
- ///
- /// This is a workaround around incomplete 1:N support in the dialect
- /// conversion driver. It allows us to write 1:N conversion patterns while
- /// 1:N support is still missing in the conversion value mapping. This
- /// function will be deleted when full 1:N support has been added.
- SmallVector<Value> unpackNTo1Materialization(Value value);
-
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -1041,7 +1112,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener =
@@ -1082,7 +1153,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
void ReplaceOperationRewrite::rollback() {
for (auto result : op->getResults())
- rewriterImpl.mapping.erase(result);
+ rewriterImpl.mapping.erase({result});
}
void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1101,18 +1172,18 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind, Type originalType,
- Value mappedValue)
+ ValueVector mappedValues)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType),
- mappedValue(mappedValue) {
+ mappedValues(mappedValues) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
void UnresolvedMaterializationRewrite::rollback() {
- if (mappedValue)
- rewriterImpl.mapping.erase(mappedValue);
+ if (!mappedValues.empty())
+ rewriterImpl.mapping.erase(mappedValues);
rewriterImpl.unresolvedMaterializations.erase(getOperation());
rewriterImpl.nTo1TempMaterializations.erase(getOperation());
op->erase();
@@ -1160,7 +1231,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVector<SmallVector<Value>> &remapped) {
+ SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1239,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type origType = operand.getType();
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- // Find the most recently mapped value. Unpack all temporary N:1
- // materializations. Such conversions are a workaround around missing
- // 1:N support in the ConversionValueMapping. (The conversion patterns
- // already support 1:N replacements.)
- Value repl = mapping.lookupOrDefault(operand);
- SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
- // pass through the most recently mapped value.
- remapped.push_back(std::move(unpacked));
+ // pass through the most recently mapped values.
+ remapped.push_back(mapping.lookupOrDefault({operand}));
continue;
}
@@ -1192,51 +1256,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
});
return failure();
}
-
// If a type is converted to 0 types, there is nothing to do.
if (legalTypes.empty()) {
remapped.push_back({});
continue;
}
- if (legalTypes.size() != 1) {
- // TODO: This is a 1:N conversion. The conversion value mapping does not
- // store such materializations yet. If the types of the most recently
- // mapped values do not match, build a target materialization.
- ValueRange unpackedRange(unpacked);
- if (TypeRange(unpackedRange) == legalTypes) {
- remapped.push_back(std::move(unpacked));
- continue;
- }
-
- // Insert a target materialization if the current pattern expects
- // different legalized types.
- ValueRange targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
- /*valueToMap=*/Value(), /*inputs=*/unpacked,
- /*outputType=*/legalTypes, /*originalType=*/origType,
- currentTypeConverter);
- remapped.push_back(targetMat);
+ ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
+ if (!repl.empty() && TypeRange(repl) == legalTypes) {
+ // Mapped values have the correct type or there is an existing
+ // materialization. Or the opreand is not mapped at all and has the
+ // correct type.
+ remapped.push_back(repl);
continue;
}
- // Handle 1->1 type conversions.
- Type desiredType = legalTypes.front();
- // Try to find a mapped value with the desired type. (Or the operand itself
- // if the value is not mapped at all.)
- Value newOperand = mapping.lookupOrDefault(operand, desiredType);
- if (newOperand.getType() != desiredType) {
- // If the looked up value's type does not have the desired type, it means
- // that the value was replaced with a value of different type and no
- // target materialization was created yet.
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
- /*outputType=*/desiredType, /*originalType=*/origType,
- currentTypeConverter);
- newOperand = castValue;
- }
- remapped.push_back({newOperand});
+ // Create a materialization for the most recently mapped values.
+ repl = mapping.lookupOrDefault({operand});
+ ValueRange castValues = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+ /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
+ /*originalType=*/origType, currentTypeConverter);
+ remapped.push_back(castValues);
}
return success();
}
@@ -1353,7 +1394,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+ /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
@@ -1364,7 +1405,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map(origArg, repl);
+ mapping.map({origArg}, {repl});
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1375,13 +1416,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- if (replArgs.size() == 1) {
- mapping.map(origArg, replArgs.front());
- } else {
- insertNTo1Materialization(
- OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
- }
+ ValueVector replArgVals = llvm::map_to_vector<1>(
+ replArgs, [](BlockArgument arg) -> Value { return arg; });
+ mapping.map({origArg}, replArgVals);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1402,7 +1439,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
UnrealizedConversionCastOp *castOp) {
assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1447,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes) {
- if (valueToMap) {
- assert(inputs.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, inputs.front());
- }
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, inputs);
return inputs;
}
@@ -1423,37 +1458,21 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
- if (valueToMap) {
- assert(outputTypes.size() == 1 && "1:N mapping is not supported");
- mapping.map(valueToMap, convertOp.getResult(0));
- }
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- originalType, valueToMap);
+ originalType, valuesToMap);
return convertOp.getResults();
}
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
- OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
- Value originalValue, const TypeConverter *converter) {
- // Insert argument materialization back to the original type.
- Type originalType = originalValue.getType();
- UnrealizedConversionCastOp argCastOp;
- buildUnresolvedMaterialization(
- MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType,
- /*originalType=*/Type(), converter, &argCastOp);
- if (argCastOp)
- nTo1TempMaterializations.insert(argCastOp);
-}
-
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
// Find a replacement value with the same type.
- Value repl = mapping.lookupOrNull(value, value.getType());
- if (repl)
- return repl;
+ ValueVector repl = mapping.lookupOrNull({value}, value.getType());
+ if (!repl.empty())
+ return repl.front();
// Check if the value is dead. No replacement value is needed in that case.
// This is an approximate check that may have false negatives but does not
@@ -1467,8 +1486,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
- repl = mapping.lookupOrNull(value);
- if (!repl) {
+ repl = mapping.lookupOrNull({value});
+ if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
// a source materialization producing a replacement value "out of thin air"
@@ -1478,34 +1497,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
}
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter);
- mapping.map(value, castValue);
+ /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(),
+ /*originalType=*/Type(), converter)[0];
+ mapping.map({value}, {castValue});
return castValue;
}
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
- // Unpack unrealized_conversion_cast ops that were inserted as a N:1
- // workaround.
- auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
- if (!castOp)
- return {value};
- if (!nTo1TempMaterializations.contains(castOp))
- return {value};
- assert(castOp->getNumResults() == 1 && "expected single result");
-
- SmallVector<Value> result;
- for (Value v : castOp.getOperands()) {
- // Keep unpacking if possible. This is needed because during block
- // signature conversions and 1:N op replacements, the driver may have
- // inserted two materializations back-to-back: first an argument
- // materialization, then a target materialization.
- llvm::append_range(result, unpackNTo1Materialization(v));
- }
- return result;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1554,7 +1551,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Materialize a replacement value "out of thin air".
buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+ result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
continue;
@@ -1572,16 +1569,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
-
- if (repl.size() == 1) {
- // Single replacement value: replace directly.
- mapping.map(result, repl.front());
- } else {
- // Multiple replacement values: insert N:1 materialization.
- insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
- /*replacements=*/repl, /*outputValue=*/result,
- currentTypeConverter);
- }
+ mapping.map({result}, repl);
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1648,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<ValueRange> newVals;
- for (size_t i = 0; i < newValues.size(); ++i)
- newVals.push_back(newValues.slice(i, 1));
+ for (size_t i = 0; i < newValues.size(); ++i) {
+ if (newValues[i]) {
+ newVals.push_back(newValues.slice(i, 1));
+ } else {
+ newVals.push_back(ValueRange());
+ }
+ }
impl->notifyOpReplaced(op, newVals);
}
@@ -1729,11 +1722,11 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+ impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<SmallVector<Value>> remappedValues;
+ SmallVector<ValueVector> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
@@ -1746,7 +1739,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<ValueVector> remapped;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
remapped)))
return failure();
@@ -1872,7 +1865,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<SmallVector<Value>> remapped;
+ SmallVector<ValueVector> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
op->getOperands(), remapped))) {
return failure();
@@ -2625,19 +2618,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
switch (rewrite->getMaterializationKind()) {
- case MaterializationKind::Argument: {
- // Try to materialize an argument conversion.
- assert(op->getNumResults() == 1 && "expected single result");
- Value argMat = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
- if (argMat) {
- newMaterialization.push_back(argMat);
- break;
- }
- }
- // If an argument materialization failed, fallback to trying a target
- // materialization.
- [[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 297eb5acef21b7..4cd196c5b44b31 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) {
// Contents of the old block are moved to the new block.
// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
-// The new block arguments are used in "test.return".
-// CHECK-NEXT: notifyOperationModified: test.return
-
// The old block is erased.
// CHECK-NEXT: notifyBlockErased
@@ -390,8 +387,8 @@ func.func @caller() {
// CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
%0:2 = func.call @callee() : () -> (f32, i24)
- // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
- // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+ // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24
+ // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
// CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
// expected-remark @below{{'test.some_user' is not legalizable}}
"test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
tupleType.getFlattenedTypes(types);
return success();
});
- typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 826c222990be4f..eae9b887e9d49a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1284,7 +1284,6 @@ struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
>From d6bf2cd3636c115cd7fe26a5213917832ee5eb38 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Dec 2024 14:03:02 +0100
Subject: [PATCH 2/5] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
.../Transforms/Utils/DialectConversion.cpp | 28 ++++++++++---------
1 file changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 96cbe07f0f12f9..8d6291f0f4f0d7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -187,7 +187,7 @@ struct ConversionValueMapping {
}
/// Drop the last mapping for the given values.
- void erase(ValueVector value) { mapping.erase(value); }
+ void erase(const ValueVector &value) { mapping.erase(value); }
private:
/// Current value mappings.
@@ -221,7 +221,7 @@ ConversionValueMapping::lookupOrDefault(ValueVector from,
}
if (next != from) {
// If at least one value was replaced, continue the lookup from there.
- from = next;
+ from = std::move(next);
continue;
}
@@ -1175,7 +1175,7 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ValueVector mappedValues)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType),
- mappedValues(mappedValues) {
+ mappedValues(std::move(mappedValues)) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
@@ -1265,9 +1265,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
if (!repl.empty() && TypeRange(repl) == legalTypes) {
// Mapped values have the correct type or there is an existing
- // materialization. Or the opreand is not mapped at all and has the
+ // materialization. Or the operand is not mapped at all and has the
// correct type.
- remapped.push_back(repl);
+ remapped.push_back(std::move(repl));
continue;
}
@@ -1416,8 +1416,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- ValueVector replArgVals = llvm::map_to_vector<1>(
- replArgs, [](BlockArgument arg) -> Value { return arg; });
+ ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
mapping.map({origArg}, replArgVals);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1462,8 +1461,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- originalType, valuesToMap);
+ appendRewrite<UnresolvedMaterializationRewrite>(
+ convertOp, converter, kind, originalType, std::move(valuesToMap));
return convertOp.getResults();
}
@@ -1495,10 +1494,13 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// `applySignatureConversion`.)
return Value();
}
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter)[0];
+ Value castValue =
+ buildUnresolvedMaterialization(MaterializationKind::Source,
+ computeInsertPoint(repl), value.getLoc(),
+ /*valuesToMap=*/{value}, /*inputs=*/repl,
+ /*outputType=*/value.getType(),
+ /*originalType=*/Type(), converter)
+ .front();
mapping.map({value}, {castValue});
return castValue;
}
>From d8d77d66eaadede0e20647c74c4833eca1ba5e60 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 29 Dec 2024 13:25:20 +0100
Subject: [PATCH 3/5] rebase fixes
---
mlir/test/Transforms/test-legalizer.mlir | 9 ++-------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 8 --------
2 files changed, 2 insertions(+), 15 deletions(-)
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 4cd196c5b44b31..ae7d344b7167f9 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -491,13 +491,8 @@ func.func @test_1_to_n_block_signature_conversion() {
// CHECK-LABEL: func @test_multiple_1_to_n_replacement()
// CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
-// TODO: There should be a single cast (i.e., a single target materialization).
-// This is currently not possible due to 1:N limitations of the conversion
-// mapping. Instead, we have 3 argument materializations.
-// CHECK: %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
-// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
-// CHECK: %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
-// CHECK: "test.valid"(%[[cast3]]) : (f16) -> ()
+// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
+// CHECK: "test.valid"(%[[cast]]) : (f16) -> ()
func.func @test_multiple_1_to_n_replacement() {
%0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
"test.invalid"(%0) : (f16) -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eae9b887e9d49a..5b7c36c9b97bf4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1264,14 +1264,6 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
// Replace test.multiple_1_to_n_replacement with test.step_1.
Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
// Now replace test.step_1 with test.legal_op.
- // TODO: Ideally, it should not be necessary to reset the insertion point
- // here. Based on the API calls, it looks like test.step_1 is entirely
- // erased. But that's not the case: an argument materialization will
- // survive. And that argument materialization will be used by the users of
- // `op`. If we don't reset the insertion point here, we get dominance
- // errors. This will be fixed when we have 1:N support in the conversion
- // value mapping.
- rewriter.setInsertionPoint(repl1);
replaceWithDoubleResults(repl1, "test.legal_op");
return success();
}
>From 2efb6d71e535323f05f79b39ef429d91ec7d436f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Sat, 21 Dec 2024 19:01:19 +0100
Subject: [PATCH 4/5] use universal references for `map`
---
.../Transforms/Utils/DialectConversion.cpp | 36 ++++++++++++++-----
1 file changed, 27 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8d6291f0f4f0d7..2a5c11c3d32ec3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -169,10 +169,15 @@ struct ConversionValueMapping {
ValueVector lookupOrNull(const ValueVector &from,
TypeRange desiredTypes = {}) const;
+ template <typename T>
+ struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
+
/// Map a value to the one provided.
- void map(const ValueVector &oldVal, const ValueVector &newVal) {
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
+ map(OldVal &&oldVal, NewVal &&newVal) {
LLVM_DEBUG({
- ValueVector next = newVal;
+ ValueVector next(newVal);
while (true) {
assert(next != oldVal && "inserting cyclic mapping");
auto it = mapping.find(next);
@@ -181,9 +186,22 @@ struct ConversionValueMapping {
next = it->second;
}
});
- mapping[oldVal] = newVal;
for (Value v : newVal)
mappedTo.insert(v);
+
+ mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+ }
+
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
+ map(OldVal &&oldVal, NewVal &&newVal) {
+ if constexpr (IsValueVector<OldVal>{}) {
+ map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+ } else if constexpr (IsValueVector<NewVal>{}) {
+ map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+ } else {
+ map(ValueVector{oldVal}, ValueVector{newVal});
+ }
}
/// Drop the last mapping for the given values.
@@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map({origArg}, {repl});
+ mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1417,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
- mapping.map({origArg}, replArgVals);
+ mapping.map(origArg, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1447,7 +1465,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes) {
if (!valuesToMap.empty())
- mapping.map(valuesToMap, inputs);
+ mapping.map(std::move(valuesToMap), inputs);
return inputs;
}
@@ -1501,7 +1519,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
/*outputType=*/value.getType(),
/*originalType=*/Type(), converter)
.front();
- mapping.map({value}, {castValue});
+ mapping.map(value, castValue);
return castValue;
}
@@ -1571,7 +1589,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
- mapping.map({result}, repl);
+ mapping.map(result, repl);
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1724,7 +1742,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
+ impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
>From 8e9b48a6daff04a6d15cbacc2d60d422a6f53722 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 30 Dec 2024 19:35:01 +0100
Subject: [PATCH 5/5] address comments
---
.../Transforms/Utils/DialectConversion.cpp | 63 ++++++++++---------
1 file changed, 32 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2a5c11c3d32ec3..3571e017158be9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -137,10 +137,12 @@ namespace {
struct ValueVectorMapInfo {
static ValueVector getEmptyKey() { return ValueVector{}; }
static ValueVector getTombstoneKey() { return ValueVector{}; }
- static ::llvm::hash_code getHashValue(ValueVector val) {
+ static ::llvm::hash_code getHashValue(const ValueVector &val) {
return ::llvm::hash_combine_range(val.begin(), val.end());
}
- static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; }
+ static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
+ return LHS == RHS;
+ }
};
/// This class wraps a IRMapping to provide recursive lookup
@@ -159,20 +161,18 @@ struct ConversionValueMapping {
/// - If there is no mapping to the desired types, also return the most
/// recently mapped values.
/// - If there is no mapping for the given values at all, return the given
- /// values.
- ValueVector lookupOrDefault(ValueVector from,
- TypeRange desiredTypes = {}) const;
+ /// value.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
- /// Lookup the given values within the map, or return an empty vector if the
- /// values are not mapped. If they are mapped, this follows the same behavior
+ /// Lookup the given value within the map, or return an empty vector if the
+ /// value is not mapped. If it is mapped, this follows the same behavior
/// as `lookupOrDefault`.
- ValueVector lookupOrNull(const ValueVector &from,
- TypeRange desiredTypes = {}) const;
+ ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
- /// Map a value to the one provided.
+ /// Map a value vector to the one provided.
template <typename OldVal, typename NewVal>
std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
map(OldVal &&oldVal, NewVal &&newVal) {
@@ -192,6 +192,7 @@ struct ConversionValueMapping {
mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
}
+ /// Map a value vector or single value to the one provided.
template <typename OldVal, typename NewVal>
std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
map(OldVal &&oldVal, NewVal &&newVal) {
@@ -217,19 +218,20 @@ struct ConversionValueMapping {
} // namespace
ValueVector
-ConversionValueMapping::lookupOrDefault(ValueVector from,
+ConversionValueMapping::lookupOrDefault(Value from,
TypeRange desiredTypes) const {
// Try to find the deepest values that have the desired types. If there is no
// such mapping, simply return the deepest values.
ValueVector desiredValue;
+ ValueVector current{from};
do {
// Store the current value if the types match.
- if (desiredTypes.empty() || TypeRange(from) == desiredTypes)
- desiredValue = from;
+ if (TypeRange(current) == desiredTypes)
+ desiredValue = current;
// If possible, Replace each value with (one or multiple) mapped values.
ValueVector next;
- for (Value v : from) {
+ for (Value v : current) {
auto it = mapping.find({v});
if (it != mapping.end()) {
llvm::append_range(next, it->second);
@@ -237,33 +239,35 @@ ConversionValueMapping::lookupOrDefault(ValueVector from,
next.push_back(v);
}
}
- if (next != from) {
+ if (next != current) {
// If at least one value was replaced, continue the lookup from there.
- from = std::move(next);
+ current = std::move(next);
continue;
}
// Otherwise: Check if there is a mapping for the entire vector. Such
// mappings are materializations. (N:M mapping are not supported for value
// replacements.)
- auto it = mapping.find(from);
+ auto it = mapping.find(current);
if (it == mapping.end()) {
// No mapping found: The lookup stops here.
break;
}
- from = it->second;
+ current = it->second;
} while (true);
// If the desired values were found use them, otherwise default to the leaf
// values.
- return !desiredValue.empty() ? desiredValue : from;
+ // Note: If `desiredTypes` is empty, this function always returns `current`.
+ return !desiredValue.empty() ? desiredValue : current;
}
-ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from,
+ValueVector ConversionValueMapping::lookupOrNull(Value from,
TypeRange desiredTypes) const {
ValueVector result = lookupOrDefault(from, desiredTypes);
TypeRange resultTypes(result);
- if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes))
+ if (result == ValueVector{from} ||
+ (!desiredTypes.empty() && resultTypes != desiredTypes))
return {};
return result;
}
@@ -1261,7 +1265,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped values.
- remapped.push_back(mapping.lookupOrDefault({operand}));
+ remapped.push_back(mapping.lookupOrDefault(operand));
continue;
}
@@ -1280,7 +1284,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
continue;
}
- ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
+ ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
if (!repl.empty() && TypeRange(repl) == legalTypes) {
// Mapped values have the correct type or there is an existing
// materialization. Or the operand is not mapped at all and has the
@@ -1290,7 +1294,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = mapping.lookupOrDefault({operand});
+ repl = mapping.lookupOrDefault(operand);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1428,10 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
continue;
}
- // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
- // dialect conversion. Therefore, we need an argument materialization to
- // turn the replacement block arguments into a single SSA value that can be
- // used as a replacement.
+ // This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
@@ -1487,7 +1488,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
// Find a replacement value with the same type.
- ValueVector repl = mapping.lookupOrNull({value}, value.getType());
+ ValueVector repl = mapping.lookupOrNull(value, value.getType());
if (!repl.empty())
return repl.front();
@@ -1503,7 +1504,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
- repl = mapping.lookupOrNull({value});
+ repl = mapping.lookupOrNull(value);
if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
@@ -1742,7 +1743,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
+ impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
More information about the flang-commits
mailing list