[Mlir-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)

Matthias Springer llvmlistbot at llvm.org
Sat Jun 8 03:28:55 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/93412

>From 6c04a065357a15c35c79332c9658036b1073dd5d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 26 May 2024 14:59:09 +0200
Subject: [PATCH] [draft] Dialect Conversion without Rollback

This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`

The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.

`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
---
 .../mlir/Transforms/DialectConversion.h       |  30 +-
 .../Transforms/GreedyPatternRewriteDriver.h   |   8 +
 .../AffineToStandard/AffineToStandard.cpp     |   5 +-
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    |   5 +-
 .../ComplexToStandard/ComplexToStandard.cpp   |   5 +-
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |   5 +-
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp |   5 +-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |   6 +-
 .../Transforms/Utils/DialectConversion.cpp    |  41 ++-
 .../Utils/GreedyPatternRewriteDriver.cpp      | 283 ++++++++++++++++--
 .../AffineToStandard/lower-affine.mlir        |   8 +-
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir |  27 +-
 .../convert-nd-vector-to-llvmir.mlir          |   6 +-
 .../expand-then-convert-to-llvm.mlir          |   4 +-
 .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir |  26 +-
 15 files changed, 379 insertions(+), 85 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..76d56073f72b5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter : public PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
@@ -708,8 +709,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Return the converted values that replace 'keys' with types defined by the
   /// type converter of the currently executing pattern. Returns failure if the
   /// remap failed, success otherwise.
-  LogicalResult getRemappedValues(ValueRange keys,
-                                  SmallVectorImpl<Value> &results);
+  LogicalResult getRemappedValues(ValueRange keys, SmallVector<Value> &results);
+
+  virtual void setCurrentTypeConverter(const TypeConverter *converter);
+
+  virtual const TypeConverter *getCurrentTypeConverter() const;
+
+  /// Populate the operands that are used for constructing the adapter into
+  /// `remapped`.
+  virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                           std::optional<Location> inputLoc,
+                                           ValueRange values,
+                                           SmallVector<Value> &remapped);
 
   //===--------------------------------------------------------------------===//
   // PatternRewriter Hooks
@@ -755,6 +766,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
+protected:
+  /// Protected constructor for `OneShotConversionPatternRewriter`. Does not
+  /// initialize `impl`.
+  explicit ConversionPatternRewriter(MLIRContext *ctx);
+
+  // Hide unsupported pattern rewriter API.
+  using OpBuilder::setListener;
+
 private:
   // Allow OperationConverter to construct new rewriters.
   friend struct OperationConverter;
@@ -765,9 +784,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   explicit ConversionPatternRewriter(MLIRContext *ctx,
                                      const ConversionConfig &config);
 
-  // Hide unsupported pattern rewriter API.
-  using OpBuilder::setListener;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..7513427e0d3ae 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,8 @@
 
 namespace mlir {
 
+class ConversionTarget;
+
 /// This enum controls which ops are put on the worklist during a greedy
 /// pattern rewrite.
 enum class GreedyRewriteStrictness {
@@ -78,6 +80,8 @@ class GreedyRewriteConfig {
   ///   excluded.
   GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
 
+  bool enableOperationDce = true;
+
   /// An optional listener that should be notified about IR modifications.
   RewriterBase::Listener *listener = nullptr;
 };
@@ -188,6 +192,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                        GreedyRewriteConfig config = GreedyRewriteConfig(),
                        bool *changed = nullptr, bool *allErased = nullptr);
 
+LogicalResult
+applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
+                              const FrozenRewritePatternSet &patterns);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 10ccd5c97783b..f7e0ffeaa0d47 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir {
@@ -563,8 +564,8 @@ class LowerAffinePass
     ConversionTarget target(getContext());
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
                            scf::SCFDialect, VectorDialect>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index d882f1184f457..dcc3dd3659cf6 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <type_traits>
 
 namespace mlir {
@@ -479,8 +480,8 @@ struct ArithToLLVMConversionPass
     LLVMTypeConverter converter(&getContext(), options);
     mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
 
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..885bf7f9822df 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <memory>
 #include <type_traits>
 
@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
   target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
-  if (failed(
-          applyPartialConversion(getOperation(), target, std::move(patterns))))
+  if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                           std::move(patterns))))
     signalPassFailure();
 }
 } // namespace
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index b8e5aec25286d..0d81277a82996 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/StringRef.h"
 #include <functional>
 
@@ -240,8 +241,8 @@ struct ConvertControlFlowToLLVM
     LLVMTypeConverter converter(&getContext(), options);
     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
 
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 23e957288eb95..89c5a980d7dea 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
@@ -291,8 +292,8 @@ struct ConvertMathToLLVMPass
     LLVMTypeConverter converter(&getContext());
     populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
     LLVMConversionTarget target(getContext());
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..931cedc0c5eb9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
         converter, patterns, target);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialOneShotConversion(getOperation(), target,
+                                             std::move(patterns))))
       signalPassFailure();
+    // applyPartialConversion
   }
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..8f0d560c05d7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
   setListener(impl.get());
 }
 
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+    : PatternRewriter(ctx), impl(nullptr) {}
+
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1717,19 +1720,17 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
   SmallVector<Value> remappedValues;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
-                               remappedValues)))
+  if (failed(getRemappedValues(key, remappedValues)))
     return nullptr;
   return remappedValues.front();
 }
 
 LogicalResult
 ConversionPatternRewriter::getRemappedValues(ValueRange keys,
-                                             SmallVectorImpl<Value> &results) {
+                                             SmallVector<Value> &results) {
   if (keys.empty())
     return success();
-  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
-                           results);
+  return getAdapterOperands("value", /*inputLoc=*/std::nullopt, keys, results);
 }
 
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1819,6 +1820,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
 
+void ConversionPatternRewriter::setCurrentTypeConverter(
+    const TypeConverter *converter) {
+  impl->currentTypeConverter = converter;
+}
+
+const TypeConverter *
+ConversionPatternRewriter::getCurrentTypeConverter() const {
+  return impl->currentTypeConverter;
+}
+
+LogicalResult ConversionPatternRewriter::getAdapterOperands(
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+    SmallVector<Value> &remapped) {
+  return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPattern
 //===----------------------------------------------------------------------===//
@@ -1827,16 +1844,18 @@ LogicalResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
   auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
-  auto &rewriterImpl = dialectRewriter.getImpl();
 
   // Track the current conversion pattern type converter in the rewriter.
-  llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
-                                             getTypeConverter());
+  const TypeConverter *currentTypeConverter =
+      dialectRewriter.getCurrentTypeConverter();
+  auto resetTypeConverter = llvm::make_scope_exit(
+      [&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
+  dialectRewriter.setCurrentTypeConverter(getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<Value, 4> operands;
-  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
-                                      op->getOperands(), operands))) {
+  SmallVector<Value> operands;
+  if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
+                                                op->getOperands(), operands))) {
     return failure();
   }
   return matchAndRewrite(op, operands, dialectRewriter);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911b..99e82827cdefb 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,10 +14,12 @@
 
 #include "mlir/Config/mlir-config.h"
 #include "mlir/IR/Action.h"
+#include "mlir/IR/Iterators.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/BitVector.h"
@@ -321,7 +323,7 @@ class RandomizedWorklist : public Worklist {
 /// to the worklist in the beginning.
 class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 protected:
-  explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+  explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config);
 
@@ -329,7 +331,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   void addSingleOpToWorklist(Operation *op);
 
   /// Add the given operation and its ancestors to the worklist.
-  void addToWorklist(Operation *op);
+  virtual void addToWorklist(Operation *op);
 
   /// Notify the driver that the specified operation may have been modified
   /// in-place. The operation is added to the worklist.
@@ -356,7 +358,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 
   /// The pattern rewriter that is used for making IR modifications and is
   /// passed to rewrite patterns.
-  PatternRewriter rewriter;
+  PatternRewriter &rewriter;
 
   /// The worklist for this transformation keeps track of the operations that
   /// need to be (re)visited.
@@ -375,6 +377,11 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
+#ifndef NDEBUG
+  /// A logger used to emit information during the application process.
+  llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+
 private:
   /// Look over the provided operands for any defining operations that should
   /// be re-added to the worklist. This function should be called when an
@@ -394,11 +401,6 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
-#ifndef NDEBUG
-  /// A logger used to emit information during the application process.
-  llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-
   /// The low-level pattern applicator.
   PatternApplicator matcher;
 
@@ -409,9 +411,9 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 } // namespace
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : rewriter(ctx), config(config), matcher(patterns)
+    : rewriter(rewriter), config(config), matcher(patterns)
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       // clang-format off
       , expensiveChecks(
@@ -476,7 +478,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     });
 
     // If the operation is trivially dead - remove it.
-    if (isOpTriviallyDead(op)) {
+    if (config.enableOperationDce && isOpTriviallyDead(op)) {
       rewriter.eraseOp(op);
       changed = true;
 
@@ -780,7 +782,7 @@ namespace {
 /// This driver simplfies all ops in a region.
 class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
-  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+  explicit RegionPatternRewriteDriver(PatternRewriter &rewriter,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config,
                                       Region &regions);
@@ -796,9 +798,9 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
 } // namespace
 
 RegionPatternRewriteDriver::RegionPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config, Region &region)
-    : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+    : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) {
   // Populate strict mode ops.
   if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
     region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
@@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Start the pattern driver.
-  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
-                                    region);
+  PatternRewriter rewriter(region.getContext());
+  RegionPatternRewriteDriver driver(rewriter, patterns, config, region);
   LogicalResult converged = std::move(driver).simplify(changed);
   LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -928,7 +930,7 @@ namespace {
 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
-      MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+      PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
       const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
       llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
 
@@ -950,10 +952,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 } // namespace
 
 MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
-    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
     llvm::SmallDenseSet<Operation *, 4> *survivingOps)
-    : GreedyPatternRewriteDriver(ctx, patterns, config),
+    : GreedyPatternRewriteDriver(rewriter, patterns, config),
       survivingOps(survivingOps) {
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Start the pattern driver.
+  PatternRewriter rewriter(ops.front()->getContext());
   llvm::SmallDenseSet<Operation *, 4> surviving;
-  MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     config, ops,
+  MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops,
                                      allErased ? &surviving : nullptr);
   LogicalResult converged = std::move(driver).simplify(ops, changed);
   if (allErased)
@@ -1053,3 +1055,242 @@ LogicalResult mlir::applyOpPatternsAndFold(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+  OneShotConversionPatternRewriter(MLIRContext *ctx)
+      : ConversionPatternRewriter(ctx) {}
+
+  bool canRecoverFromRewriteFailure() const override { return false; }
+
+  void replaceOp(Operation *op, ValueRange newValues) override;
+
+  void replaceOp(Operation *op, Operation *newOp) override {
+    replaceOp(op, newOp->getResults());
+  }
+
+  void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+  void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+  void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+                         ValueRange argValues = std::nullopt) override {
+    PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+  }
+  using PatternRewriter::inlineBlockBefore;
+
+  void startOpModification(Operation *op) override {
+    PatternRewriter::startOpModification(op);
+  }
+
+  void finalizeOpModification(Operation *op) override {
+    PatternRewriter::finalizeOpModification(op);
+  }
+
+  void cancelOpModification(Operation *op) override {
+    PatternRewriter::cancelOpModification(op);
+  }
+
+  void setCurrentTypeConverter(const TypeConverter *converter) override {
+    typeConverter = converter;
+  }
+
+  const TypeConverter *getCurrentTypeConverter() const override {
+    return typeConverter;
+  }
+
+  LogicalResult getAdapterOperands(StringRef valueDiagTag,
+                                   std::optional<Location> inputLoc,
+                                   ValueRange values,
+                                   SmallVector<Value> &remapped) override;
+
+private:
+  /// Build an unrealized_conversion_cast op or look it up in the cache.
+  Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+  /// The current type converter.
+  const TypeConverter *typeConverter;
+
+  /// A cache for unrealized_conversion_casts. To ensure that identical casts
+  /// are not built multiple times.
+  DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+                                                 ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size());
+  for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+    if (orig.getType() != repl.getType()) {
+      // Type mismatch: insert unrealized_conversion cast.
+      replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+                                   op->getLoc(), orig.getType(), repl));
+    } else {
+      // Same type: use replacement value directly.
+      replaceAllUsesWith(orig, repl);
+    }
+  }
+  eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+    Location loc, Type type, Value value) {
+  auto it = castCache.find(std::make_pair(value, type));
+  if (it != castCache.end())
+    return it->second;
+
+  // Insert cast at the beginning of the block (for block arguments) or right
+  // after the defining op.
+  OpBuilder::InsertionGuard g(*this);
+  Block *insertBlock = value.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(value))
+    insertPt = ++inputRes.getOwner()->getIterator();
+  setInsertionPoint(insertBlock, insertPt);
+  auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+  castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+  return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+  ConversionPatternRewriteDriver(PatternRewriter &rewriter,
+                                 const FrozenRewritePatternSet &patterns,
+                                 const GreedyRewriteConfig &config,
+                                 const ConversionTarget &target)
+      : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) {
+  }
+
+  /// Populate the worklist with all illegal ops and start the conversion
+  /// process.
+  LogicalResult convert(Operation *op) &&;
+
+protected:
+  void addToWorklist(Operation *op) override;
+
+  /// Notify the driver that the specified operation was removed. Update the
+  /// worklist as needed: The operation and its children are removed from the
+  /// worklist.
+  void notifyOperationErased(Operation *op) override;
+
+private:
+  const ConversionTarget ⌖
+};
+} // namespace
+
+LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && {
+  op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) {
+    auto legalityInfo = target.isLegal(op);
+    if (!legalityInfo) {
+      addSingleOpToWorklist(op);
+      return WalkResult::advance();
+    }
+    if (legalityInfo->isRecursivelyLegal) {
+      // Don't check this operation's children for conversion if the
+      // operation is recursively legal.
+      return WalkResult::skip();
+    }
+    return WalkResult::advance();
+  });
+
+  // Reverse the list so our pop-back loop processes them in-order.
+  // TODO: newly enqueued ops must also be reversed
+  worklist.reverse();
+
+  processWorklist();
+
+  return success();
+}
+
+void ConversionPatternRewriteDriver::addToWorklist(Operation *op) {
+  if (!target.isLegal(op))
+    addSingleOpToWorklist(op);
+}
+
+// TODO: Refactor. This is the same as
+// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to
+// the worklist.
+void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) {
+  LLVM_DEBUG({
+    logger.startLine() << "** Erase   : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+
+#ifndef NDEBUG
+  // Only ops that are within the configured scope are added to the worklist of
+  // the greedy pattern rewriter. Moreover, the parent op of the scope region is
+  // the part of the IR that is taken into account for the "expensive checks".
+  // A greedy pattern rewrite is not allowed to erase the parent op of the scope
+  // region, as that would break the worklist handling and the expensive checks.
+  if (config.scope && config.scope->getParentOp() == op)
+    llvm_unreachable(
+        "scope region must not be erased during greedy pattern rewrite");
+#endif // NDEBUG
+
+  if (config.listener)
+    config.listener->notifyOperationErased(op);
+
+  worklist.remove(op);
+
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+    strictModeFilteredOps.erase(op);
+}
+
+/// Populate the converted operands in `remapped`. (Based on the currently set
+/// type converter.)
+LogicalResult OneShotConversionPatternRewriter::getAdapterOperands(
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+    SmallVector<Value> &remapped) {
+  // TODO: Refactor. This is mostly copied from the current dialect conversion.
+  for (Value v : values) {
+    // Skip all unrealized_conversion_casts in the chain of defining ops.
+    Value vBase = v;
+    while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>())
+      vBase = castOp.getInputs()[0];
+
+    if (!getCurrentTypeConverter()) {
+      // No type converter set. Just replicate what the current type conversion
+      // is doing.
+      // TODO: We may have to distinguish between newly-inserted an
+      // pre-existing unrealized_conversion_casts.
+      remapped.push_back(vBase);
+      continue;
+    }
+
+    Type desiredType;
+    SmallVector<Type, 1> legalTypes;
+    if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes)))
+      return failure();
+    assert(legalTypes.size() == 1 && "1:N conversion not supported yet");
+    desiredType = legalTypes.front();
+    if (desiredType == vBase.getType()) {
+      // Type already matches. No need to convert anything.
+      remapped.push_back(vBase);
+      continue;
+    }
+
+    Location operandLoc = inputLoc ? *inputLoc : v.getLoc();
+    remapped.push_back(
+        buildUnrealizedConversionCast(operandLoc, desiredType, vBase));
+  }
+  return success();
+}
+
+LogicalResult
+mlir::applyPartialOneShotConversion(Operation *op,
+                                    const ConversionTarget &target,
+                                    const FrozenRewritePatternSet &patterns) {
+  GreedyRewriteConfig config;
+  config.enableOperationDce = false;
+  OneShotConversionPatternRewriter rewriter(op->getContext());
+  ConversionPatternRewriteDriver driver(rewriter, patterns, config, target);
+  return std::move(driver).convert(op);
+}
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb..865840a751fa6 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -430,7 +430,7 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
 #map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>
 
 // CHECK-LABEL: func @affine_applies(
-func.func @affine_applies(%arg0 : index) {
+func.func @affine_applies(%arg0 : index) -> (index, index, index, index, index) {
 // CHECK: %[[c0:.*]] = arith.constant 0 : index
   %zero = affine.apply #map0()
 
@@ -442,9 +442,7 @@ func.func @affine_applies(%arg0 : index) {
   %102 = arith.constant 102 : index
   %copy = affine.apply #map2(%zero)
 
-// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] : index
 // CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] : index
   %one = affine.apply #map3(%symbZero)[%zero]
 
 // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
@@ -466,7 +464,9 @@ func.func @affine_applies(%arg0 : index) {
 // CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
 // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
   %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
-  return
+
+// CHECK: return %[[c0]], %[[c0]], %[[c0]], %[[c1]], %[[v13]]
+  return %zero, %symbZero, %copy, %one, %four : index, index, index, index, index
 }
 
 // CHECK-LABEL: func @args_ret_affine_apply(
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 56ae930e6d627..2360d6e71dedc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -321,9 +321,9 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
 // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
 // CHECK-NEXT: = llvm.sext %[[ARG0]] : vector<1xi3> to vector<1xi6>
   %0 = arith.extsi %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6>
+// CHECK: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6>
   %1 = arith.extui %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2>
+// CHECK: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2>
   %2 = arith.trunci %arg0 : vector<i3> to vector<i2>
   return
 }
@@ -478,11 +478,12 @@ func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -
 // -----
 
 // CHECK-LABEL: func @cmpf_2dvector(
+//  CHECK-SAME:     %[[FARG0:.*]]: vector<4x3xf32>, %[[FARG1:.*]]: vector<4x3xf32>
 func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>>
-  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
+  // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>>
   // CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32>
   // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>>
   %0 = arith.cmpf olt, %arg0, %arg1 : vector<4x3xf32>
@@ -492,9 +493,10 @@ func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
 // -----
 
 // CHECK-LABEL: func @cmpi_0dvector(
+//  CHECK-SAME:     %[[FARG0:.*]]: vector<i32>, %[[FARG1:.*]]: vector<i32>
 func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
   // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
   %0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
   func.return
@@ -503,11 +505,12 @@ func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
 // -----
 
 // CHECK-LABEL: func @cmpi_2dvector(
+//  CHECK-SAME:     %[[FARG0:.*]]: vector<4x3xi32>, %[[FARG1:.*]]: vector<4x3xi32>
 func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
-  // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>>
-  // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
+  // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>>
+  // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
   // CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32>
   // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>>
   %0 = arith.cmpi ult, %arg0, %arg1 : vector<4x3xi32>
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index 63989347567b5..b234cbbb35f32 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -199,9 +199,9 @@ func.func @bitcast_2d(%arg0: vector<2x4xf32>) {
 
 // CHECK-LABEL: func @select_2d(
 func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) {
-  // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
-  // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
-  // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
+  // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
+  // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
+  // CHECK-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
   // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi1>>
   // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
   // CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[ARG2]][0] : !llvm.array<4 x vector<3xi32>>
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index b86103422b074..599bb6ca14c48 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -294,6 +294,7 @@ func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -
 // CHECK:         %[[MEM:[a-zA-Z0-9]*]]: memref
 func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?xf32, strided<[?, 1], offset: ?>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEMREF]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
@@ -316,7 +317,6 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
   // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Sizes and strides @rank 1: static stride 1, dynamic size unchanged from source memref.
   // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
   // CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   %c0 = arith.constant 1 : index
@@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
 // CHECK:           %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
-// CHECK:           %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
 // CHECK:           %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]]  : i64
 // CHECK:           %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
 // CHECK:           %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 1b046d32f163a..57f805307a065 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -275,8 +275,8 @@ func.func @async_cp_i4(
 // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index
 func.func @async_cp_zfill_f32_align4(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-  // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
-  // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+  // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
   // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>                                   
   // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
   // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -310,8 +310,8 @@ func.func @async_cp_zfill_f32_align4(
 // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
 func.func @async_cp_zfill_f32_align1(
   %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
-    // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
-  // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+  // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+  // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
   // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>                                   
   // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
   // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -533,7 +533,9 @@ func.func @mbarrier_nocomplete() {
 }
 
 // CHECK-LABEL: func @mbarrier_wait
+//  CHECK-SAME:     %[[barriers:.*]]: !nvgpu.mbarrier.group
 func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, %token : !tokenType) {
+// CHECK: %[[barriersCast:.*]] = builtin.unrealized_conversion_cast %[[barriers]]
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %n = arith.constant 100 : index
@@ -545,7 +547,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
 // CHECK: scf.for %[[i:.*]] =
 // CHECK: %[[S2:.+]] = arith.remui %[[i]], %[[c5]] : index
 // CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64
-// CHECK: %[[S4:.+]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[barriersCast]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> 
 // CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
     %mbarId = arith.remui %i, %numBarriers : index
     %isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType
@@ -871,9 +873,9 @@ func.func @warpgroup_mma_128_128_64(
       %descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
       %acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>) 
 {
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[UD:.+]] =  llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1280,9 +1282,9 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
     to memref<128x128xf32,3>
 
 
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
 // CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
@@ -1296,7 +1298,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
 // CHECK: nvvm.wgmma.fence.aligned
 // CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
 // CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)> 
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: nvvm.wgmma.mma_async
 // CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async



More information about the Mlir-commits mailing list