[Mlir-commits] [mlir] [mlir][Transforms] Dialect Conversion: Add `replaceOpWithMultiple` (PR #115816)

Matthias Springer llvmlistbot at llvm.org
Wed Nov 13 16:34:56 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/115816

>From 001453c12e18c8e5f8d1b762cf5f67d56ec542e9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 12 Nov 2024 05:14:43 +0100
Subject: [PATCH] replace with multiple
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Apply suggestions from code review

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

address comments
---
 .../mlir/Transforms/DialectConversion.h       |  24 +++-
 .../Transforms/DecomposeCallGraphTypes.cpp    |  40 ++----
 .../Transforms/SparseTensorCodegen.cpp        |  43 +++---
 .../Utils/SparseTensorDescriptor.cpp          |  21 +--
 .../Transforms/Utils/DialectConversion.cpp    | 135 +++++++++++++-----
 5 files changed, 164 insertions(+), 99 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..de47765006f81e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the new values. The number of op results
+  /// and replacement values must match. The types may differ: the dialect
+  /// conversion driver will reconcile any surviving type mismatches at the end
+  /// of the conversion process with source materializations. The given
+  /// operation is erased.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the results of the new op. The number of
+  /// op results must match. The types may differ: the dialect conversion
+  /// driver will reconcile any surviving type mismatches at the end of the
+  /// conversion process with source materializations. The original operation
+  /// is erased.
   void replaceOp(Operation *op, Operation *newOp) override;
 
+  /// Replace the given operation with the new value ranges. The number of op
+  /// results and value ranges must match. If an original SSA value is replaced
+  /// by multiple SSA values (i.e., a value range has more than 1 element), the
+  /// conversion driver will insert an argument materialization to convert the
+  /// N SSA values back into 1 SSA value of the original type. The given
+  /// operation is erased.
+  ///
+  /// Note: The argument materialization is a workaround until we have full 1:N
+  /// support in the dialect conversion. (It is going to disappear from both
+  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
   /// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
                                         getTypeConverter()));
     }
 
-    // Create the new result types for the new `CallOp` and track the indices in
-    // the new call op's results that correspond to the old call op's results.
-    //
-    // expandedResultIndices[i] = "list of new result indices that old result i
-    // expanded to".
+    // Create the new result types for the new `CallOp` and track the number of
+    // replacement types for each original op result.
     SmallVector<Type, 2> newResultTypes;
-    SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+    SmallVector<unsigned> expandedResultSizes;
     for (Type resultType : op.getResultTypes()) {
       unsigned oldSize = newResultTypes.size();
       if (failed(typeConverter->convertType(resultType, newResultTypes)))
         return failure();
-      auto &resultMapping = expandedResultIndices.emplace_back();
-      for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
-        resultMapping.push_back(i);
+      expandedResultSizes.push_back(newResultTypes.size() - oldSize);
     }
 
     CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
                                                newResultTypes, newOperands);
 
-    // Build a replacement value for each result to replace its uses. If a
-    // result has multiple mapping values, it needs to be materialized as a
-    // single value.
-    SmallVector<Value, 2> replacedValues;
+    // Build a replacement value for each result to replace its uses.
+    SmallVector<ValueRange> replacedValues;
     replacedValues.reserve(op.getNumResults());
+    unsigned startIdx = 0;
     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
-      auto decomposedValues = llvm::to_vector<6>(
-          llvm::map_range(expandedResultIndices[i],
-                          [&](unsigned i) { return newCallOp.getResult(i); }));
-      if (decomposedValues.empty()) {
-        // No replacement is required.
-        replacedValues.push_back(nullptr);
-      } else if (decomposedValues.size() == 1) {
-        replacedValues.push_back(decomposedValues.front());
-      } else {
-        // Materialize a single Value to replace the original Value.
-        Value materialized = getTypeConverter()->materializeArgumentConversion(
-            rewriter, op.getLoc(), op.getType(i), decomposedValues);
-        replacedValues.push_back(materialized);
-      }
+      ValueRange repl =
+          newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+      replacedValues.push_back(repl);
+      startIdx += expandedResultSizes[i];
     }
-    rewriter.replaceOp(op, replacedValues);
+    rewriter.replaceOpWithMultiple(op, replacedValues);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..6b154dd1613c0d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     flattenOperands(adaptor.getOperands(), flattened);
     auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
                                                  finalRetTy, flattened);
-    // (2) Create cast operation for sparse tensor returns.
-    SmallVector<Value> castedRet;
+    // (2) Gather sparse tensor returns.
+    SmallVector<SmallVector<Value>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
     // relative to the new call (after sparse tensor flattening);
     unsigned retOffset = 0;
@@ -618,21 +618,22 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
       assert(!sparseFlat.empty());
       if (sparseFlat.size() > 1) {
         auto flatSize = sparseFlat.size();
-        ValueRange fields(iterator_range<ResultRange::iterator>(
-            newCall.result_begin() + retOffset,
-            newCall.result_begin() + retOffset + flatSize));
-        castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+	packedResultVals.emplace_back();
+        llvm::append_range(packedResultVals.back(),
+                           newCall.getResults().slice(retOffset, flatSize));
         retOffset += flatSize;
       } else {
         // If this is an 1:1 conversion, no need for casting.
-        castedRet.push_back(newCall.getResult(retOffset));
+        packedResultVals.emplace_back();
+        packedResultVals.back().push_back(newCall.getResult(retOffset));
         retOffset++;
       }
       sparseFlat.clear();
     }
 
-    assert(castedRet.size() == op.getNumResults());
-    rewriter.replaceOp(op, castedRet);
+    assert(packedResultVals.size() == op.getNumResults());
+    rewriter.replaceOpWithMultiple(
+        op, llvm::to_vector_of<ValueRange>(packedResultVals));
     return success();
   }
 };
@@ -776,7 +777,7 @@ class SparseTensorAllocConverter
       // Reuses specifier.
       fields.push_back(desc.getSpecifier());
       assert(fields.size() == desc.getNumFields());
-      rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+      rewriter.replaceOpWithMultiple(op, {fields});
       return success();
     }
 
@@ -796,7 +797,7 @@ class SparseTensorAllocConverter
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -837,7 +838,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -893,7 +894,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1006,7 +1007,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<scf::YieldOp>(loc, insertRet);
 
     rewriter.setInsertionPointAfter(loop);
-    Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = getTop(op);
     rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1014,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOpWithMultiple(op, {loop->getResults()});
     return success();
   }
 };
@@ -1041,8 +1041,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op,
-                       genTuple(rewriter, loc, op.getDest().getType(), ret));
+    rewriter.replaceOpWithMultiple(op, {ret});
     return success();
   }
 };
@@ -1215,8 +1214,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
           return true;
         });
 
-    rewriter.replaceOp(
-        op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
@@ -1271,8 +1269,7 @@ class SparseExtractSliceConverter
     // NOTE: we can not generate tuples directly from descriptor here, as the
     // descriptor is holding the original type, yet we want the slice type
     // here (they shared every memref but with an updated specifier).
-    rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
-                                    desc.getFields()));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     }
     desc.setValMemSize(rewriter, loc, memSize);
 
-    rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
                    EmitCInterface::Off);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
 // The sparse tensor type converter (defined in Passes.h).
 //===----------------------------------------------------------------------===//
 
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+                              ValueRange inputs, Location loc) {
+  if (!getSparseTensorEncoding(tp))
+    // Not a sparse tensor.
+    return Value();
+  // Sparsifier knows how to cancel out these casts.
+  return genTuple(builder, loc, tp, inputs);
+}
+
 SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
   addConversion([](Type type) { return type; });
   addConversion(convertSparseTensorType);
 
   // Required by scf.for 1:N type conversion.
-  addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
-                              ValueRange inputs, Location loc) -> Value {
-    if (!getSparseTensorEncoding(tp))
-      // Not a sparse tensor.
-      return Value();
-    // Sparsifier knows how to cancel out these casts.
-    return genTuple(builder, loc, tp, inputs);
-  });
+  addSourceMaterialization(materializeTuple);
+
+  // Required as a workaround until we have full 1:N support.
+  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..5b2cfd370900a8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,6 +67,10 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
 namespace {
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +822,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Type originalType,
                                        const TypeConverter *converter);
 
+  /// Build an N:1 materialization for the given original value that was
+  /// replaced with the given replacement values.
+  ///
+  /// This is a workaround around incomplete 1:N support in the dialect
+  /// conversion driver. The conversion mapping can store only 1:1 replacements
+  /// and the conversion patterns only support single Value replacements in the
+  /// adaptor, so N values must be converted back to a single value. This
+  /// function will be deleted when full 1:N support has been added.
+  ///
+  /// This function inserts an argument materialization back to the original
+  /// type, followed by a target materialization to the legalized type (if
+  /// applicable).
+  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+                                 ValueRange replacements, Value originalValue,
+                                 const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -827,7 +847,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ValueRange newValues);
+  void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -1148,7 +1168,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       // source materialization was created yet.
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+          operandLoc,
+          /*inputs=*/newOperand, /*outputType=*/desiredType,
           /*originalType=*/origType, currentTypeConverter);
       mapping.map(newOperand, castValue);
       newOperand = castValue;
@@ -1287,33 +1308,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    Value argMat = buildUnresolvedMaterialization(
-        MaterializationKind::Argument,
+    insertNTo1Materialization(
         OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*inputs=*/replArgs, /*outputType=*/origArgType,
-        /*originalType=*/Type(), converter);
-    mapping.map(origArg, argMat);
-
-    Type legalOutputType;
-    if (converter) {
-      legalOutputType = converter->convertType(origArgType);
-    } else if (replArgs.size() == 1) {
-      // When there is no type converter, assume that the new block argument
-      // types are legal. This is reasonable to assume because they were
-      // specified by the user.
-      // FIXME: This won't work for 1->N conversions because multiple output
-      // types are not supported in parts of the dialect conversion. In such a
-      // case, we currently use the original block argument type (produced by
-      // the argument materialization).
-      legalOutputType = replArgs[0].getType();
-    }
-    if (legalOutputType && legalOutputType != origArgType) {
-      Value targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(argMat),
-          origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
-          /*originalType=*/origArgType, converter);
-      mapping.map(argMat, targetMat);
-    }
+        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
   }
 
@@ -1354,6 +1351,39 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   return convertOp.getResult(0);
 }
 
+void ConversionPatternRewriterImpl::insertNTo1Materialization(
+    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
+    Value originalValue, const TypeConverter *converter) {
+  // Insert argument materialization back to the original type.
+  Type originalType = originalValue.getType();
+  Value argMat =
+      buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
+                                     /*inputs=*/replacements, originalType,
+                                     /*originalType=*/Type(), converter);
+  mapping.map(originalValue, argMat);
+
+  // Insert target materialization to the legalized type.
+  Type legalOutputType;
+  if (converter) {
+    legalOutputType = converter->convertType(originalType);
+  } else if (replacements.size() == 1) {
+    // When there is no type converter, assume that the replacement value
+    // types are legal. This is reasonable to assume because they were
+    // specified by the user.
+    // FIXME: This won't work for 1->N conversions because multiple output
+    // types are not supported in parts of the dialect conversion. In such a
+    // case, we currently use the original value type.
+    legalOutputType = replacements[0].getType();
+  }
+  if (legalOutputType && legalOutputType != originalType) {
+    Value targetMat = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(argMat), loc,
+        /*inputs=*/argMat, /*outputType=*/legalOutputType,
+        /*originalType=*/originalType, converter);
+    mapping.map(argMat, targetMat);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1377,8 +1407,8 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
   appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
 }
 
-void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
-                                                     ValueRange newValues) {
+void ConversionPatternRewriterImpl::notifyOpReplaced(
+    Operation *op, ArrayRef<ReplacementValues> newValues) {
   assert(newValues.size() == op->getNumResults());
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
@@ -1390,8 +1420,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
       isUnresolvedMaterialization = true;
 
   // Create mappings for each of the new result values.
-  for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
-    if (!newValue) {
+  for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
+    ReplacementValues repl = n;
+    if (repl.empty()) {
       // This result was dropped and no replacement value was provided.
       if (isUnresolvedMaterialization) {
         // Do not create another materializations if we are erasing a
@@ -1400,11 +1431,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
       }
 
       // Materialize a replacement value "out of thin air".
-      newValue = buildUnresolvedMaterialization(
+      Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
+      repl.push_back(sourceMat);
     } else {
       // Make sure that the user does not mess with unresolved materializations
       // that were inserted by the conversion driver. We keep track of these
@@ -1417,12 +1449,21 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
     }
 
     // Remap result to replacement value.
-    if (newValue)
-      mapping.map(result, newValue);
+    if (repl.empty())
+      continue;
+
+    if (repl.size() == 1) {
+      // Single replacement value: replace directly.
+      mapping.map(result, repl.front());
+    } else {
+      // Multiple replacement values: insert N:1 materialization.
+      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
+                                /*replacements=*/repl, /*outputValue=*/result,
+                                currentTypeConverter);
+    }
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
-
   // Mark this operation and all nested ops as replaced.
   op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
@@ -1497,7 +1538,25 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  impl->notifyOpReplaced(op, newValues);
+  SmallVector<ReplacementValues> newVals(newValues.size());
+  for (auto [index, val] : llvm::enumerate(newValues))
+    if (val)
+      newVals[index].push_back(val);
+  impl->notifyOpReplaced(op, newVals);
+}
+
+void ConversionPatternRewriter::replaceOpWithMultiple(
+    Operation *op, ArrayRef<ValueRange> newValues) {
+  assert(op->getNumResults() == newValues.size() &&
+         "incorrect # of replacement values");
+  LLVM_DEBUG({
+    impl->logger.startLine()
+        << "** Replace : '" << op->getName() << "'(" << op << ")\n";
+  });
+  SmallVector<ReplacementValues> newVals(newValues.size(), {});
+  for (auto [index, val] : llvm::enumerate(newValues))
+    llvm::append_range(newVals[index], val);
+  impl->notifyOpReplaced(op, newVals);
 }
 
 void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1505,7 +1564,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
     impl->logger.startLine()
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
+  SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
   impl->notifyOpReplaced(op, nullRepls);
 }
 



More information about the Mlir-commits mailing list