[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 29 12:09:28 PDT 2024


================
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+  OneShotConversionPatternRewriter(MLIRContext *ctx)
+      : ConversionPatternRewriter(ctx) {}
+
+  bool canRecoverFromRewriteFailure() const override { return false; }
+
+  void replaceOp(Operation *op, ValueRange newValues) override;
+
+  void replaceOp(Operation *op, Operation *newOp) override {
+    replaceOp(op, newOp->getResults());
+  }
+
+  void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+  void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override {
+    PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+  }
+  using PatternRewriter::inlineBlockBefore;
+
+  void startOpModification(Operation *op) override {
+    PatternRewriter::startOpModification(op);
+  }
+
+  void finalizeOpModification(Operation *op) override {
+    PatternRewriter::finalizeOpModification(op);
+  }
+
+  void cancelOpModification(Operation *op) override {
+    PatternRewriter::cancelOpModification(op);
+  }
+
+  void setCurrentTypeConverter(const TypeConverter *converter) override {
+    typeConverter = converter;
+  }
+
+  const TypeConverter *getCurrentTypeConverter() const override {
+    return typeConverter;
+  }
+
+  LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                   std::optional<Location> inputLoc,
+                                   ValueRange values,
+                                   SmallVector<Value> &remapped) override;
+
+private:
+  /// Build an unrealized_conversion_cast op or look it up in the cache.
+  Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+  /// The current type converter.
+  const TypeConverter *typeConverter;
+
+  /// A cache for unrealized_conversion_casts. To ensure that identical casts
+  /// are not built multiple times.
+  DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+                                                 ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size());
+  for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+    if (orig.getType() != repl.getType()) {
+      // Type mismatch: insert unrealized_conversion cast.
+      replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+                                   op->getLoc(), orig.getType(), repl));
+    } else {
+      // Same type: use replacement value directly.
+      replaceAllUsesWith(orig, repl);
+    }
+  }
+  eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+    Location loc, Type type, Value value) {
+  auto it = castCache.find(std::make_pair(value, type));
+  if (it != castCache.end())
+    return it->second;
+
+  // Insert cast at the beginning of the block (for block arguments) or right
+  // after the defining op.
+  OpBuilder::InsertionGuard g(*this);
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  setInsertionPoint(insertBlock, insertPt);
+  auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+  castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+  return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
----------------
joker-eph wrote:

I don't have problem with dialect conversion complexity actually, other than replaceAllUsesWith not doing what it says it does, and the fact that you have to go through the rewriter for these. But you're not changing these aspects are you?

> I am reluctant to build something that is not compatible with the large number of existing conversion patterns. 

Without rollback you're incompatible anyway, people can't "just adopt it". But otherwise what kind of incompatibility are you worried about here?


Greedy is a fixed-point iterative algorithm which has non-trivial cost associated with it. I don't quite get why it is suitable here?

https://github.com/llvm/llvm-project/pull/93412


More information about the llvm-branch-commits mailing list