[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 29 13:00:24 PDT 2024


================
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+  OneShotConversionPatternRewriter(MLIRContext *ctx)
+      : ConversionPatternRewriter(ctx) {}
+
+  bool canRecoverFromRewriteFailure() const override { return false; }
+
+  void replaceOp(Operation *op, ValueRange newValues) override;
+
+  void replaceOp(Operation *op, Operation *newOp) override {
+    replaceOp(op, newOp->getResults());
+  }
+
+  void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+  void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override {
+    PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+  }
+  using PatternRewriter::inlineBlockBefore;
+
+  void startOpModification(Operation *op) override {
+    PatternRewriter::startOpModification(op);
+  }
+
+  void finalizeOpModification(Operation *op) override {
+    PatternRewriter::finalizeOpModification(op);
+  }
+
+  void cancelOpModification(Operation *op) override {
+    PatternRewriter::cancelOpModification(op);
+  }
+
+  void setCurrentTypeConverter(const TypeConverter *converter) override {
+    typeConverter = converter;
+  }
+
+  const TypeConverter *getCurrentTypeConverter() const override {
+    return typeConverter;
+  }
+
+  LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                   std::optional<Location> inputLoc,
+                                   ValueRange values,
+                                   SmallVector<Value> &remapped) override;
+
+private:
+  /// Build an unrealized_conversion_cast op or look it up in the cache.
+  Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+  /// The current type converter.
+  const TypeConverter *typeConverter;
+
+  /// A cache for unrealized_conversion_casts. To ensure that identical casts
+  /// are not built multiple times.
+  DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+                                                 ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size());
+  for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+    if (orig.getType() != repl.getType()) {
+      // Type mismatch: insert unrealized_conversion cast.
+      replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+                                   op->getLoc(), orig.getType(), repl));
+    } else {
+      // Same type: use replacement value directly.
+      replaceAllUsesWith(orig, repl);
+    }
+  }
+  eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+    Location loc, Type type, Value value) {
+  auto it = castCache.find(std::make_pair(value, type));
+  if (it != castCache.end())
+    return it->second;
+
+  // Insert cast at the beginning of the block (for block arguments) or right
+  // after the defining op.
+  OpBuilder::InsertionGuard g(*this);
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  setInsertionPoint(insertBlock, insertPt);
+  auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+  castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+  return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
----------------
matthias-springer wrote:

> I don't have problem with dialect conversion complexity actually, other than replaceAllUsesWith not doing what it says it does

I tried to fix that [here](https://github.com/llvm/llvm-project/pull/84725). It sounded like the rollback logic was getting too complex. (I would agree with that.)

> Without rollback you're incompatible anyway, people can't "just adopt it".

I expect that most dialect conversions do not actually need the rollback. I am not aware of any passes in MLIR that require it. (But users may have their own passes that rely on it, that's why I asked in the RFC.) So my hope is that the existing patterns just work with the new driver. I only tried it with `NVGPUToNVVM.cpp` and `ComplexToStandard.cpp` so far, and have to try migrating a few more passes to get a better picture.

But you are right, it is not a general replacement. That's why I would keep both drivers side-by-side for a while, so that there is time to handle the tricky cases (if any). (And patterns can be updated gradually. And they can be updated in such a way that they stay compatible with both drivers.)

> Greedy is a fixed-point iterative algorithm which has non-trivial cost associated with it. I don't quite get why it is suitable here?

Part of the answer is that I wanted to build something that is compatible. And that requires us to maintain some sort of worklist or resort to (unbounded) recursion, as the current implementation does.

There are some differences to a greedy fixed-point rewrite. (The following are properties of the current and the "new" dialect conversion.)

1. Ops are always processed from top to bottom.
2. Patterns are applied only to illegal root ops.
3. A pattern must either remove the illegal root op or in-place modify it. (We may be able to tighten the rules to require that an in-place modification *must* make the op legal. That would gives us a stronger guarantee about the worst-case number of pattern applications.)

New ops are allowed to be created, that's why the worklist is needed. The second property is an important one: we do not put an op onto the worklist just because something changed in the vicinity of the op (like in a greedy pattern rewrite).

So I wouldn't call it a greedy rewrite. Jacques asked how this is different from equipping the greedy pattern rewrite driver with a legality function. That got me thinking: large parts of the greedy driver can be reused, only the condition when to put ops on the worklist is different. The only reason why this implementation is in `GreedyPatternRewriteDriver.cpp` is to safe a few lines code.


https://github.com/llvm/llvm-project/pull/93412


More information about the llvm-branch-commits mailing list