[llvm-branch-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add `replaceOpWithMultiple` (PR #115816)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 12 20:12:04 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/115816
>From b425caab826e5d9ad2f078d6f548f3215005bf7f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 12 Nov 2024 05:14:43 +0100
Subject: [PATCH 1/2] replace with multiple
---
mlir/include/mlir/IR/Builders.h | 2 +-
.../mlir/Transforms/DialectConversion.h | 24 ++-
.../Transforms/DecomposeCallGraphTypes.cpp | 40 ++---
.../Transforms/SparseTensorCodegen.cpp | 48 +++---
.../Utils/SparseTensorDescriptor.cpp | 21 ++-
mlir/lib/IR/Builders.cpp | 16 +-
.../Transforms/Utils/DialectConversion.cpp | 153 ++++++++++++------
7 files changed, 186 insertions(+), 118 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7ef03b87179523..78729376507208 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -353,7 +353,7 @@ class OpBuilder : public Builder {
/// selected insertion point. (E.g., because they are defined in a nested
/// region or because they are not visible in an IsolatedFromAbove region.)
static InsertPoint after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo);
+ const PostDominanceInfo *domInfo = nullptr);
/// Returns true if this insert point is set.
bool isSet() const { return (block != nullptr); }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..e461b7d11602a0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the new values. The number of op results
+ /// and replacement values must match. The types may differ: the dialect
+ /// conversion driver will reconcile any surviving type mismatches at the end
+ /// of the conversion process with source materializations. The given
+ /// operation is erased.
void replaceOp(Operation *op, ValueRange newValues) override;
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the results of the new op. The number of
+ /// op results must match. The types may differ: the dialect conversion
+ /// driver will reconcile any surviving type mismatches at the end of the
+ /// conversion process with source materializations. The original operation
+ /// is erased.
void replaceOp(Operation *op, Operation *newOp) override;
+ /// Replace the given operation with the new value range. The number of op
+ /// results and value ranges must match. If an original SSA value is replaced
+ /// by multiple SSA values (i.e., value range has more than 1 element), the
+ /// conversion driver will insert an argument materialization to convert the
+ /// N SSA values back into 1 SSA value of the original type. The given
+ /// operation is erased.
+ ///
+ /// Note: The argument materialization is a workaround until we have full 1:N
+ /// support in the dialect conversion. (It is going to disappear from both
+ /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+ void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
getTypeConverter()));
}
- // Create the new result types for the new `CallOp` and track the indices in
- // the new call op's results that correspond to the old call op's results.
- //
- // expandedResultIndices[i] = "list of new result indices that old result i
- // expanded to".
+ // Create the new result types for the new `CallOp` and track the number of
+ // replacement types for each original op result.
SmallVector<Type, 2> newResultTypes;
- SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+ SmallVector<unsigned> expandedResultSizes;
for (Type resultType : op.getResultTypes()) {
unsigned oldSize = newResultTypes.size();
if (failed(typeConverter->convertType(resultType, newResultTypes)))
return failure();
- auto &resultMapping = expandedResultIndices.emplace_back();
- for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
- resultMapping.push_back(i);
+ expandedResultSizes.push_back(newResultTypes.size() - oldSize);
}
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
newResultTypes, newOperands);
- // Build a replacement value for each result to replace its uses. If a
- // result has multiple mapping values, it needs to be materialized as a
- // single value.
- SmallVector<Value, 2> replacedValues;
+ // Build a replacement value for each result to replace its uses.
+ SmallVector<ValueRange> replacedValues;
replacedValues.reserve(op.getNumResults());
+ unsigned startIdx = 0;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- auto decomposedValues = llvm::to_vector<6>(
- llvm::map_range(expandedResultIndices[i],
- [&](unsigned i) { return newCallOp.getResult(i); }));
- if (decomposedValues.empty()) {
- // No replacement is required.
- replacedValues.push_back(nullptr);
- } else if (decomposedValues.size() == 1) {
- replacedValues.push_back(decomposedValues.front());
- } else {
- // Materialize a single Value to replace the original Value.
- Value materialized = getTypeConverter()->materializeArgumentConversion(
- rewriter, op.getLoc(), op.getType(i), decomposedValues);
- replacedValues.push_back(materialized);
- }
+ ValueRange repl =
+ newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+ replacedValues.push_back(repl);
+ startIdx += expandedResultSizes[i];
}
- rewriter.replaceOp(op, replacedValues);
+ rewriter.replaceOpWithMultiple(op, replacedValues);
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..09509278d7749a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
- // (2) Create cast operation for sparse tensor returns.
- SmallVector<Value> castedRet;
+ // (2) Gather sparse tensor returns.
+ SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange fields(iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + flatSize));
- castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+ packedResultVals.push_back(SmallVector<Value>());
+ llvm::append_range(packedResultVals.back(),
+ iterator_range<ResultRange::iterator>(
+ newCall.result_begin() + retOffset,
+ newCall.result_begin() + retOffset + flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
- castedRet.push_back(newCall.getResult(retOffset));
+ packedResultVals.push_back(SmallVector<Value>());
+ packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}
- assert(castedRet.size() == op.getNumResults());
- rewriter.replaceOp(op, castedRet);
+ assert(packedResultVals.size() == op.getNumResults());
+ SmallVector<ValueRange> ranges;
+ ranges.reserve(packedResultVals.size());
+ for (const SmallVector<Value> &vec : packedResultVals)
+ ranges.push_back(ValueRange(vec));
+ rewriter.replaceOpWithMultiple(op, ranges);
return success();
}
};
@@ -776,7 +782,7 @@ class SparseTensorAllocConverter
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -796,7 +802,7 @@ class SparseTensorAllocConverter
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<scf::YieldOp>(loc, insertRet);
rewriter.setInsertionPointAfter(loop);
- Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
@@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op,
- genTuple(rewriter, loc, op.getDest().getType(), ret));
+ rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
@@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
return true;
});
- rewriter.replaceOp(
- op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
@@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
- rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
- desc.getFields()));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);
- rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
// The sparse tensor type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+ ValueRange inputs, Location loc) {
+ if (!getSparseTensorEncoding(tp))
+ // Not a sparse tensor.
+ return Value();
+ // Sparsifier knows how to cancel out these casts.
+ return genTuple(builder, loc, tp, inputs);
+}
+
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorType);
// Required by scf.for 1:N type conversion.
- addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
- ValueRange inputs, Location loc) -> Value {
- if (!getSparseTensorEncoding(tp))
- // Not a sparse tensor.
- return Value();
- // Sparsifier knows how to cancel out these casts.
- return genTuple(builder, loc, tp, inputs);
- });
+ addSourceMaterialization(materializeTuple);
+
+ // Required as a workaround until we have full 1:N support.
+ addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4714c3cace6c78..e85a86e94282ec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) {
OpBuilder::InsertPoint
OpBuilder::InsertPoint::after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo) {
+ const PostDominanceInfo *domInfo) {
// Helper function that computes the point after v's definition.
auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
if (auto blockArg = dyn_cast<BlockArgument>(v))
@@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
assert(!values.empty() && "expected at least one Value");
auto [block, blockIt] = computeAfterIp(values.front());
+ if (values.size() == 1) {
+ // Fast path: There is only one value.
+ return InsertPoint(block, blockIt);
+ }
+
// Check the other values one-by-one and update the insertion point if
// needed.
+ assert(domInfo && "domInfo expected if >1 values");
for (Value v : values.drop_front()) {
auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
- if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
- blockIt)) {
+ if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block,
+ blockIt)) {
// The point after v's definition post-dominates the current (and all
// previous) insertion points. Note: Post-dominance is transitive.
block = candidateBlock;
@@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
continue;
}
- if (!domInfo.postDominantes(block, blockIt, candidateBlock,
- candidateBlockIt)) {
+ if (!domInfo->postDominantes(block, blockIt, candidateBlock,
+ candidateBlockIt)) {
// The point after v's definition and the current insertion point do not
// post-dominate each other. Therefore, there is no insertion point that
// post-dominates all values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..2f6c0a1ab0bd3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
- Block *insertBlock = value.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(value))
- insertPt = ++inputRes.getOwner()->getIterator();
- return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Type originalType,
const TypeConverter *converter);
+ /// Build an N:1 materialization for the given original value that was
+ /// replaced with the given replacement values.
+ ///
+ /// This is a workaround around incomplete 1:N support in the dialect
+ /// conversion driver. The conversion mapping can store only 1:1 replacements
+ /// and the conversion patterns only support single Value replacements in the
+ /// adaptor, so N values must be converted back to a single value. This
+ /// function will be deleted when full 1:N support has been added.
+ ///
+ /// This function inserts an argument materialization back to the original
+ /// type, followed by a target materialization to the legalized type (if
+ /// applicable).
+ void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+ ValueRange replacements, Value originalValue,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ValueRange newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// that the value was replaced with a value of different type and no
// source materialization was created yet.
Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+ MaterializationKind::Target,
+ OpBuilder::InsertPoint::after(newOperand), operandLoc,
+ /*inputs=*/newOperand, /*outputType=*/desiredType,
/*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
newOperand = castValue;
@@ -1287,33 +1299,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- Value argMat = buildUnresolvedMaterialization(
- MaterializationKind::Argument,
+ insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/replArgs, /*outputType=*/origArgType,
- /*originalType=*/Type(), converter);
- mapping.map(origArg, argMat);
-
- Type legalOutputType;
- if (converter) {
- legalOutputType = converter->convertType(origArgType);
- } else if (replArgs.size() == 1) {
- // When there is no type converter, assume that the new block argument
- // types are legal. This is reasonable to assume because they were
- // specified by the user.
- // FIXME: This won't work for 1->N conversions because multiple output
- // types are not supported in parts of the dialect conversion. In such a
- // case, we currently use the original block argument type (produced by
- // the argument materialization).
- legalOutputType = replArgs[0].getType();
- }
- if (legalOutputType && legalOutputType != origArgType) {
- Value targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat),
- origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/origArgType, converter);
- mapping.map(argMat, targetMat);
- }
+ /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
}
@@ -1354,6 +1342,39 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
return convertOp.getResult(0);
}
+void ConversionPatternRewriterImpl::insertNTo1Materialization(
+ OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
+ Value originalValue, const TypeConverter *converter) {
+ // Insert argument materialization back to the original type.
+ Type originalType = originalValue.getType();
+ Value argMat =
+ buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
+ /*inputs=*/replacements, originalType,
+ /*originalType=*/Type(), converter);
+ mapping.map(originalValue, argMat);
+
+ // Insert target materialization to the legalized type.
+ Type legalOutputType;
+ if (converter) {
+ legalOutputType = converter->convertType(originalType);
+ } else if (replacements.size() == 1) {
+ // When there is no type converter, assume that the replacement value
+ // types are legal. This is reasonable to assume because they were
+ // specified by the user.
+ // FIXME: This won't work for 1->N conversions because multiple output
+ // types are not supported in parts of the dialect conversion. In such a
+ // case, we currently use the original value type.
+ legalOutputType = replacements[0].getType();
+ }
+ if (legalOutputType && legalOutputType != originalType) {
+ Value targetMat = buildUnresolvedMaterialization(
+ MaterializationKind::Target, OpBuilder::InsertPoint::after(argMat), loc,
+ /*inputs=*/argMat, /*outputType=*/legalOutputType,
+ /*originalType=*/originalType, converter);
+ mapping.map(argMat, targetMat);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1377,10 +1398,11 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
}
-void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
- ValueRange newValues) {
+void ConversionPatternRewriterImpl::notifyOpReplaced(
+ Operation *op, ArrayRef<ReplacementValues> newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
+ PostDominanceInfo domInfo;
// Check if replaced op is an unresolved materialization, i.e., an
// unrealized_conversion_cast op that was created by the conversion driver.
@@ -1390,8 +1412,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
isUnresolvedMaterialization = true;
// Create mappings for each of the new result values.
- for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
- if (!newValue) {
+ for (auto [n, result] : llvm::zip(newValues, op->getResults())) {
+ ReplacementValues repl = n;
+ if (repl.empty()) {
// This result was dropped and no replacement value was provided.
if (isUnresolvedMaterialization) {
// Do not create another materializations if we are erasing a
@@ -1400,11 +1423,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
}
// Materialize a replacement value "out of thin air".
- newValue = buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(result),
+ Value sourceMat = buildUnresolvedMaterialization(
+ MaterializationKind::Source, OpBuilder::InsertPoint::after(result),
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
+ repl.push_back(sourceMat);
} else {
// Make sure that the user does not mess with unresolved materializations
// that were inserted by the conversion driver. We keep track of these
@@ -1417,12 +1441,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
}
// Remap result to replacement value.
- if (newValue)
- mapping.map(result, newValue);
+ if (!repl.empty()) {
+ if (repl.size() == 1) {
+ // Single replacement value: replace directly.
+ mapping.map(result, repl.front());
+ } else {
+ // Multiple replacement values: insert N:1 materialization.
+ insertNTo1Materialization(OpBuilder::InsertPoint::after(repl, &domInfo),
+ result.getLoc(),
+ /*replacements=*/repl, /*outputValue=*/result,
+ currentTypeConverter);
+ }
+ }
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
-
// Mark this operation and all nested ops as replaced.
op->walk([&](Operation *op) { replacedOps.insert(op); });
}
@@ -1497,7 +1530,25 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- impl->notifyOpReplaced(op, newValues);
+ SmallVector<ReplacementValues> newVals(newValues.size(), {});
+ for (auto it : llvm::enumerate(newValues))
+ if (Value val = it.value())
+ newVals[it.index()].push_back(val);
+ impl->notifyOpReplaced(op, newVals);
+}
+
+void ConversionPatternRewriter::replaceOpWithMultiple(
+ Operation *op, ArrayRef<ValueRange> newValues) {
+ assert(op->getNumResults() == newValues.size() &&
+ "incorrect # of replacement values");
+ LLVM_DEBUG({
+ impl->logger.startLine()
+ << "** Replace : '" << op->getName() << "'(" << op << ")\n";
+ });
+ SmallVector<ReplacementValues> newVals(newValues.size(), {});
+ for (auto it : llvm::enumerate(newValues))
+ llvm::append_range(newVals[it.index()], it.value());
+ impl->notifyOpReplaced(op, newVals);
}
void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1505,7 +1556,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
+ SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
impl->notifyOpReplaced(op, nullRepls);
}
@@ -2596,7 +2647,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(newValue),
+ MaterializationKind::Source, OpBuilder::InsertPoint::after(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
/*originalType=*/Type(), converter);
>From b59db4636891df96b7569c9737e361431f898909 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 13 Nov 2024 13:11:56 +0900
Subject: [PATCH 2/2] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
.../mlir/Transforms/DialectConversion.h | 4 +--
.../Transforms/SparseTensorCodegen.cpp | 6 ++--
.../Transforms/Utils/DialectConversion.cpp | 33 ++++++++++---------
3 files changed, 21 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e461b7d11602a0..de47765006f81e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -809,9 +809,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// is erased.
void replaceOp(Operation *op, Operation *newOp) override;
- /// Replace the given operation with the new value range. The number of op
+ /// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. If an original SSA value is replaced
- /// by multiple SSA values (i.e., value range has more than 1 element), the
+ /// by multiple SSA values (i.e., a value range has more than 1 element), the
/// conversion driver will insert an argument materialization to convert the
/// N SSA values back into 1 SSA value of the original type. The given
/// operation is erased.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 09509278d7749a..8b6841f202def7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -620,13 +620,11 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
auto flatSize = sparseFlat.size();
packedResultVals.push_back(SmallVector<Value>());
llvm::append_range(packedResultVals.back(),
- iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + flatSize));
+ newCall.getResults().slice(retOffset, flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
- packedResultVals.push_back(SmallVector<Value>());
+ packedResultVals.emplace_back();
packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2f6c0a1ab0bd3b..23d7b059829a9b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1412,7 +1412,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
isUnresolvedMaterialization = true;
// Create mappings for each of the new result values.
- for (auto [n, result] : llvm::zip(newValues, op->getResults())) {
+ for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
ReplacementValues repl = n;
if (repl.empty()) {
// This result was dropped and no replacement value was provided.
@@ -1441,17 +1441,18 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
}
// Remap result to replacement value.
- if (!repl.empty()) {
- if (repl.size() == 1) {
- // Single replacement value: replace directly.
- mapping.map(result, repl.front());
- } else {
- // Multiple replacement values: insert N:1 materialization.
- insertNTo1Materialization(OpBuilder::InsertPoint::after(repl, &domInfo),
- result.getLoc(),
- /*replacements=*/repl, /*outputValue=*/result,
- currentTypeConverter);
- }
+ if (repl.empty())
+ continue;
+
+ if (repl.size() == 1) {
+ // Single replacement value: replace directly.
+ mapping.map(result, repl.front());
+ } else {
+ // Multiple replacement values: insert N:1 materialization.
+ insertNTo1Materialization(OpBuilder::InsertPoint::after(repl, &domInfo),
+ result.getLoc(),
+ /*replacements=*/repl, /*outputValue=*/result,
+ currentTypeConverter);
}
}
@@ -1530,10 +1531,10 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ReplacementValues> newVals(newValues.size(), {});
- for (auto it : llvm::enumerate(newValues))
- if (Value val = it.value())
- newVals[it.index()].push_back(val);
+ SmallVector<ReplacementValues> newVals(newValues.size());
+ for (auto [index, val] : llvm::enumerate(newValues))
+ if (val)
+ newVals[index].push_back(val);
impl->notifyOpReplaced(op, newVals);
}
More information about the llvm-branch-commits
mailing list