[Mlir-commits] [mlir] [mlir][Transforms] Add 1:N `matchAndRewrite` overload (PR #116470)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 18 17:12:53 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116470
>From 8f9fafd0d61451e4ff1ef0e1beb1b9b1a2a05cc7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 12 Nov 2024 05:14:43 +0100
Subject: [PATCH 1/4] replace with multiple
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Apply suggestions from code review
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
address comments
[WIP] 1:N conversion pattern
update test cases
---
.../mlir/Conversion/LLVMCommon/Pattern.h | 35 ++-
.../mlir/Transforms/DialectConversion.h | 63 +++++
.../Transforms/DecomposeCallGraphTypes.cpp | 56 +---
.../Func/Transforms/FuncConversions.cpp | 5 +-
.../Transforms/StructuralTypeConversions.cpp | 106 +++-----
.../Transforms/SparseTensorCodegen.cpp | 114 ++++----
.../Transforms/Utils/SparseTensorDescriptor.h | 16 +-
.../Transforms/Utils/DialectConversion.cpp | 251 ++++++++++++------
.../decompose-call-graph-types.mlir | 38 +--
9 files changed, 381 insertions(+), 303 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b66398e09..86ea87b55af1cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
- rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index de47765006f81e..e4eeb39b9c0741 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -537,6 +537,10 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
+ virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
@@ -547,6 +551,11 @@ class ConversionPattern : public RewritePattern {
rewrite(op, operands, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
@@ -574,6 +583,15 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
+ /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
+ /// try to extract the single value of each range to construct a the inputs
+ /// for a 1:1 adaptor.
+ ///
+ /// This function produces a fatal error if at least one range has 0 or
+ /// more than 1 value: "pattern 'name' does not support 1:N conversion"
+ SmallVector<Value>
+ getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
@@ -589,6 +607,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +627,24 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -623,6 +655,12 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +669,13 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
@@ -656,11 +701,20 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewrite(cast<SourceOp>(op), operands, rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -668,6 +722,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +734,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a08764326a80b6..03be00328bda33 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -13,40 +13,6 @@
using namespace mlir;
using namespace mlir::func;
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
- Value value,
- const TypeConverter *converter) {
- // Try to convert the given value's type. If that fails, just return the
- // given value.
- SmallVector<Type> convertedTypes;
- if (failed(converter->convertType(value.getType(), convertedTypes)))
- return {value};
- if (convertedTypes.empty())
- return {};
-
- // If the given value's type is already legal, just return the given value.
- TypeRange convertedTypeRange(convertedTypes);
- if (convertedTypeRange == TypeRange(value.getType()))
- return {value};
-
- // Try to materialize a target conversion. If the materialization did not
- // produce values of the requested type, the materialization failed. Just
- // return the given value in that case.
- SmallVector<Value> result = converter->materializeTargetConversion(
- builder, loc, convertedTypeRange, value);
- if (result.empty())
- return {value};
- return result;
-}
-
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands()) {
- // TODO: We can directly take the values from the adaptor once this is a
- // 1:N conversion pattern.
- llvm::append_range(newOperands,
- decomposeValue(rewriter, operand.getLoc(), operand,
- getTypeConverter()));
- }
+ for (ValueRange operand : adaptor.getOperands())
+ llvm::append_range(newOperands, operand);
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CallOp op, OpAdaptor adaptor,
+ matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands()) {
- // TODO: We can directly take the values from the adaptor once this is a
- // 1:N conversion pattern.
- llvm::append_range(newOperands,
- decomposeValue(rewriter, operand.getLoc(), operand,
- getTypeConverter()));
- }
+ for (ValueRange operand : adaptor.getOperands())
+ llvm::append_range(newOperands, operand);
// Create the new result types for the new `CallOp` and track the number of
// replacement types for each original op result.
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..d81f822f7d4b51 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
- matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
@@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
// Substitute with the new result types from the corresponding FuncType
// conversion.
rewriter.replaceOpWithNewOp<CallOp>(
- callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+ callOp, callOp.getCallee(), convertedResults,
+ getOneToOneAdaptorOperands(adaptor.getOperands()));
return success();
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..c0589044c26ecb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,18 @@ using namespace mlir::scf;
namespace {
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
- SmallVectorImpl<Value> &unpacked) {
- if (auto cast =
- dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
- if (cast.getInputs().size() != 1) {
- // 1 : N type conversion.
- unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
- return;
- }
- }
- // 1 : 1 type conversion.
- unpacked.push_back(v);
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
}
// CRTP
@@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
- using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+ using OneToNOpAdaptor =
+ typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
//
// Derived classes should provide the following method which performs the
// actual conversion. It should return std::nullopt upon conversion failure
// and return the converted operation upon success.
//
- // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
- // ConversionPatternRewriter &rewriter,
- // TypeRange dstTypes) const;
+ // std::optional<SourceOp> convertSourceOp(
+ // SourceOp op, OneToNOpAdaptor adaptor,
+ // ConversionPatternRewriter &rewriter,
+ // TypeRange dstTypes) const;
LogicalResult
- matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
@@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
return rewriter.notifyMatchFailure(op, "could not convert operation");
// Packs the return value.
- SmallVector<Value> packedRets;
+ SmallVector<ValueRange> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults().slice(start, len);
- if (len != 1) {
- // 1 : N type conversion.
- Type origType = op.getResultTypes()[i - 1];
- Value mat = typeConverter->materializeSourceConversion(
- rewriter, op.getLoc(), origType, mappedValue);
- if (!mat) {
- return rewriter.notifyMatchFailure(
- op, "Failed to materialize 1:N type conversion");
- }
- packedRets.push_back(mat);
- } else {
- // 1 : 1 type conversion.
- packedRets.push_back(mappedValue.front());
- }
+ packedRets.push_back(mappedValue);
}
- rewriter.replaceOp(op, packedRets);
+ rewriter.replaceOpWithMultiple(op, packedRets);
return success();
}
};
@@ -105,7 +92,7 @@ class ConvertForOpTypes
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
// The callback required by CRTP.
- std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+ std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
@@ -129,16 +116,13 @@ class ConvertForOpTypes
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
return std::nullopt;
- // Unpacked the iteration arguments.
- SmallVector<Value> flatArgs;
- for (Value arg : adaptor.getInitArgs())
- unpackUnrealizedConversionCast(arg, flatArgs);
-
// We can not do clone as the number of result types after conversion
// might be different.
- ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
- adaptor.getUpperBound(),
- adaptor.getStep(), flatArgs);
+ ForOp newOp = rewriter.create<ForOp>(
+ op.getLoc(), getSingleValue(adaptor.getLowerBound()),
+ getSingleValue(adaptor.getUpperBound()),
+ getSingleValue(adaptor.getStep()),
+ flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -160,12 +144,12 @@ class ConvertIfOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
- std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+ std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
- adaptor.getCondition(), true);
+ IfOp newOp = rewriter.create<IfOp>(
+ op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
@@ -189,15 +173,11 @@ class ConvertWhileOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
- std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+ std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- // Unpacked the iteration arguments.
- SmallVector<Value> flatArgs;
- for (Value arg : adaptor.getOperands())
- unpackUnrealizedConversionCast(arg, flatArgs);
-
- auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+ auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
+ flattenValues(adaptor.getOperands()));
for (auto i : {0u, 1u}) {
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +198,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
+ matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> unpackedYield;
- for (Value operand : adaptor.getOperands())
- unpackUnrealizedConversionCast(operand, unpackedYield);
-
- rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(
+ op, flattenValues(adaptor.getOperands()));
return success();
}
};
@@ -235,13 +212,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+ matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> unpackedYield;
- for (Value operand : adaptor.getOperands())
- unpackUnrealizedConversionCast(operand, unpackedYield);
-
- rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 25fca49cb0154a..9184224e7aef4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -39,25 +39,18 @@ using namespace mlir::sparse_tensor;
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Flattens a list of operands that may contain sparse tensors.
-static void flattenOperands(ValueRange operands,
- SmallVectorImpl<Value> &flattened) {
- // In case of
- // sparse_tensor, c, sparse_tensor
- // ==>
- // memref ..., c, memref ...
- for (auto operand : operands) {
- if (getSparseTensorEncoding(operand.getType())) {
- auto tuple = getTuple(operand);
- // An unrealized_conversion_cast will be inserted by type converter to
- // inter-mix the gap between 1:N conversion between sparse tensors and
- // fields. In this case, take the operands in the cast and replace the
- // sparse tensor output with the flattened type array.
- flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
- } else {
- flattened.push_back(operand);
- }
- }
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
}
/// Generates a load with proper `index` typing.
@@ -567,12 +560,11 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
// Create a return with the flattened value extracted from sparse tensors.
- rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(
+ op, flattenValues(adaptor.getOperands()));
return success();
}
};
@@ -583,7 +575,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
// The default CallOp converter can not handle 1:N type conversion.
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// In case of:
@@ -596,10 +588,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
return failure();
// (1) Generates new call with flattened return value.
- SmallVector<Value> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
- auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
- finalRetTy, flattened);
+ auto newCall = rewriter.create<func::CallOp>(
+ loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
// (2) Gather sparse tensor returns.
SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
@@ -643,7 +633,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(LvlOp op, OpAdaptor adaptor,
+ matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> lvl = op.getConstantLvlIndex();
RankedTensorType srcType = op.getSource().getType();
@@ -662,7 +652,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
@@ -693,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
- rewriter.replaceOp(op, adaptor.getInputCoo());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
return success();
}
};
@@ -703,7 +693,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ matchAndRewrite(Op op,
+ typename OpConversionPattern<Op>::OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply lowers to specifer.get <field> operation.
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
@@ -721,14 +712,14 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite identically annotated source/dest.
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (!encDst || encDst != encSrc)
return failure();
- rewriter.replaceOp(op, adaptor.getOperands());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
};
@@ -737,10 +728,10 @@ class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply fold the operation.
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
};
@@ -756,7 +747,7 @@ class SparseTensorAllocConverter
enableBufferInitialization(enableInit) {}
LogicalResult
- matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
+ matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto resType = getSparseTensorType(op);
if (!resType.hasEncoding())
@@ -791,7 +782,8 @@ class SparseTensorAllocConverter
}
// Level size equals to dimension size since lvl2dim map is an identity map.
SmallVector<Value> lvlSizesValues;
- createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
+ createDimSizes(rewriter, loc, resType,
+ flattenValues(adaptor.getDynamicSizes()),
/*dimSizesValues=*/lvlSizesValues);
// Construct allocation for each field.
@@ -861,7 +853,7 @@ class SparseTensorDeallocConverter
createDeallocs(createDeallocs) {}
LogicalResult
- matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+ matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto enc = getSparseTensorEncoding(op.getTensor().getType());
if (!enc)
@@ -892,7 +884,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(LoadOp op, OpAdaptor adaptor,
+ matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
@@ -911,7 +903,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+ matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
@@ -963,16 +955,16 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+ matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
- Value values = adaptor.getValues();
- Value filled = adaptor.getFilled();
- Value added = adaptor.getAdded();
- Value count = adaptor.getCount();
+ Value values = getSingleValue(adaptor.getValues());
+ Value filled = getSingleValue(adaptor.getFilled());
+ Value added = getSingleValue(adaptor.getAdded());
+ Value count = getSingleValue(adaptor.getCount());
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();
@@ -1005,7 +997,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
SmallVector<Type> flatSpTensorTps = llvm::to_vector(
llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
- params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+ SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
+ params.append(flatLvlCoords.begin(), flatLvlCoords.end());
params.push_back(crd);
params.push_back(value);
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
@@ -1033,9 +1026,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto stt = getSparseTensorType(adaptor.getDest());
+ auto stt = getSparseTensorType(op.getDest());
if (!stt.hasEncoding())
return failure();
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
@@ -1045,8 +1038,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
- params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
- params.push_back(adaptor.getScalar());
+ SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
+ params.append(flatIndices.begin(), flatIndices.end());
+ params.push_back(getSingleValue(adaptor.getScalar()));
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
@@ -1062,7 +1056,7 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
using OpAdaptor = typename ToPositionsOp::Adaptor;
using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested position access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1085,7 +1079,7 @@ class SparseToCoordinatesConverter
using OpAdaptor = typename ToCoordinatesOp::Adaptor;
using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1111,7 +1105,7 @@ class SparseToCoordinatesBufferConverter
using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1133,7 +1127,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
using OpAdaptor = typename ToValuesOp::Adaptor;
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1153,7 +1147,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+ matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
@@ -1173,7 +1167,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
Type srcElemTp = op.getSource().getType().getElementType();
// Fold the trivial cases.
if (retElemTp == srcElemTp && encDst == encSrc) {
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
//
@@ -1239,7 +1233,7 @@ class SparseExtractSliceConverter
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
@@ -1296,7 +1290,7 @@ class SparseNumberOfEntriesConverter
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+ matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
@@ -1430,7 +1424,7 @@ struct SparseDisassembleOpConverter
: OpConversionPattern(typeConverter, context) {}
LogicalResult
- matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index 89858546e37e1b..869c7864d75354 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -228,11 +228,6 @@ class MutSparseTensorDescriptor
}
};
-/// Returns the "tuple" value of the adapted tensor.
-inline UnrealizedConversionCastOp getTuple(Value tensor) {
- return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
-}
-
/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
@@ -246,16 +241,15 @@ inline Value genTuple(OpBuilder &builder, Location loc,
}
inline SparseTensorDescriptor
-getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
- auto tuple = getTuple(tensor);
- return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
+getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
+ return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
}
inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+getMutDescriptorFromTensorTuple(ValueRange adaptorValues,
+ SmallVectorImpl<Value> &fields,
RankedTensorType type) {
- auto tuple = getTuple(tensor);
- fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
+ fields.assign(adaptorValues.begin(), adaptorValues.end());
return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b2cfd370900a8..627b87b92921d8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,10 +67,6 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
// ConversionValueMapping
//===----------------------------------------------------------------------===//
-/// A list of replacement SSA values. Optimized for the common case of a single
-/// SSA value.
-using ReplacementValues = SmallVector<Value, 1>;
-
namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -780,7 +776,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped);
+ SmallVector<SmallVector<Value>> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -814,13 +810,27 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Materializations
//===--------------------------------------------------------------------===//
- /// Build an unresolved materialization operation given an output type and set
- /// of input operands.
- Value buildUnresolvedMaterialization(MaterializationKind kind,
- OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType,
- Type originalType,
- const TypeConverter *converter);
+ /// Build an unresolved materialization operation given a range of output
+ /// types and a list of input operands. Returns the inputs if they their
+ /// types match the output types.
+ ///
+ /// If a cast op was built, it can optionally be returned with the `castOp`
+ /// output argument.
+ ValueRange buildUnresolvedMaterialization(
+ MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+ ValueRange inputs, TypeRange outputTypes, Type originalType,
+ const TypeConverter *converter,
+ UnrealizedConversionCastOp *castOp = nullptr);
+ Value buildUnresolvedMaterialization(
+ MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+ ValueRange inputs, Type outputType, Type originalType,
+ const TypeConverter *converter,
+ UnrealizedConversionCastOp *castOp = nullptr) {
+ return buildUnresolvedMaterialization(kind, ip, loc, inputs,
+ TypeRange(outputType), originalType,
+ converter, castOp)
+ .front();
+ }
/// Build an N:1 materialization for the given original value that was
/// replaced with the given replacement values.
@@ -838,6 +848,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
+ /// Unpack an N:1 materialization and return the inputs of the
+ /// materialization. This function unpacks only those materializations that
+ /// were built with `insertNTo1Materialization`.
+ ///
+ /// This is a workaround around incomplete 1:N support in the dialect
+ /// conversion driver. It allows us to write 1:N conversion patterns while
+ /// 1:N support is still missing in the conversion value mapping. This
+ /// function will be deleted when full 1:N support has been added.
+ SmallVector<Value> unpackNTo1Materialization(Value value);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -847,7 +867,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -940,6 +960,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
unresolvedMaterializations;
+ /// A set of all N:1 materializations that were added to work around
+ /// incomplete 1:N support in the dialect conversion driver.
+ DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1076,6 +1100,7 @@ void UnresolvedMaterializationRewrite::rollback() {
rewriterImpl.mapping.erase(input);
}
rewriterImpl.unresolvedMaterializations.erase(getOperation());
+ rewriterImpl.nTo1TempMaterializations.erase(getOperation());
op->erase();
}
@@ -1119,7 +1144,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped) {
+ SmallVector<SmallVector<Value>> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1131,7 +1156,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back({mapping.lookupOrDefault(operand)});
continue;
}
@@ -1145,15 +1170,32 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return failure();
}
+ // If a type is converted to 0 types, there is nothing to do.
+ if (legalTypes.empty()) {
+ remapped.push_back({});
+ continue;
+ }
+
if (legalTypes.size() != 1) {
- // TODO: Parts of the dialect conversion infrastructure do not support
- // 1->N type conversions yet. Therefore, if a type is converted to 0 or
- // multiple types, the only thing that we can do for now is passing
- // through the most recently mapped value. Fixing this requires
- // improvements to the `ConversionValueMapping` (to be able to store 1:N
- // mappings) and to the `ConversionPattern` adaptor handling (to be able
- // to pass multiple remapped values for a single operand to the adaptor).
- remapped.push_back(mapping.lookupOrDefault(operand));
+ // TODO: This is a 1:N conversion. The conversion value mapping does not
+ // support such conversions yet. It stores the result of an argument
+ // materialization (i.e., a conversion back into a single SSA value)
+ // instead. Unpack such "workaround" materializations and hand the
+ // original replacement values to the adaptor.
+ Value repl = mapping.lookupOrDefault(operand);
+ SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
+ if (TypeRange(unpacked) == legalTypes) {
+ remapped.push_back(unpacked);
+ continue;
+ }
+
+ // Insert a target materialization if the current pattern expects
+ // different legalized types.
+ ValueRange targetMat = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+ /*inputs=*/repl, /*outputType=*/legalTypes,
+ /*originalType=*/origType, currentTypeConverter);
+ remapped.push_back(targetMat);
continue;
}
@@ -1165,7 +1207,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (newOperand.getType() != desiredType) {
// If the looked up value's type does not have the desired type, it means
// that the value was replaced with a value of different type and no
- // source materialization was created yet.
+ // target materialization was created yet.
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc,
@@ -1174,7 +1216,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
mapping.map(newOperand, castValue);
newOperand = castValue;
}
- remapped.push_back(newOperand);
+ remapped.push_back({newOperand});
}
return success();
}
@@ -1329,26 +1371,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
-Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType, Type originalType,
- const TypeConverter *converter) {
+ ValueRange inputs, TypeRange outputTypes, Type originalType,
+ const TypeConverter *converter, UnrealizedConversionCastOp *castOp) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
// Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
- return inputs.front();
+ if (TypeRange(inputs) == outputTypes)
+ return inputs;
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
- OpBuilder builder(outputType.getContext());
+ OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
+ if (castOp)
+ *castOp = convertOp;
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
originalType);
- return convertOp.getResult(0);
+ return convertOp.getResults();
}
void ConversionPatternRewriterImpl::insertNTo1Materialization(
@@ -1356,10 +1400,13 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
Value originalValue, const TypeConverter *converter) {
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
- Value argMat =
- buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
- /*inputs=*/replacements, originalType,
- /*originalType=*/Type(), converter);
+ UnrealizedConversionCastOp argCastOp;
+ Value argMat = buildUnresolvedMaterialization(
+ MaterializationKind::Argument, ip, loc,
+ /*inputs=*/replacements, originalType,
+ /*originalType=*/Type(), converter, &argCastOp);
+ if (argCastOp)
+ nTo1TempMaterializations.insert(argCastOp);
mapping.map(originalValue, argMat);
// Insert target materialization to the legalized type.
@@ -1376,14 +1423,36 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
+ UnrealizedConversionCastOp targetCastOp;
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat), loc,
/*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
+ /*originalType=*/originalType, converter, &targetCastOp);
+ if (targetCastOp)
+ nTo1TempMaterializations.insert(targetCastOp);
mapping.map(argMat, targetMat);
}
}
+SmallVector<Value>
+ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
+ // Unpack unrealized_conversion_cast ops that were inserted as a N:1
+ // workaround.
+ auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!castOp)
+ return {value};
+ if (!nTo1TempMaterializations.contains(castOp))
+ return {value};
+ assert(castOp->getNumResults() == 1 && "expected single result");
+
+ SmallVector<Value> result;
+ for (Value v : castOp.getOperands()) {
+ // Keep unpacking if possible.
+ llvm::append_range(result, unpackNTo1Materialization(v));
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1408,7 +1477,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
}
void ConversionPatternRewriterImpl::notifyOpReplaced(
- Operation *op, ArrayRef<ReplacementValues> newValues) {
+ Operation *op, ArrayRef<ValueRange> newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
@@ -1420,8 +1489,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
isUnresolvedMaterialization = true;
// Create mappings for each of the new result values.
- for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
- ReplacementValues repl = n;
+ for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
if (repl.empty()) {
// This result was dropped and no replacement value was provided.
if (isUnresolvedMaterialization) {
@@ -1436,7 +1504,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
- repl.push_back(sourceMat);
+ mapping.map(result, sourceMat);
+ continue;
} else {
// Make sure that the user does not mess with unresolved materializations
// that were inserted by the conversion driver. We keep track of these
@@ -1538,10 +1607,9 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> newVals(newValues.size());
- for (auto [index, val] : llvm::enumerate(newValues))
- if (val)
- newVals[index].push_back(val);
+ SmallVector<ValueRange> newVals;
+ for (int i = 0; i < newValues.size(); ++i)
+ newVals.push_back(newValues.slice(i, 1));
impl->notifyOpReplaced(op, newVals);
}
@@ -1553,10 +1621,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> newVals(newValues.size(), {});
- for (auto [index, val] : llvm::enumerate(newValues))
- llvm::append_range(newVals[index], val);
- impl->notifyOpReplaced(op, newVals);
+ impl->notifyOpReplaced(op, newValues);
}
void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1564,7 +1629,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
+ SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
impl->notifyOpReplaced(op, nullRepls);
}
@@ -1615,11 +1680,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<Value> remappedValues;
+ SmallVector<SmallVector<Value>> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
- return remappedValues.front();
+ assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+ return remappedValues.front().front();
}
LogicalResult
@@ -1627,8 +1693,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
- results);
+ SmallVector<SmallVector<Value>> remapped;
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ remapped)))
+ return failure();
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ results.push_back(values.front());
+ }
+ return success();
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1722,6 +1795,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+ ArrayRef<ValueRange> operands) const {
+ SmallVector<Value> oneToOneOperands;
+ oneToOneOperands.reserve(operands.size());
+ for (ValueRange operand : operands) {
+ if (operand.size() != 1)
+ llvm::report_fatal_error("pattern '" + getDebugName() +
+ "' does not support 1:N conversion");
+ oneToOneOperands.push_back(operand.front());
+ }
+ return oneToOneOperands;
+}
+
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -1733,12 +1819,14 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
+ SmallVector<SmallVector<Value>> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ op->getOperands(), remapped))) {
return failure();
}
- return matchAndRewrite(op, operands, dialectRewriter);
+ SmallVector<ValueRange> remappedAsRange = llvm::map_to_vector(
+ remapped, [](const auto &v) -> ValueRange { return v; });
+ return matchAndRewrite(op, remappedAsRange, dialectRewriter);
}
//===----------------------------------------------------------------------===//
@@ -1965,19 +2053,19 @@ OperationLegalizer::legalizeWithFold(Operation *op,
});
// Try to fold the operation.
- SmallVector<Value, 2> replacementValues;
+ SmallVector<Value, 2> ValueRange;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, ValueRange))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
- if (replacementValues.empty())
+ if (ValueRange.empty())
return legalize(op, rewriter);
// Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
+ rewriter.replaceOp(op, ValueRange);
// Recursively legalize any new constant operations.
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
@@ -2482,45 +2570,52 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
- Type outputType = op.getResultTypes()[0];
// Try to materialize the conversion.
if (const TypeConverter *converter = rewrite->getConverter()) {
rewriter.setInsertionPoint(op);
- Value newMaterialization;
+ SmallVector<Value> newMaterialization;
switch (rewrite->getMaterializationKind()) {
- case MaterializationKind::Argument:
+ case MaterializationKind::Argument: {
// Try to materialize an argument conversion.
- newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- if (newMaterialization)
+ assert(op->getNumResults() == 1 && "expected single result");
+ Value argMat = converter->materializeArgumentConversion(
+ rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+ if (argMat) {
+ newMaterialization.push_back(argMat);
break;
+ }
+ }
// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands,
+ rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
rewrite->getOriginalType());
break;
case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
+ assert(op->getNumResults() == 1 && "expected single result");
+ Value sourceMat = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+ if (sourceMat)
+ newMaterialization.push_back(sourceMat);
break;
}
- if (newMaterialization) {
- assert(newMaterialization.getType() == outputType &&
+ if (!newMaterialization.empty()) {
+ assert(TypeRange(newMaterialization) == op.getResultTypes() &&
"materialization callback produced value of incorrect type");
rewriter.replaceOp(op, newMaterialization);
return success();
}
}
- InFlightDiagnostic diag =
- op->emitError() << "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to (" << outputType
- << ") that remained live after conversion";
+ InFlightDiagnostic diag = op->emitError()
+ << "failed to legalize unresolved materialization "
+ "from ("
+ << inputOperands.getTypes() << ") to ("
+ << op.getResultTypes()
+ << ") that remained live after conversion";
diag.attachNote(op->getUsers().begin()->getLoc())
<< "see existing live user here: " << *op->getUsers().begin();
return failure();
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index b8fad63eb4de67..4e641317ac2f3d 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -9,10 +9,7 @@
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32
// CHECK-12N-LABEL: func @identity(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK-LABEL: func @mixed_recursive_decomposition(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
-// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
-// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
-// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
-// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
-// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK: return %[[V7]], %[[V10]] : i1, i2
+// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
@@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-LABEL: func @caller(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I4:.*]]: i4,
// CHECK-SAME: %[[I5:.*]]: i5,
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[I1:.*]]: i1,
// CHECK-12N-SAME: %[[I2:.*]]: i2,
>From 0e6a8fd3137f15feb9588aa9de6a762e8c45f108 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 17 Nov 2024 10:32:01 +0900
Subject: [PATCH 2/4] Update mlir/lib/Transforms/Utils/DialectConversion.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 627b87b92921d8..42691dd6ebfabf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1824,8 +1824,8 @@ ConversionPattern::matchAndRewrite(Operation *op,
op->getOperands(), remapped))) {
return failure();
}
- SmallVector<ValueRange> remappedAsRange = llvm::map_to_vector(
- remapped, [](const auto &v) -> ValueRange { return v; });
+ SmallVector<ValueRange> remappedAsRange = llvm::to_vector_of<ValueRange>(
+ remapped);
return matchAndRewrite(op, remappedAsRange, dialectRewriter);
}
>From 8c657ba8e3a1b0ceb02dd4731ddf1486fa3ecb37 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 17 Nov 2024 10:34:00 +0900
Subject: [PATCH 3/4] Update mlir/lib/Transforms/Utils/DialectConversion.cpp
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/lib/Transforms/Utils/DialectConversion.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 42691dd6ebfabf..adfe7bc770e3dd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1185,7 +1185,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Value repl = mapping.lookupOrDefault(operand);
SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
if (TypeRange(unpacked) == legalTypes) {
- remapped.push_back(unpacked);
+ remapped.push_back(std::move(unpacked));
continue;
}
@@ -1824,8 +1824,8 @@ ConversionPattern::matchAndRewrite(Operation *op,
op->getOperands(), remapped))) {
return failure();
}
- SmallVector<ValueRange> remappedAsRange = llvm::to_vector_of<ValueRange>(
- remapped);
+ SmallVector<ValueRange> remappedAsRange =
+ llvm::to_vector_of<ValueRange>(remapped);
return matchAndRewrite(op, remappedAsRange, dialectRewriter);
}
@@ -2053,19 +2053,19 @@ OperationLegalizer::legalizeWithFold(Operation *op,
});
// Try to fold the operation.
- SmallVector<Value, 2> ValueRange;
+ SmallVector<Value, 2> replacementValues;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, ValueRange))) {
+ if (failed(rewriter.tryFold(op, replacementValues))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
- if (ValueRange.empty())
+ if (replacementValues.empty())
return legalize(op, rewriter);
// Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, ValueRange);
+ rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
>From 0fedcfdbc14ff227ec03981ba91a8b0f5cdc1193 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 18 Nov 2024 12:53:41 +0100
Subject: [PATCH 4/4] address comments
---
.../mlir/Transforms/DialectConversion.h | 6 ++
.../Transforms/Utils/DialectConversion.cpp | 26 +++---
mlir/test/Transforms/test-legalizer.mlir | 11 +++
mlir/test/lib/Dialect/Test/TestOps.td | 5 ++
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 84 +++++++++++++++----
5 files changed, 106 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e4eeb39b9c0741..9abb6eedb9d912 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -543,6 +543,9 @@ class ConversionPattern : public RewritePattern {
}
/// Hook for derived classes to implement combined matching and rewriting.
+ /// This overload supports only 1:1 replacements. The 1:N overload is called
+ /// by the driver. By default, it calls this 1:1 overload or reports a fatal
+ /// error if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -551,6 +554,9 @@ class ConversionPattern : public RewritePattern {
rewrite(op, operands, rewriter);
return success();
}
+
+ /// Hook for derived classes to implement combined matching and rewriting.
+ /// This overload supports 1:N replacements.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index adfe7bc770e3dd..d4879c1bc333c3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1152,11 +1152,18 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type origType = operand.getType();
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
+ // Find the most recently mapped value. Unpack all temporary N:1
+ // materializations. Such conversions are a workaround around missing
+ // 1:N support in the ConversionValueMapping. (The conversion patterns
+ // already support 1:N replacements.)
+ Value repl = mapping.lookupOrDefault(operand);
+ SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
+
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back({mapping.lookupOrDefault(operand)});
+ remapped.push_back(std::move(unpacked));
continue;
}
@@ -1178,12 +1185,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (legalTypes.size() != 1) {
// TODO: This is a 1:N conversion. The conversion value mapping does not
- // support such conversions yet. It stores the result of an argument
- // materialization (i.e., a conversion back into a single SSA value)
- // instead. Unpack such "workaround" materializations and hand the
- // original replacement values to the adaptor.
- Value repl = mapping.lookupOrDefault(operand);
- SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
+ // store such materializations yet. If the types of the most recently
+ // mapped values do not match, build a target materialization.
if (TypeRange(unpacked) == legalTypes) {
remapped.push_back(std::move(unpacked));
continue;
@@ -1193,7 +1196,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// different legalized types.
ValueRange targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
- /*inputs=*/repl, /*outputType=*/legalTypes,
+ /*inputs=*/unpacked, /*outputType=*/legalTypes,
/*originalType=*/origType, currentTypeConverter);
remapped.push_back(targetMat);
continue;
@@ -1211,7 +1214,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc,
- /*inputs=*/newOperand, /*outputType=*/desiredType,
+ /*inputs=*/unpacked, /*outputType=*/desiredType,
/*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
newOperand = castValue;
@@ -1447,7 +1450,10 @@ ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
SmallVector<Value> result;
for (Value v : castOp.getOperands()) {
- // Keep unpacking if possible.
+ // Keep unpacking if possible. This is needed because during block
+ // signature conversions and 1:N op replacements, the driver may have
+ // inserted two materializations back-to-back: first an argument
+ // materialization, then a target materialization.
llvm::append_range(result, unpackNTo1Materialization(v));
}
return result;
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..3ebee795a251df 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -463,3 +463,14 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+func.func @test_1_to_n_block_signature_conversion() {
+ "test.duplicate_block_args"() ({
+ ^bb0(%arg0: i64):
+ "test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
+ }) {} : () -> ()
+ "test.return"() : () -> ()
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cfe19a2fd5c08b..d59caacc18ae44 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1886,6 +1886,11 @@ def LegalOpC : TEST_Op<"legal_op_c">,
Arguments<(ins I32)>, Results<(outs I32)>;
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
+def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
+ let arguments = (ins UnitAttr:$is_legal);
+ let regions = (region SizedRegion<1>:$body);
+}
+
// Check that the conversion infrastructure can properly undo the creation of
// operations where an operation was created before its parent, in this case,
// in the parent's builder.
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..0b5239168efc43 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
+ SmallVector<Value> flattened;
+ for (auto it : llvm::enumerate(operands)) {
+ ValueRange range = it.value();
+ if (range.size() == 1) {
+ flattened.push_back(range.front());
+ continue;
+ }
+
+ // This is a 1:N replacement. Insert a test.cast op. (That's what the
+ // argument materialization used to do.)
+ flattened.push_back(
+ rewriter
+ .create<TestCastOp>(op->getLoc(),
+ op->getOperand(it.index()).getType(), range)
+ .getResult());
+ }
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
std::nullopt);
return success();
}
@@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
return failure();
-
- // Check if the first operation is a cast operation, if it is we use the
- // results directly.
- auto *defOp = operands[0].getDefiningOp();
- if (auto packerOp =
- llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
- rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
- return success();
- }
-
- // Otherwise, fail to match.
- return failure();
+ rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
+ return success();
}
};
@@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern {
}
};
+/// This pattern matches a test.duplicate_block_args op and duplicates all
+/// block arguments.
+class TestDuplicateBlockArgs
+ : public OpConversionPattern<DuplicateBlockArgsOp> {
+ using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getIsLegal())
+ return failure();
+ rewriter.startOpModification(op);
+ Block *body = &op.getBody().front();
+ TypeConverter::SignatureConversion result(body->getNumArguments());
+ for (auto it : llvm::enumerate(body->getArgumentTypes()))
+ result.addInputs(it.index(), {it.value(), it.value()});
+ rewriter.applySignatureConversion(body, result, getTypeConverter());
+ op.setIsLegal(true);
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+};
+
+/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
+/// op. The pattern supports 1:N replacements and forwards the replacement
+/// values of the single operand as test.valid operands.
+class TestRepetitive1ToNConsumer : public ConversionPattern {
+public:
+ TestRepetitive1ToNConsumer(MLIRContext *ctx)
+ : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // A single operand is expected.
+ if (op->getNumOperands() != 1)
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1258,9 +1305,11 @@ struct TestLegalizePatternDriver
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
&getContext(), converter);
+ patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1312,6 +1361,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
+ target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
+ [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
More information about the Mlir-commits
mailing list