[Mlir-commits] [mlir] 9df63b2 - [mlir][Transforms] Add 1:N `matchAndRewrite` overload (#116470)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 29 16:27:51 PST 2024
Author: Matthias Springer
Date: 2024-11-30T09:27:47+09:00
New Revision: 9df63b2651b2435c02a7d825953ca2ddc65c778e
URL: https://github.com/llvm/llvm-project/commit/9df63b2651b2435c02a7d825953ca2ddc65c778e
DIFF: https://github.com/llvm/llvm-project/commit/9df63b2651b2435c02a7d825953ca2ddc65c778e.diff
LOG: [mlir][Transforms] Add 1:N `matchAndRewrite` overload (#116470)
This commit adds a new `matchAndRewrite` overload to `ConversionPattern`
to support 1:N replacements. This is the first of two main PRs that
merge the 1:1 and 1:N dialect conversion drivers.
The existing `matchAndRewrite` function supports only 1:1 replacements,
as can be seen from the `ArrayRef<Value>` parameter.
```c++
LogicalResult ConversionPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands /*adaptor values*/,
ConversionPatternRewriter &rewriter) const;
```
This commit adds a `matchAndRewrite` overload that is called by the
dialect conversion driver. By default, this new overload dispatches to
the original 1:1 `matchAndRewrite` implementation. Existing
`ConversionPattern`s do not need to be changed as long as there are no
1:N type conversions or value replacements.
```c++
LogicalResult ConversionPattern::matchAndRewrite(
Operation *op, ArrayRef<ValueRange> operands /*adaptor values*/,
ConversionPatternRewriter &rewriter) const {
// Note: getOneToOneAdaptorOperands produces a fatal error if at least one
// ValueRange has 0 or more than 1 value.
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}
```
The `ConversionValueMapping`, which keeps track of value replacements
and materializations, still does not support 1:N replacements. We still
rely on argument materializations to convert N replacement values back
into a single value. The `ConversionValueMapping` will be generalized to
1:N mappings in the second main PR.
Before handing the adaptor values to a `ConversionPattern`, all argument
materializations are "unpacked". The `ConversionPattern` receives N
replacement values and does not see any argument materializations. This
implementation strategy allows us to use the 1:N infrastructure/API in
`ConversionPattern`s even though some functionality is still missing in
the driver. This strategy was chosen to keep the sizes of the PRs
smaller and to make it easier for downstream users to adapt to API
changes.
This commit also updates the the "decompose call graphs" transformation
and the "sparse tensor codegen" transformation to use the new 1:N
`ConversionPattern` API.
Note for LLVM conversion: If you are using a type converter with 1:N
type conversion rules or if your patterns are performing 1:N
replacements (via `replaceOpWithMultiple` or
`applySignatureConversion`), conversion pattern applications will start
failing (fatal LLVM error) with this error message: `pattern 'name' does
not support 1:N conversion`. The name of the failing pattern is shown in
the error message. These patterns must be updated to the new 1:N
`matchAndRewrite` API.
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/decompose-call-graph-types.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b66398e09..86ea87b55af1cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
- rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index aac6b7c03548a9..28150e886913e3 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
+ virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
/// Hook for derived classes to implement combined matching and rewriting.
+ /// This overload supports only 1:1 replacements. The 1:N overload is called
+ /// by the driver. By default, it calls this 1:1 overload or reports a fatal
+ /// error if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
return success();
}
+ /// Hook for derived classes to implement combined matching and rewriting.
+ /// This overload supports 1:N replacements.
+ virtual LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
+
/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
@@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
+ /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
+ /// try to extract the single value of each range to construct a the inputs
+ /// for a 1:1 adaptor.
+ ///
+ /// This function produces a fatal error if at least one range has 0 or
+ /// more than 1 value: "pattern 'name' does not support 1:N conversion"
+ SmallVector<Value>
+ getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
@@ -590,6 +614,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
@@ -657,11 +708,20 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewrite(cast<SourceOp>(op), operands, rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -669,6 +729,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a08764326a80b6..03be00328bda33 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -13,40 +13,6 @@
using namespace mlir;
using namespace mlir::func;
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
- Value value,
- const TypeConverter *converter) {
- // Try to convert the given value's type. If that fails, just return the
- // given value.
- SmallVector<Type> convertedTypes;
- if (failed(converter->convertType(value.getType(), convertedTypes)))
- return {value};
- if (convertedTypes.empty())
- return {};
-
- // If the given value's type is already legal, just return the given value.
- TypeRange convertedTypeRange(convertedTypes);
- if (convertedTypeRange == TypeRange(value.getType()))
- return {value};
-
- // Try to materialize a target conversion. If the materialization did not
- // produce values of the requested type, the materialization failed. Just
- // return the given value in that case.
- SmallVector<Value> result = converter->materializeTargetConversion(
- builder, loc, convertedTypeRange, value);
- if (result.empty())
- return {value};
- return result;
-}
-
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands()) {
- // TODO: We can directly take the values from the adaptor once this is a
- // 1:N conversion pattern.
- llvm::append_range(newOperands,
- decomposeValue(rewriter, operand.getLoc(), operand,
- getTypeConverter()));
- }
+ for (ValueRange operand : adaptor.getOperands())
+ llvm::append_range(newOperands, operand);
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CallOp op, OpAdaptor adaptor,
+ matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands()) {
- // TODO: We can directly take the values from the adaptor once this is a
- // 1:N conversion pattern.
- llvm::append_range(newOperands,
- decomposeValue(rewriter, operand.getLoc(), operand,
- getTypeConverter()));
- }
+ for (ValueRange operand : adaptor.getOperands())
+ llvm::append_range(newOperands, operand);
// Create the new result types for the new `CallOp` and track the number of
// replacement types for each original op result.
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index b1cde6ca5d2fca..9e7759bef6d8fd 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -13,6 +13,14 @@
using namespace mlir;
using namespace mlir::func;
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
namespace {
/// Converts the operand and result types of the CallOp, used together with the
/// FuncOpSignatureConversion.
@@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
- matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the original function results. Keep track of how many result
// types an original result type is converted into.
@@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
// Substitute with the new result types from the corresponding FuncType
// conversion.
- auto newCallOp =
- rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
- convertedResults, adaptor.getOperands());
+ auto newCallOp = rewriter.create<CallOp>(
+ callOp.getLoc(), callOp.getCallee(), convertedResults,
+ flattenValues(adaptor.getOperands()));
SmallVector<ValueRange> replacements;
size_t offset = 0;
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..c0589044c26ecb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,18 @@ using namespace mlir::scf;
namespace {
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
- SmallVectorImpl<Value> &unpacked) {
- if (auto cast =
- dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
- if (cast.getInputs().size() != 1) {
- // 1 : N type conversion.
- unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
- return;
- }
- }
- // 1 : 1 type conversion.
- unpacked.push_back(v);
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
}
// CRTP
@@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
- using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+ using OneToNOpAdaptor =
+ typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
//
// Derived classes should provide the following method which performs the
// actual conversion. It should return std::nullopt upon conversion failure
// and return the converted operation upon success.
//
- // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
- // ConversionPatternRewriter &rewriter,
- // TypeRange dstTypes) const;
+ // std::optional<SourceOp> convertSourceOp(
+ // SourceOp op, OneToNOpAdaptor adaptor,
+ // ConversionPatternRewriter &rewriter,
+ // TypeRange dstTypes) const;
LogicalResult
- matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
@@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
return rewriter.notifyMatchFailure(op, "could not convert operation");
// Packs the return value.
- SmallVector<Value> packedRets;
+ SmallVector<ValueRange> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults().slice(start, len);
- if (len != 1) {
- // 1 : N type conversion.
- Type origType = op.getResultTypes()[i - 1];
- Value mat = typeConverter->materializeSourceConversion(
- rewriter, op.getLoc(), origType, mappedValue);
- if (!mat) {
- return rewriter.notifyMatchFailure(
- op, "Failed to materialize 1:N type conversion");
- }
- packedRets.push_back(mat);
- } else {
- // 1 : 1 type conversion.
- packedRets.push_back(mappedValue.front());
- }
+ packedRets.push_back(mappedValue);
}
- rewriter.replaceOp(op, packedRets);
+ rewriter.replaceOpWithMultiple(op, packedRets);
return success();
}
};
@@ -105,7 +92,7 @@ class ConvertForOpTypes
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
// The callback required by CRTP.
- std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+ std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
@@ -129,16 +116,13 @@ class ConvertForOpTypes
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
return std::nullopt;
- // Unpacked the iteration arguments.
- SmallVector<Value> flatArgs;
- for (Value arg : adaptor.getInitArgs())
- unpackUnrealizedConversionCast(arg, flatArgs);
-
// We can not do clone as the number of result types after conversion
// might be
diff erent.
- ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
- adaptor.getUpperBound(),
- adaptor.getStep(), flatArgs);
+ ForOp newOp = rewriter.create<ForOp>(
+ op.getLoc(), getSingleValue(adaptor.getLowerBound()),
+ getSingleValue(adaptor.getUpperBound()),
+ getSingleValue(adaptor.getStep()),
+ flattenValues(adaptor.getInitArgs()));
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
@@ -160,12 +144,12 @@ class ConvertIfOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
- std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+ std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
- adaptor.getCondition(), true);
+ IfOp newOp = rewriter.create<IfOp>(
+ op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
newOp->setAttrs(op->getAttrs());
// We do not need the empty blocks created by rewriter.
@@ -189,15 +173,11 @@ class ConvertWhileOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
- std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+ std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
- // Unpacked the iteration arguments.
- SmallVector<Value> flatArgs;
- for (Value arg : adaptor.getOperands())
- unpackUnrealizedConversionCast(arg, flatArgs);
-
- auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+ auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
+ flattenValues(adaptor.getOperands()));
for (auto i : {0u, 1u}) {
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +198,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
+ matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> unpackedYield;
- for (Value operand : adaptor.getOperands())
- unpackUnrealizedConversionCast(operand, unpackedYield);
-
- rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(
+ op, flattenValues(adaptor.getOperands()));
return success();
}
};
@@ -235,13 +212,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+ matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> unpackedYield;
- for (Value operand : adaptor.getOperands())
- unpackUnrealizedConversionCast(operand, unpackedYield);
-
- rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 25fca49cb0154a..20d46f7ca00c54 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -39,25 +39,18 @@ using namespace mlir::sparse_tensor;
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Flattens a list of operands that may contain sparse tensors.
-static void flattenOperands(ValueRange operands,
- SmallVectorImpl<Value> &flattened) {
- // In case of
- // sparse_tensor, c, sparse_tensor
- // ==>
- // memref ..., c, memref ...
- for (auto operand : operands) {
- if (getSparseTensorEncoding(operand.getType())) {
- auto tuple = getTuple(operand);
- // An unrealized_conversion_cast will be inserted by type converter to
- // inter-mix the gap between 1:N conversion between sparse tensors and
- // fields. In this case, take the operands in the cast and replace the
- // sparse tensor output with the flattened type array.
- flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
- } else {
- flattened.push_back(operand);
- }
- }
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const auto &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+ assert(values.size() == 1 && "expected single value");
+ return values.front();
}
/// Generates a load with proper `index` typing.
@@ -567,12 +560,11 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Value> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
// Create a return with the flattened value extracted from sparse tensors.
- rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(
+ op, flattenValues(adaptor.getOperands()));
return success();
}
};
@@ -583,7 +575,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
// The default CallOp converter can not handle 1:N type conversion.
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// In case of:
@@ -596,10 +588,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
return failure();
// (1) Generates new call with flattened return value.
- SmallVector<Value> flattened;
- flattenOperands(adaptor.getOperands(), flattened);
- auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
- finalRetTy, flattened);
+ auto newCall = rewriter.create<func::CallOp>(
+ loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
// (2) Gather sparse tensor returns.
SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
@@ -643,7 +633,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(LvlOp op, OpAdaptor adaptor,
+ matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> lvl = op.getConstantLvlIndex();
RankedTensorType srcType = op.getSource().getType();
@@ -662,7 +652,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
@@ -693,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
- rewriter.replaceOp(op, adaptor.getInputCoo());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
return success();
}
};
@@ -702,8 +692,10 @@ template <typename Op, StorageSpecifierKind kind>
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
+ using typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
LogicalResult
- matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply lowers to specifer.get <field> operation.
auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
@@ -721,14 +713,14 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only rewrite identically annotated source/dest.
auto encDst = getSparseTensorEncoding(op.getType());
auto encSrc = getSparseTensorEncoding(op.getSource().getType());
if (!encDst || encDst != encSrc)
return failure();
- rewriter.replaceOp(op, adaptor.getOperands());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
};
@@ -737,10 +729,10 @@ class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Simply fold the operation.
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
};
@@ -756,7 +748,7 @@ class SparseTensorAllocConverter
enableBufferInitialization(enableInit) {}
LogicalResult
- matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
+ matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto resType = getSparseTensorType(op);
if (!resType.hasEncoding())
@@ -791,7 +783,8 @@ class SparseTensorAllocConverter
}
// Level size equals to dimension size since lvl2dim map is an identity map.
SmallVector<Value> lvlSizesValues;
- createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
+ createDimSizes(rewriter, loc, resType,
+ flattenValues(adaptor.getDynamicSizes()),
/*dimSizesValues=*/lvlSizesValues);
// Construct allocation for each field.
@@ -861,7 +854,7 @@ class SparseTensorDeallocConverter
createDeallocs(createDeallocs) {}
LogicalResult
- matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+ matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto enc = getSparseTensorEncoding(op.getTensor().getType());
if (!enc)
@@ -892,7 +885,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(LoadOp op, OpAdaptor adaptor,
+ matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Prepare descriptor.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
@@ -911,7 +904,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+ matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!getSparseTensorEncoding(op.getTensor().getType()))
return failure();
@@ -963,16 +956,16 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+ matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
- Value values = adaptor.getValues();
- Value filled = adaptor.getFilled();
- Value added = adaptor.getAdded();
- Value count = adaptor.getCount();
+ Value values = getSingleValue(adaptor.getValues());
+ Value filled = getSingleValue(adaptor.getFilled());
+ Value added = getSingleValue(adaptor.getAdded());
+ Value count = getSingleValue(adaptor.getCount());
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();
@@ -1005,7 +998,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
SmallVector<Type> flatSpTensorTps = llvm::to_vector(
llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
- params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+ SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
+ params.append(flatLvlCoords.begin(), flatLvlCoords.end());
params.push_back(crd);
params.push_back(value);
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
@@ -1033,9 +1027,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto stt = getSparseTensorType(adaptor.getDest());
+ auto stt = getSparseTensorType(op.getDest());
if (!stt.hasEncoding())
return failure();
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
@@ -1045,8 +1039,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
- params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
- params.push_back(adaptor.getScalar());
+ SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
+ params.append(flatIndices.begin(), flatIndices.end());
+ params.push_back(getSingleValue(adaptor.getScalar()));
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
@@ -1062,7 +1057,7 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
using OpAdaptor = typename ToPositionsOp::Adaptor;
using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested position access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1085,7 +1080,7 @@ class SparseToCoordinatesConverter
using OpAdaptor = typename ToCoordinatesOp::Adaptor;
using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1111,7 +1106,7 @@ class SparseToCoordinatesBufferConverter
using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested coordinates access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1133,7 +1128,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
using OpAdaptor = typename ToValuesOp::Adaptor;
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+ matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested values access with corresponding field.
// The view is restricted to the actual size to ensure clients
@@ -1153,7 +1148,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+ matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
@@ -1173,7 +1168,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
Type srcElemTp = op.getSource().getType().getElementType();
// Fold the trivial cases.
if (retElemTp == srcElemTp && encDst == encSrc) {
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
return success();
}
//
@@ -1239,7 +1234,7 @@ class SparseExtractSliceConverter
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
@@ -1296,7 +1291,7 @@ class SparseNumberOfEntriesConverter
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+ matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
@@ -1430,7 +1425,7 @@ struct SparseDisassembleOpConverter
: OpConversionPattern(typeConverter, context) {}
LogicalResult
- matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
op.getTensor().getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index 89858546e37e1b..869c7864d75354 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -228,11 +228,6 @@ class MutSparseTensorDescriptor
}
};
-/// Returns the "tuple" value of the adapted tensor.
-inline UnrealizedConversionCastOp getTuple(Value tensor) {
- return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
-}
-
/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
@@ -246,16 +241,15 @@ inline Value genTuple(OpBuilder &builder, Location loc,
}
inline SparseTensorDescriptor
-getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
- auto tuple = getTuple(tensor);
- return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
+getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
+ return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
}
inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+getMutDescriptorFromTensorTuple(ValueRange adaptorValues,
+ SmallVectorImpl<Value> &fields,
RankedTensorType type) {
- auto tuple = getTuple(tensor);
- fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
+ fields.assign(adaptorValues.begin(), adaptorValues.end());
return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1424c4974f2d43..613fd6d9d74b1f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,10 +67,6 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
// ConversionValueMapping
//===----------------------------------------------------------------------===//
-/// A list of replacement SSA values. Optimized for the common case of a single
-/// SSA value.
-using ReplacementValues = SmallVector<Value, 1>;
-
namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -783,7 +779,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped);
+ SmallVector<SmallVector<Value>> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -817,17 +813,31 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Materializations
//===--------------------------------------------------------------------===//
- /// Build an unresolved materialization operation given an output type and set
- /// of input operands.
+ /// Build an unresolved materialization operation given a range of output
+ /// types and a list of input operands. Returns the inputs if they their
+ /// types match the output types.
+ ///
+ /// If a cast op was built, it can optionally be returned with the `castOp`
+ /// output argument.
///
/// If `valueToMap` is set to a non-null Value, then that value is mapped to
- /// the result of the unresolved materialization in the conversion value
+ /// the results of the unresolved materialization in the conversion value
/// mapping.
- Value buildUnresolvedMaterialization(MaterializationKind kind,
- OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs,
- Type outputType, Type originalType,
- const TypeConverter *converter);
+ ValueRange buildUnresolvedMaterialization(
+ MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+ Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ Type originalType, const TypeConverter *converter,
+ UnrealizedConversionCastOp *castOp = nullptr);
+ Value buildUnresolvedMaterialization(
+ MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+ Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
+ const TypeConverter *converter,
+ UnrealizedConversionCastOp *castOp = nullptr) {
+ return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
+ TypeRange(outputType), originalType,
+ converter, castOp)
+ .front();
+ }
/// Build an N:1 materialization for the given original value that was
/// replaced with the given replacement values.
@@ -853,6 +863,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value findOrBuildReplacementValue(Value value,
const TypeConverter *converter);
+ /// Unpack an N:1 materialization and return the inputs of the
+ /// materialization. This function unpacks only those materializations that
+ /// were built with `insertNTo1Materialization`.
+ ///
+ /// This is a workaround around incomplete 1:N support in the dialect
+ /// conversion driver. It allows us to write 1:N conversion patterns while
+ /// 1:N support is still missing in the conversion value mapping. This
+ /// function will be deleted when full 1:N support has been added.
+ SmallVector<Value> unpackNTo1Materialization(Value value);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -862,7 +882,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -955,6 +975,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
unresolvedMaterializations;
+ /// A set of all N:1 materializations that were added to work around
+ /// incomplete 1:N support in the dialect conversion driver.
+ DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations;
+
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1091,6 +1115,7 @@ void UnresolvedMaterializationRewrite::rollback() {
if (mappedValue)
rewriterImpl.mapping.erase(mappedValue);
rewriterImpl.unresolvedMaterializations.erase(getOperation());
+ rewriterImpl.nTo1TempMaterializations.erase(getOperation());
op->erase();
}
@@ -1136,7 +1161,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped) {
+ SmallVector<SmallVector<Value>> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1144,11 +1169,18 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type origType = operand.getType();
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
+ // Find the most recently mapped value. Unpack all temporary N:1
+ // materializations. Such conversions are a workaround around missing
+ // 1:N support in the ConversionValueMapping. (The conversion patterns
+ // already support 1:N replacements.)
+ Value repl = mapping.lookupOrDefault(operand);
+ SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
+
if (!currentTypeConverter) {
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back(std::move(unpacked));
continue;
}
@@ -1162,15 +1194,29 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
return failure();
}
+ // If a type is converted to 0 types, there is nothing to do.
+ if (legalTypes.empty()) {
+ remapped.push_back({});
+ continue;
+ }
+
if (legalTypes.size() != 1) {
- // TODO: Parts of the dialect conversion infrastructure do not support
- // 1->N type conversions yet. Therefore, if a type is converted to 0 or
- // multiple types, the only thing that we can do for now is passing
- // through the most recently mapped value. Fixing this requires
- // improvements to the `ConversionValueMapping` (to be able to store 1:N
- // mappings) and to the `ConversionPattern` adaptor handling (to be able
- // to pass multiple remapped values for a single operand to the adaptor).
- remapped.push_back(mapping.lookupOrDefault(operand));
+ // TODO: This is a 1:N conversion. The conversion value mapping does not
+ // store such materializations yet. If the types of the most recently
+ // mapped values do not match, build a target materialization.
+ if (TypeRange(unpacked) == legalTypes) {
+ remapped.push_back(std::move(unpacked));
+ continue;
+ }
+
+ // Insert a target materialization if the current pattern expects
+ //
diff erent legalized types.
+ ValueRange targetMat = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+ /*valueToMap=*/Value(), /*inputs=*/unpacked,
+ /*outputType=*/legalTypes, /*originalType=*/origType,
+ currentTypeConverter);
+ remapped.push_back(targetMat);
continue;
}
@@ -1182,15 +1228,15 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (newOperand.getType() != desiredType) {
// If the looked up value's type does not have the desired type, it means
// that the value was replaced with a value of
diff erent type and no
- // source materialization was created yet.
+ // target materialization was created yet.
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand,
+ operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
/*outputType=*/desiredType, /*originalType=*/origType,
currentTypeConverter);
newOperand = castValue;
}
- remapped.push_back(newOperand);
+ remapped.push_back({newOperand});
}
return success();
}
@@ -1347,31 +1393,38 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
-Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
- const TypeConverter *converter) {
+ Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+ Type originalType, const TypeConverter *converter,
+ UnrealizedConversionCastOp *castOp) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
// Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType) {
- if (valueToMap)
+ if (TypeRange(inputs) == outputTypes) {
+ if (valueToMap) {
+ assert(inputs.size() == 1 && "1:N mapping is not supported");
mapping.map(valueToMap, inputs.front());
- return inputs.front();
+ }
+ return inputs;
}
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
- OpBuilder builder(outputType.getContext());
+ OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- if (valueToMap)
+ builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
+ if (valueToMap) {
+ assert(outputTypes.size() == 1 && "1:N mapping is not supported");
mapping.map(valueToMap, convertOp.getResult(0));
+ }
+ if (castOp)
+ *castOp = convertOp;
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
originalType, valueToMap);
- return convertOp.getResult(0);
+ return convertOp.getResults();
}
void ConversionPatternRewriterImpl::insertNTo1Materialization(
@@ -1379,10 +1432,13 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
Value originalValue, const TypeConverter *converter) {
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
+ UnrealizedConversionCastOp argCastOp;
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
- /*inputs=*/replacements, originalType, /*originalType=*/Type(),
- converter);
+ /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
+ &argCastOp);
+ if (argCastOp)
+ nTo1TempMaterializations.insert(argCastOp);
// Insert target materialization to the legalized type.
Type legalOutputType;
@@ -1398,11 +1454,14 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
- buildUnresolvedMaterialization(MaterializationKind::Target,
- computeInsertPoint(argMat), loc,
- /*valueToMap=*/argMat, /*inputs=*/argMat,
- /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
+ UnrealizedConversionCastOp targetCastOp;
+ buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(argMat), loc,
+ /*valueToMap=*/argMat, /*inputs=*/argMat,
+ /*outputType=*/legalOutputType, /*originalType=*/originalType,
+ converter, &targetCastOp);
+ if (targetCastOp)
+ nTo1TempMaterializations.insert(targetCastOp);
}
}
@@ -1438,9 +1497,32 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
/*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
/*originalType=*/Type(), converter);
+ mapping.map(value, castValue);
return castValue;
}
+SmallVector<Value>
+ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
+ // Unpack unrealized_conversion_cast ops that were inserted as a N:1
+ // workaround.
+ auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!castOp)
+ return {value};
+ if (!nTo1TempMaterializations.contains(castOp))
+ return {value};
+ assert(castOp->getNumResults() == 1 && "expected single result");
+
+ SmallVector<Value> result;
+ for (Value v : castOp.getOperands()) {
+ // Keep unpacking if possible. This is needed because during block
+ // signature conversions and 1:N op replacements, the driver may have
+ // inserted two materializations back-to-back: first an argument
+ // materialization, then a target materialization.
+ llvm::append_range(result, unpackNTo1Materialization(v));
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1465,7 +1547,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
}
void ConversionPatternRewriterImpl::notifyOpReplaced(
- Operation *op, ArrayRef<ReplacementValues> newValues) {
+ Operation *op, ArrayRef<ValueRange> newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
@@ -1477,8 +1559,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
isUnresolvedMaterialization = true;
// Create mappings for each of the new result values.
- for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
- ReplacementValues repl = n;
+ for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
if (repl.empty()) {
// This result was dropped and no replacement value was provided.
if (isUnresolvedMaterialization) {
@@ -1488,12 +1569,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
}
// Materialize a replacement value "out of thin air".
- Value sourceMat = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
+ result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
- repl.push_back(sourceMat);
+ continue;
} else {
// Make sure that the user does not mess with unresolved materializations
// that were inserted by the conversion driver. We keep track of these
@@ -1595,10 +1676,9 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> newVals(newValues.size());
- for (auto [index, val] : llvm::enumerate(newValues))
- if (val)
- newVals[index].push_back(val);
+ SmallVector<ValueRange> newVals;
+ for (int i = 0; i < newValues.size(); ++i)
+ newVals.push_back(newValues.slice(i, 1));
impl->notifyOpReplaced(op, newVals);
}
@@ -1610,10 +1690,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> newVals(newValues.size(), {});
- for (auto [index, val] : llvm::enumerate(newValues))
- llvm::append_range(newVals[index], val);
- impl->notifyOpReplaced(op, newVals);
+ impl->notifyOpReplaced(op, newValues);
}
void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1621,7 +1698,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
+ SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
impl->notifyOpReplaced(op, nullRepls);
}
@@ -1673,11 +1750,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<Value> remappedValues;
+ SmallVector<SmallVector<Value>> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
- return remappedValues.front();
+ assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+ return remappedValues.front().front();
}
LogicalResult
@@ -1685,8 +1763,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
- results);
+ SmallVector<SmallVector<Value>> remapped;
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ remapped)))
+ return failure();
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ results.push_back(values.front());
+ }
+ return success();
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1780,6 +1865,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+ ArrayRef<ValueRange> operands) const {
+ SmallVector<Value> oneToOneOperands;
+ oneToOneOperands.reserve(operands.size());
+ for (ValueRange operand : operands) {
+ if (operand.size() != 1)
+ llvm::report_fatal_error("pattern '" + getDebugName() +
+ "' does not support 1:N conversion");
+ oneToOneOperands.push_back(operand.front());
+ }
+ return oneToOneOperands;
+}
+
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -1791,12 +1889,14 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
+ SmallVector<SmallVector<Value>> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ op->getOperands(), remapped))) {
return failure();
}
- return matchAndRewrite(op, operands, dialectRewriter);
+ SmallVector<ValueRange> remappedAsRange =
+ llvm::to_vector_of<ValueRange>(remapped);
+ return matchAndRewrite(op, remappedAsRange, dialectRewriter);
}
//===----------------------------------------------------------------------===//
@@ -2536,45 +2636,52 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
- Type outputType = op.getResultTypes()[0];
// Try to materialize the conversion.
if (const TypeConverter *converter = rewrite->getConverter()) {
rewriter.setInsertionPoint(op);
- Value newMaterialization;
+ SmallVector<Value> newMaterialization;
switch (rewrite->getMaterializationKind()) {
- case MaterializationKind::Argument:
+ case MaterializationKind::Argument: {
// Try to materialize an argument conversion.
- newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- if (newMaterialization)
+ assert(op->getNumResults() == 1 && "expected single result");
+ Value argMat = converter->materializeArgumentConversion(
+ rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+ if (argMat) {
+ newMaterialization.push_back(argMat);
break;
+ }
+ }
// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands,
+ rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
rewrite->getOriginalType());
break;
case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
+ assert(op->getNumResults() == 1 && "expected single result");
+ Value sourceMat = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+ if (sourceMat)
+ newMaterialization.push_back(sourceMat);
break;
}
- if (newMaterialization) {
- assert(newMaterialization.getType() == outputType &&
+ if (!newMaterialization.empty()) {
+ assert(TypeRange(newMaterialization) == op.getResultTypes() &&
"materialization callback produced value of incorrect type");
rewriter.replaceOp(op, newMaterialization);
return success();
}
}
- InFlightDiagnostic diag =
- op->emitError() << "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to (" << outputType
- << ") that remained live after conversion";
+ InFlightDiagnostic diag = op->emitError()
+ << "failed to legalize unresolved materialization "
+ "from ("
+ << inputOperands.getTypes() << ") to ("
+ << op.getResultTypes()
+ << ") that remained live after conversion";
diag.attachNote(op->getUsers().begin()->getLoc())
<< "see existing live user here: " << *op->getUsers().begin();
return failure();
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index b8fad63eb4de67..4e641317ac2f3d 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -9,10 +9,7 @@
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32
// CHECK-12N-LABEL: func @identity(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK-LABEL: func @mixed_recursive_decomposition(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
-// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
-// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
-// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
-// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
-// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK: return %[[V7]], %[[V10]] : i1, i2
+// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
@@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-LABEL: func @caller(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
-// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I4:.*]]: i4,
// CHECK-SAME: %[[I5:.*]]: i5,
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
-// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[I1:.*]]: i1,
// CHECK-12N-SAME: %[[I2:.*]]: i2,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e05f444afa68f0..d98a6a036e6b1f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -472,3 +472,14 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+func.func @test_1_to_n_block_signature_conversion() {
+ "test.duplicate_block_args"() ({
+ ^bb0(%arg0: i64):
+ "test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
+ }) {} : () -> ()
+ "test.return"() : () -> ()
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 239d5292180269..d24d52f356d88f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1886,6 +1886,11 @@ def LegalOpC : TEST_Op<"legal_op_c">,
Arguments<(ins I32)>, Results<(outs I32)>;
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
+def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
+ let arguments = (ins UnitAttr:$is_legal);
+ let regions = (region SizedRegion<1>:$body);
+}
+
// Check that the conversion infrastructure can properly undo the creation of
// operations where an operation was created before its parent, in this case,
// in the parent's builder.
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bbd55938718fe7..8a0bc597c56beb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
TestPassthroughInvalidOp(MLIRContext *ctx)
: ConversionPattern("test.invalid", 1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
- rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
+ SmallVector<Value> flattened;
+ for (auto it : llvm::enumerate(operands)) {
+ ValueRange range = it.value();
+ if (range.size() == 1) {
+ flattened.push_back(range.front());
+ continue;
+ }
+
+ // This is a 1:N replacement. Insert a test.cast op. (That's what the
+ // argument materialization used to do.)
+ flattened.push_back(
+ rewriter
+ .create<TestCastOp>(op->getLoc(),
+ op->getOperand(it.index()).getType(), range)
+ .getResult());
+ }
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
std::nullopt);
return success();
}
@@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern {
TestSplitReturnType(MLIRContext *ctx)
: ConversionPattern("test.return", 1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
// Check for a return of F32.
if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
return failure();
-
- // Check if the first operation is a cast operation, if it is we use the
- // results directly.
- auto *defOp = operands[0].getDefiningOp();
- if (auto packerOp =
- llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
- rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
- return success();
- }
-
- // Otherwise, fail to match.
- return failure();
+ rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
+ return success();
}
};
@@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern {
}
};
+/// This pattern matches a test.duplicate_block_args op and duplicates all
+/// block arguments.
+class TestDuplicateBlockArgs
+ : public OpConversionPattern<DuplicateBlockArgsOp> {
+ using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (op.getIsLegal())
+ return failure();
+ rewriter.startOpModification(op);
+ Block *body = &op.getBody().front();
+ TypeConverter::SignatureConversion result(body->getNumArguments());
+ for (auto it : llvm::enumerate(body->getArgumentTypes()))
+ result.addInputs(it.index(), {it.value(), it.value()});
+ rewriter.applySignatureConversion(body, result, getTypeConverter());
+ op.setIsLegal(true);
+ rewriter.finalizeOpModification(op);
+ return success();
+ }
+};
+
+/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
+/// op. The pattern supports 1:N replacements and forwards the replacement
+/// values of the single operand as test.valid operands.
+class TestRepetitive1ToNConsumer : public ConversionPattern {
+public:
+ TestRepetitive1ToNConsumer(MLIRContext *ctx)
+ : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // A single operand is expected.
+ if (op->getNumOperands() != 1)
+ return failure();
+ rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1263,9 +1310,11 @@ struct TestLegalizePatternDriver
TestUpdateConsumerType, TestNonRootReplacement,
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
- TestUndoPropertiesModification, TestEraseOp>(&getContext());
+ TestUndoPropertiesModification, TestEraseOp,
+ TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
&getContext(), converter);
+ patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1317,6 +1366,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
+ target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
+ [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
+
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
More information about the Mlir-commits
mailing list