[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue May 28 14:04:39 PDT 2024


================
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+  OneShotConversionPatternRewriter(MLIRContext *ctx)
+      : ConversionPatternRewriter(ctx) {}
+
+  bool canRecoverFromRewriteFailure() const override { return false; }
+
+  void replaceOp(Operation *op, ValueRange newValues) override;
+
+  void replaceOp(Operation *op, Operation *newOp) override {
+    replaceOp(op, newOp->getResults());
+  }
+
+  void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+  void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override {
+    PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+  }
+  using PatternRewriter::inlineBlockBefore;
+
+  void startOpModification(Operation *op) override {
+    PatternRewriter::startOpModification(op);
+  }
+
+  void finalizeOpModification(Operation *op) override {
+    PatternRewriter::finalizeOpModification(op);
+  }
+
+  void cancelOpModification(Operation *op) override {
+    PatternRewriter::cancelOpModification(op);
+  }
+
+  void setCurrentTypeConverter(const TypeConverter *converter) override {
+    typeConverter = converter;
+  }
+
+  const TypeConverter *getCurrentTypeConverter() const override {
+    return typeConverter;
+  }
+
+  LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                   std::optional<Location> inputLoc,
+                                   ValueRange values,
+                                   SmallVector<Value> &remapped) override;
+
+private:
+  /// Build an unrealized_conversion_cast op or look it up in the cache.
+  Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+  /// The current type converter.
+  const TypeConverter *typeConverter;
+
+  /// A cache for unrealized_conversion_casts. To ensure that identical casts
+  /// are not built multiple times.
+  DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+                                                 ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size());
+  for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+    if (orig.getType() != repl.getType()) {
+      // Type mismatch: insert unrealized_conversion cast.
+      replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+                                   op->getLoc(), orig.getType(), repl));
+    } else {
+      // Same type: use replacement value directly.
+      replaceAllUsesWith(orig, repl);
+    }
+  }
+  eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+    Location loc, Type type, Value value) {
+  auto it = castCache.find(std::make_pair(value, type));
+  if (it != castCache.end())
+    return it->second;
+
+  // Insert cast at the beginning of the block (for block arguments) or right
+  // after the defining op.
+  OpBuilder::InsertionGuard g(*this);
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  setInsertionPoint(insertBlock, insertPt);
+  auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+  castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+  return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
----------------
joker-eph wrote:

Relying on the greedy driver makes me concerned about the "one shot" aspect.
I would really like to see a **simple** implementation, with the minimum amount of bookkeeping (ideally not even a worklist!), and relying on the GreedyPatternRewriteDriver does not really go in this direction right now.

https://github.com/llvm/llvm-project/pull/93412


More information about the llvm-branch-commits mailing list