[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun May 26 06:07:47 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/93412

This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`

The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.

`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.

Adapted two tests to show what kind of changes are needed: `nvgpu-to-nvvm` and `complex-to-standard`


>From 11dfecd793b6e84224511ef6b065964b64784c49 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 26 May 2024 14:59:09 +0200
Subject: [PATCH] [draft] Dialect Conversion without Rollback

This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`

The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.

`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
---
 .../mlir/Transforms/DialectConversion.h       |  27 +-
 .../Transforms/GreedyPatternRewriteDriver.h   |   6 +
 .../ComplexToStandard/ComplexToStandard.cpp   |   5 +-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |   6 +-
 .../Transforms/Utils/DialectConversion.cpp    |  33 ++-
 .../Utils/GreedyPatternRewriteDriver.cpp      | 280 ++++++++++++++++--
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  26 +-
 7 files changed, 336 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..8d1c125660d10 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
@@ -711,6 +712,17 @@ class ConversionPatternRewriter final : public PatternRewriter {
   LogicalResult getRemappedValues(ValueRange keys,
                                   SmallVectorImpl<Value> &results);
 
+  virtual void setCurrentTypeConverter(const TypeConverter *converter);
+
+  virtual const TypeConverter *getCurrentTypeConverter() const;
+
+  /// Populate the operands that are used for constructing the adapter into
+  /// `remapped`.
+  virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                           std::optional<Location> inputLoc,
+                                           ValueRange values,
+                                           SmallVector<Value> &remapped);
+
   //===--------------------------------------------------------------------===//
   // PatternRewriter Hooks
   //===--------------------------------------------------------------------===//
@@ -755,6 +767,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
+protected:
+  /// Protected constructor for `OneShotConversionPatternRewriter`. Does not
+  /// initialize `impl`.
+  explicit ConversionPatternRewriter(MLIRContext *ctx);
+
+  // Hide unsupported pattern rewriter API.
+  using OpBuilder::setListener;
+
 private:
   // Allow OperationConverter to construct new rewriters.
   friend struct OperationConverter;
@@ -765,9 +785,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   explicit ConversionPatternRewriter(MLIRContext *ctx,
                                      const ConversionConfig &config);
 
-  // Hide unsupported pattern rewriter API.
-  using OpBuilder::setListener;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..21156ca59f674 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,8 @@
 
 namespace mlir {
 
+class ConversionTarget;
+
 /// This enum controls which ops are put on the worklist during a greedy
 /// pattern rewrite.
 enum class GreedyRewriteStrictness {
@@ -188,6 +190,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *changed = nullptr, bool *allErased = nullptr);
 
+LogicalResult
+applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
+                              const FrozenRewritePatternSet &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..885bf7f9822df 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <memory>
 #include <type_traits>
 
@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
+  if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                           std::move(patterns))))
     signalPassFailure();
 }
 } // namespace
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..931cedc0c5eb9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
+    // applyPartialConversion
   }
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..7e48df33fded5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
   setListener(impl.get());
 }
 
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+    : PatternRewriter(ctx), impl(nullptr) {}
+
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1819,6 +1822,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
 
+void ConversionPatternRewriter::setCurrentTypeConverter(
+    const TypeConverter *converter) {
+  impl->currentTypeConverter = converter;
+}
+
+const TypeConverter *
+ConversionPatternRewriter::getCurrentTypeConverter() const {
+  return impl->currentTypeConverter;
+}
+
+LogicalResult ConversionPatternRewriter::getAdapterOperands(
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+    SmallVector<Value> &remapped) {
+  return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPattern
 //===----------------------------------------------------------------------===//
@@ -1827,16 +1846,18 @@ LogicalResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
-  auto &rewriterImpl = dialectRewriter.getImpl();
 
   // Track the current conversion pattern type converter in the rewriter.
-  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
-                                             getTypeConverter());
+  const TypeConverter *currentTypeConverter =
+      dialectRewriter.getCurrentTypeConverter();
+  auto resetTypeConverter = llvm::make_scope_exit(
+      [&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
+  dialectRewriter.setCurrentTypeConverter(getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<Value, 4> operands;
-  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
-                                      op->getOperands(), operands))) {
+  SmallVector<Value> operands;
+  if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
+                                                op->getOperands(), operands))) {
     return failure();
   }
   return matchAndRewrite(op, operands, dialectRewriter);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911b..5312f4a1cacbf 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,10 +14,12 @@
 
 #include "mlir/Config/mlir-config.h"
 #include "mlir/IR/Action.h"
+#include "mlir/IR/Iterators.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/BitVector.h"
@@ -321,7 +323,7 @@ class RandomizedWorklist : public Worklist {
 /// to the worklist in the beginning.
 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 protected:
-  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+  explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config);
 
@@ -329,7 +331,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   void addSingleOpToWorklist(Operation *op);
 
   /// Add the given operation and its ancestors to the worklist.
-  void addToWorklist(Operation *op);
+  virtual void addToWorklist(Operation *op);
 
   /// Notify the driver that the specified operation may have been modified
   /// in-place. The operation is added to the worklist.
@@ -356,7 +358,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 
   /// The pattern rewriter that is used for making IR modifications and is
   /// passed to rewrite patterns.
-  PatternRewriter rewriter;
+  PatternRewriter &rewriter;
 
   /// The worklist for this transformation keeps track of the operations that
   /// need to be (re)visited.
@@ -375,6 +377,11 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
+#ifndef NDEBUG
+  /// A logger used to emit information during the application process.
+  llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+
 private:
   /// Look over the provided operands for any defining operations that should
   /// be re-added to the worklist. This function should be called when an
@@ -394,11 +401,6 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
-#ifndef NDEBUG
-  /// A logger used to emit information during the application process.
-  llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-
   /// The low-level pattern applicator.
   PatternApplicator matcher;
 
@@ -409,9 +411,9 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 } // namespace
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : rewriter(ctx), config(config), matcher(patterns)
+    : rewriter(rewriter), config(config), matcher(patterns)
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       // clang-format off
       , expensiveChecks(
@@ -780,7 +782,7 @@ namespace {
 /// This driver simplfies all ops in a region.
 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
-  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+  explicit RegionPatternRewriteDriver(PatternRewriter &rewriter,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config,
                                       Region &regions);
@@ -796,9 +798,9 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
 } // namespace
 
 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config, Region &region)
-    : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+    : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) {
   // Populate strict mode ops.
   if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
     region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
@@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Start the pattern driver.
-  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
-                                    region);
+  PatternRewriter rewriter(region.getContext());
+  RegionPatternRewriteDriver driver(rewriter, patterns, config, region);
   LogicalResult converged = std::move(driver).simplify(changed);
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -928,7 +930,7 @@ namespace {
 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
-      MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+      PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
       const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
 
@@ -950,10 +952,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 } // namespace
 
 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
     llvm::SmallDenseSet<Operation *, 4> *survivingOps)
-    : GreedyPatternRewriteDriver(ctx, patterns, config),
+    : GreedyPatternRewriteDriver(rewriter, patterns, config),
       survivingOps(survivingOps) {
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Start the pattern driver.
+  PatternRewriter rewriter(ops.front()->getContext());
   llvm::SmallDenseSet<Operation *, 4> surviving;
-  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     config, ops,
+  MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops,
                                      allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplify(ops, changed);
   if (allErased)
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+  OneShotConversionPatternRewriter(MLIRContext *ctx)
+      : ConversionPatternRewriter(ctx) {}
+
+  bool canRecoverFromRewriteFailure() const override { return false; }
+
+  void replaceOp(Operation *op, ValueRange newValues) override;
+
+  void replaceOp(Operation *op, Operation *newOp) override {
+    replaceOp(op, newOp->getResults());
+  }
+
+  void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+  void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override {
+    PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+  }
+  using PatternRewriter::inlineBlockBefore;
+
+  void startOpModification(Operation *op) override {
+    PatternRewriter::startOpModification(op);
+  }
+
+  void finalizeOpModification(Operation *op) override {
+    PatternRewriter::finalizeOpModification(op);
+  }
+
+  void cancelOpModification(Operation *op) override {
+    PatternRewriter::cancelOpModification(op);
+  }
+
+  void setCurrentTypeConverter(const TypeConverter *converter) override {
+    typeConverter = converter;
+  }
+
+  const TypeConverter *getCurrentTypeConverter() const override {
+    return typeConverter;
+  }
+
+  LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                   std::optional<Location> inputLoc,
+                                   ValueRange values,
+                                   SmallVector<Value> &remapped) override;
+
+private:
+  /// Build an unrealized_conversion_cast op or look it up in the cache.
+  Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+  /// The current type converter.
+  const TypeConverter *typeConverter;
+
+  /// A cache for unrealized_conversion_casts. To ensure that identical casts
+  /// are not built multiple times.
+  DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+                                                 ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size());
+  for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+    if (orig.getType() != repl.getType()) {
+      // Type mismatch: insert unrealized_conversion cast.
+      replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+                                   op->getLoc(), orig.getType(), repl));
+    } else {
+      // Same type: use replacement value directly.
+      replaceAllUsesWith(orig, repl);
+    }
+  }
+  eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+    Location loc, Type type, Value value) {
+  auto it = castCache.find(std::make_pair(value, type));
+  if (it != castCache.end())
+    return it->second;
+
+  // Insert cast at the beginning of the block (for block arguments) or right
+  // after the defining op.
+  OpBuilder::InsertionGuard g(*this);
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  setInsertionPoint(insertBlock, insertPt);
+  auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+  castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+  return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+  ConversionPatternRewriteDriver(PatternRewriter &rewriter,
+                                 const FrozenRewritePatternSet &patterns,
+                                 const GreedyRewriteConfig &config,
+                                 const ConversionTarget &target)
+      : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) {
+  }
+
+  /// Populate the worklist with all illegal ops and start the conversion
+  /// process.
+  LogicalResult convert(Operation *op) &&;
+
+protected:
+  void addToWorklist(Operation *op) override;
+
+  /// Notify the driver that the specified operation was removed. Update the
+  /// worklist as needed: The operation and its children are removed from the
+  /// worklist.
+  void notifyOperationErased(Operation *op) override;
+
+private:
+  const ConversionTarget ⌖
+};
+} // namespace
+
+LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && {
+  op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) {
+    auto legalityInfo = target.isLegal(op);
+    if (!legalityInfo) {
+      addSingleOpToWorklist(op);
+      return WalkResult::advance();
+    }
+    if (legalityInfo->isRecursivelyLegal) {
+      // Don't check this operation's children for conversion if the
+      // operation is recursively legal.
+      return WalkResult::skip();
+    }
+    return WalkResult::advance();
+  });
+
+  // Reverse the list so our pop-back loop processes them in-order.
+  // TODO: newly enqueued ops must also be reversed
+  worklist.reverse();
+
+  processWorklist();
+
+  return success();
+}
+
+void ConversionPatternRewriteDriver::addToWorklist(Operation *op) {
+  if (!target.isLegal(op))
+    addSingleOpToWorklist(op);
+}
+
+// TODO: Refactor. This is the same as
+// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to
+// the worklist.
+void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+
+#ifndef NDEBUG
+  // Only ops that are within the configured scope are added to the worklist of
+  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
+  // the part of the IR that is taken into account for the "expensive checks".
+  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
+  // region, as that would break the worklist handling and the expensive checks.
+  if (config.scope && config.scope->getParentOp() == op)
+    llvm_unreachable(
+        "scope region must not be erased during greedy pattern rewrite");
+#endif // NDEBUG
+
+  if (config.listener)
+    config.listener->notifyOperationErased(op);
+
+  worklist.remove(op);
+
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+    strictModeFilteredOps.erase(op);
+}
+
+/// Populate the converted operands in `remapped`. (Based on the currently set
+/// type converter.)
+LogicalResult OneShotConversionPatternRewriter::getAdapterOperands(
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+    SmallVector<Value> &remapped) {
+  // TODO: Refactor. This is mostly copied from the current dialect conversion.
+  for (Value v : values) {
+    // Skip all unrealized_conversion_casts in the chain of defining ops.
+    Value vBase = v;
+    while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>())
+      vBase = castOp.getInputs()[0];
+
+    if (!getCurrentTypeConverter()) {
+      // No type converter set. Just replicate what the current type conversion
+      // is doing.
+      // TODO: We may have to distinguish between newly-inserted an
+      // pre-existing unrealized_conversion_casts.
+      remapped.push_back(vBase);
+      continue;
+    }
+
+    Type desiredType;
+    SmallVector<Type, 1> legalTypes;
+    if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes)))
+      return failure();
+    assert(legalTypes.size() == 1 && "1:N conversion not supported yet");
+    desiredType = legalTypes.front();
+    if (desiredType == vBase.getType()) {
+      // Type already matches. No need to convert anything.
+      remapped.push_back(vBase);
+      continue;
+    }
+
+    Location operandLoc = inputLoc ? *inputLoc : v.getLoc();
+    remapped.push_back(
+        buildUnrealizedConversionCast(operandLoc, desiredType, vBase));
+  }
+  return success();
+}
+
+LogicalResult
+mlir::applyPartialOneShotConversion(Operation *op,
+                                    const ConversionTarget &target,
+                                    const FrozenRewritePatternSet &patterns) {
+  GreedyRewriteConfig config;
+  OneShotConversionPatternRewriter rewriter(op->getContext());
+  ConversionPatternRewriteDriver driver(rewriter, patterns, config, target);
+  return std::move(driver).convert(op);
+}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78d..f059caf079889 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -275,8 +275,8 @@ func.func @async_cp_i4(
 // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index
 func.func @async_cp_zfill_f32_align4(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-  // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
-  // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+  // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
   // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>                                   
   // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
   // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -310,8 +310,8 @@ func.func @async_cp_zfill_f32_align4(
 // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
 func.func @async_cp_zfill_f32_align1(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-    // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
-  // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+  // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
   // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>                                   
   // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
   // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -533,7 +533,9 @@ func.func @mbarrier_nocomplete() {
 }
 
 // CHECK-LABEL: func @mbarrier_wait
+//  CHECK-SAME:     %[[barriers:.*]]: !nvgpu.mbarrier.group
 func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, %token : !tokenType) {
+// CHECK: %[[barriersCast:.*]] = builtin.unrealized_conversion_cast %[[barriers]]
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %n = arith.constant 100 : index
@@ -545,7 +547,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
 // CHECK: scf.for %[[i:.*]] =
 // CHECK: %[[S2:.+]] = arith.remui %[[i]], %[[c5]] : index
 // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64
-// CHECK: %[[S4:.+]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[barriersCast]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
 // CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     %mbarId = arith.remui %i, %numBarriers : index
     %isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType
@@ -871,9 +873,9 @@ func.func @warpgroup_mma_128_128_64(
       %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
       %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) 
 {
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[UD:.+]] =  llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1280,9 +1282,9 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
     to memref<128x128xf32,3>
 
 
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
 // CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1296,7 +1298,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async



More information about the llvm-branch-commits mailing list