[llvm-branch-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun May 26 06:07:47 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/93412
This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`
The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.
`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
Adapted two tests to show what kind of changes are needed: `nvgpu-to-nvvm` and `complex-to-standard`
>From 11dfecd793b6e84224511ef6b065964b64784c49 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 26 May 2024 14:59:09 +0200
Subject: [PATCH] [draft] Dialect Conversion without Rollback
This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`
The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.
`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
---
.../mlir/Transforms/DialectConversion.h | 27 +-
.../Transforms/GreedyPatternRewriteDriver.h | 6 +
.../ComplexToStandard/ComplexToStandard.cpp | 5 +-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 6 +-
.../Transforms/Utils/DialectConversion.cpp | 33 ++-
.../Utils/GreedyPatternRewriteDriver.cpp | 280 ++++++++++++++++--
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 26 +-
7 files changed, 336 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..8d1c125660d10 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
- template <typename TargetType> TargetType convertType(Type t) const {
+ template <typename TargetType>
+ TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter : public PatternRewriter {
public:
~ConversionPatternRewriter() override;
@@ -711,6 +712,17 @@ class ConversionPatternRewriter final : public PatternRewriter {
LogicalResult getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results);
+ virtual void setCurrentTypeConverter(const TypeConverter *converter);
+
+ virtual const TypeConverter *getCurrentTypeConverter() const;
+
+ /// Populate the operands that are used for constructing the adapter into
+ /// `remapped`.
+ virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
+ std::optional<Location> inputLoc,
+ ValueRange values,
+ SmallVector<Value> &remapped);
+
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
@@ -755,6 +767,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
+protected:
+ /// Protected constructor for `OneShotConversionPatternRewriter`. Does not
+ /// initialize `impl`.
+ explicit ConversionPatternRewriter(MLIRContext *ctx);
+
+ // Hide unsupported pattern rewriter API.
+ using OpBuilder::setListener;
+
private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
@@ -765,9 +785,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
explicit ConversionPatternRewriter(MLIRContext *ctx,
const ConversionConfig &config);
- // Hide unsupported pattern rewriter API.
- using OpBuilder::setListener;
-
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..21156ca59f674 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,8 @@
namespace mlir {
+class ConversionTarget;
+
/// This enum controls which ops are put on the worklist during a greedy
/// pattern rewrite.
enum class GreedyRewriteStrictness {
@@ -188,6 +190,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
+LogicalResult
+applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..885bf7f9822df 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include <type_traits>
@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
- if (failed(
- applyPartialConversion(getOperation(), target, std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
} // namespace
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..931cedc0c5eb9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
+ // applyPartialConversion
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..7e48df33fded5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
setListener(impl.get());
}
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ : PatternRewriter(ctx), impl(nullptr) {}
+
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1819,6 +1822,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
+void ConversionPatternRewriter::setCurrentTypeConverter(
+ const TypeConverter *converter) {
+ impl->currentTypeConverter = converter;
+}
+
+const TypeConverter *
+ConversionPatternRewriter::getCurrentTypeConverter() const {
+ return impl->currentTypeConverter;
+}
+
+LogicalResult ConversionPatternRewriter::getAdapterOperands(
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+ SmallVector<Value> &remapped) {
+ return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
+}
+
//===----------------------------------------------------------------------===//
// ConversionPattern
//===----------------------------------------------------------------------===//
@@ -1827,16 +1846,18 @@ LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
- auto &rewriterImpl = dialectRewriter.getImpl();
// Track the current conversion pattern type converter in the rewriter.
- llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
- getTypeConverter());
+ const TypeConverter *currentTypeConverter =
+ dialectRewriter.getCurrentTypeConverter();
+ auto resetTypeConverter = llvm::make_scope_exit(
+ [&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
+ dialectRewriter.setCurrentTypeConverter(getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ SmallVector<Value> operands;
+ if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
+ op->getOperands(), operands))) {
return failure();
}
return matchAndRewrite(op, operands, dialectRewriter);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911b..5312f4a1cacbf 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,10 +14,12 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
+#include "mlir/IR/Iterators.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
@@ -321,7 +323,7 @@ class RandomizedWorklist : public Worklist {
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
- explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+ explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config);
@@ -329,7 +331,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
void addSingleOpToWorklist(Operation *op);
/// Add the given operation and its ancestors to the worklist.
- void addToWorklist(Operation *op);
+ virtual void addToWorklist(Operation *op);
/// Notify the driver that the specified operation may have been modified
/// in-place. The operation is added to the worklist.
@@ -356,7 +358,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
- PatternRewriter rewriter;
+ PatternRewriter &rewriter;
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
@@ -375,6 +377,11 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+#ifndef NDEBUG
+ /// A logger used to emit information during the application process.
+ llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+
private:
/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
@@ -394,11 +401,6 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
-#ifndef NDEBUG
- /// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-
/// The low-level pattern applicator.
PatternApplicator matcher;
@@ -409,9 +411,9 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : rewriter(ctx), config(config), matcher(patterns)
+ : rewriter(rewriter), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
@@ -780,7 +782,7 @@ namespace {
/// This driver simplfies all ops in a region.
class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
- explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+ explicit RegionPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
Region ®ions);
@@ -796,9 +798,9 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
} // namespace
RegionPatternRewriteDriver::RegionPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, Region ®ion)
- : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) {
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
@@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
- RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
- region);
+ PatternRewriter rewriter(region.getContext());
+ RegionPatternRewriteDriver driver(rewriter, patterns, config, region);
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -928,7 +930,7 @@ namespace {
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
@@ -950,10 +952,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
} // namespace
MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
- : GreedyPatternRewriteDriver(ctx, patterns, config),
+ : GreedyPatternRewriteDriver(rewriter, patterns, config),
survivingOps(survivingOps) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
+ PatternRewriter rewriter(ops.front()->getContext());
llvm::SmallDenseSet<Operation *, 4> surviving;
- MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- config, ops,
+ MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
@@ -1053,3 +1055,241 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+ OneShotConversionPatternRewriter(MLIRContext *ctx)
+ : ConversionPatternRewriter(ctx) {}
+
+ bool canRecoverFromRewriteFailure() const override { return false; }
+
+ void replaceOp(Operation *op, ValueRange newValues) override;
+
+ void replaceOp(Operation *op, Operation *newOp) override {
+ replaceOp(op, newOp->getResults());
+ }
+
+ void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+ void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+ void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+ ValueRange argValues = std::nullopt) override {
+ PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+ }
+ using PatternRewriter::inlineBlockBefore;
+
+ void startOpModification(Operation *op) override {
+ PatternRewriter::startOpModification(op);
+ }
+
+ void finalizeOpModification(Operation *op) override {
+ PatternRewriter::finalizeOpModification(op);
+ }
+
+ void cancelOpModification(Operation *op) override {
+ PatternRewriter::cancelOpModification(op);
+ }
+
+ void setCurrentTypeConverter(const TypeConverter *converter) override {
+ typeConverter = converter;
+ }
+
+ const TypeConverter *getCurrentTypeConverter() const override {
+ return typeConverter;
+ }
+
+ LogicalResult getAdapterOperands(StringRef valueDiagTag,
+ std::optional<Location> inputLoc,
+ ValueRange values,
+ SmallVector<Value> &remapped) override;
+
+private:
+ /// Build an unrealized_conversion_cast op or look it up in the cache.
+ Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+ /// The current type converter.
+ const TypeConverter *typeConverter;
+
+ /// A cache for unrealized_conversion_casts. To ensure that identical casts
+ /// are not built multiple times.
+ DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+ ValueRange newValues) {
+ assert(op->getNumResults() == newValues.size());
+ for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+ if (orig.getType() != repl.getType()) {
+ // Type mismatch: insert unrealized_conversion cast.
+ replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+ op->getLoc(), orig.getType(), repl));
+ } else {
+ // Same type: use replacement value directly.
+ replaceAllUsesWith(orig, repl);
+ }
+ }
+ eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+ Location loc, Type type, Value value) {
+ auto it = castCache.find(std::make_pair(value, type));
+ if (it != castCache.end())
+ return it->second;
+
+ // Insert cast at the beginning of the block (for block arguments) or right
+ // after the defining op.
+ OpBuilder::InsertionGuard g(*this);
+ Block *insertBlock = value.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(value))
+ insertPt = ++inputRes.getOwner()->getIterator();
+ setInsertionPoint(insertBlock, insertPt);
+ auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+ castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+ return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+ ConversionPatternRewriteDriver(PatternRewriter &rewriter,
+ const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config,
+ const ConversionTarget &target)
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) {
+ }
+
+ /// Populate the worklist with all illegal ops and start the conversion
+ /// process.
+ LogicalResult convert(Operation *op) &&;
+
+protected:
+ void addToWorklist(Operation *op) override;
+
+ /// Notify the driver that the specified operation was removed. Update the
+ /// worklist as needed: The operation and its children are removed from the
+ /// worklist.
+ void notifyOperationErased(Operation *op) override;
+
+private:
+ const ConversionTarget ⌖
+};
+} // namespace
+
+LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && {
+ op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) {
+ auto legalityInfo = target.isLegal(op);
+ if (!legalityInfo) {
+ addSingleOpToWorklist(op);
+ return WalkResult::advance();
+ }
+ if (legalityInfo->isRecursivelyLegal) {
+ // Don't check this operation's children for conversion if the
+ // operation is recursively legal.
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
+ });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ // TODO: newly enqueued ops must also be reversed
+ worklist.reverse();
+
+ processWorklist();
+
+ return success();
+}
+
+void ConversionPatternRewriteDriver::addToWorklist(Operation *op) {
+ if (!target.isLegal(op))
+ addSingleOpToWorklist(op);
+}
+
+// TODO: Refactor. This is the same as
+// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to
+// the worklist.
+void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+
+#ifndef NDEBUG
+ // Only ops that are within the configured scope are added to the worklist of
+ // the greedy pattern rewriter. Moreover, the parent op of the scope region is
+ // the part of the IR that is taken into account for the "expensive checks".
+ // A greedy pattern rewrite is not allowed to erase the parent op of the scope
+ // region, as that would break the worklist handling and the expensive checks.
+ if (config.scope && config.scope->getParentOp() == op)
+ llvm_unreachable(
+ "scope region must not be erased during greedy pattern rewrite");
+#endif // NDEBUG
+
+ if (config.listener)
+ config.listener->notifyOperationErased(op);
+
+ worklist.remove(op);
+
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.erase(op);
+}
+
+/// Populate the converted operands in `remapped`. (Based on the currently set
+/// type converter.)
+LogicalResult OneShotConversionPatternRewriter::getAdapterOperands(
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+ SmallVector<Value> &remapped) {
+ // TODO: Refactor. This is mostly copied from the current dialect conversion.
+ for (Value v : values) {
+ // Skip all unrealized_conversion_casts in the chain of defining ops.
+ Value vBase = v;
+ while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>())
+ vBase = castOp.getInputs()[0];
+
+ if (!getCurrentTypeConverter()) {
+ // No type converter set. Just replicate what the current type conversion
+ // is doing.
+ // TODO: We may have to distinguish between newly-inserted an
+ // pre-existing unrealized_conversion_casts.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Type desiredType;
+ SmallVector<Type, 1> legalTypes;
+ if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes)))
+ return failure();
+ assert(legalTypes.size() == 1 && "1:N conversion not supported yet");
+ desiredType = legalTypes.front();
+ if (desiredType == vBase.getType()) {
+ // Type already matches. No need to convert anything.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Location operandLoc = inputLoc ? *inputLoc : v.getLoc();
+ remapped.push_back(
+ buildUnrealizedConversionCast(operandLoc, desiredType, vBase));
+ }
+ return success();
+}
+
+LogicalResult
+mlir::applyPartialOneShotConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns) {
+ GreedyRewriteConfig config;
+ OneShotConversionPatternRewriter rewriter(op->getContext());
+ ConversionPatternRewriteDriver driver(rewriter, patterns, config, target);
+ return std::move(driver).convert(op);
+}
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dbf8ead49f78d..f059caf079889 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -275,8 +275,8 @@ func.func @async_cp_i4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index
func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
- // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
// CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
// CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -310,8 +310,8 @@ func.func @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
- // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
// CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
// CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -533,7 +533,9 @@ func.func @mbarrier_nocomplete() {
}
// CHECK-LABEL: func @mbarrier_wait
+// CHECK-SAME: %[[barriers:.*]]: !nvgpu.mbarrier.group
func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, %token : !tokenType) {
+// CHECK: %[[barriersCast:.*]] = builtin.unrealized_conversion_cast %[[barriers]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%n = arith.constant 100 : index
@@ -545,7 +547,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
// CHECK: scf.for %[[i:.*]] =
// CHECK: %[[S2:.+]] = arith.remui %[[i]], %[[c5]] : index
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64
-// CHECK: %[[S4:.+]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[barriersCast]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
%mbarId = arith.remui %i, %numBarriers : index
%isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType
@@ -871,9 +873,9 @@ func.func @warpgroup_mma_128_128_64(
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
%acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
{
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1280,9 +1282,9 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
to memref<128x128xf32,3>
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1296,7 +1298,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.mma_async
// CHECK: nvvm.wgmma.mma_async
// CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
More information about the llvm-branch-commits
mailing list