[Mlir-commits] [mlir] [draft] Dialect Conversion without Rollback (PR #93412)
Matthias Springer
llvmlistbot at llvm.org
Sat Jun 8 03:28:55 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/93412
>From 6c04a065357a15c35c79332c9658036b1073dd5d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 26 May 2024 14:59:09 +0200
Subject: [PATCH] [draft] Dialect Conversion without Rollback
This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver`
The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet.
`OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
---
.../mlir/Transforms/DialectConversion.h | 30 +-
.../Transforms/GreedyPatternRewriteDriver.h | 8 +
.../AffineToStandard/AffineToStandard.cpp | 5 +-
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 5 +-
.../ComplexToStandard/ComplexToStandard.cpp | 5 +-
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 5 +-
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 5 +-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 6 +-
.../Transforms/Utils/DialectConversion.cpp | 41 ++-
.../Utils/GreedyPatternRewriteDriver.cpp | 283 ++++++++++++++++--
.../AffineToStandard/lower-affine.mlir | 8 +-
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 27 +-
.../convert-nd-vector-to-llvmir.mlir | 6 +-
.../expand-then-convert-to-llvm.mlir | 4 +-
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 26 +-
15 files changed, 379 insertions(+), 85 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..76d56073f72b5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
- template <typename TargetType> TargetType convertType(Type t) const {
+ template <typename TargetType>
+ TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter {
+class ConversionPatternRewriter : public PatternRewriter {
public:
~ConversionPatternRewriter() override;
@@ -708,8 +709,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return the converted values that replace 'keys' with types defined by the
/// type converter of the currently executing pattern. Returns failure if the
/// remap failed, success otherwise.
- LogicalResult getRemappedValues(ValueRange keys,
- SmallVectorImpl<Value> &results);
+ LogicalResult getRemappedValues(ValueRange keys, SmallVector<Value> &results);
+
+ virtual void setCurrentTypeConverter(const TypeConverter *converter);
+
+ virtual const TypeConverter *getCurrentTypeConverter() const;
+
+ /// Populate the operands that are used for constructing the adapter into
+ /// `remapped`.
+ virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
+ std::optional<Location> inputLoc,
+ ValueRange values,
+ SmallVector<Value> &remapped);
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
@@ -755,6 +766,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
+protected:
+ /// Protected constructor for `OneShotConversionPatternRewriter`. Does not
+ /// initialize `impl`.
+ explicit ConversionPatternRewriter(MLIRContext *ctx);
+
+ // Hide unsupported pattern rewriter API.
+ using OpBuilder::setListener;
+
private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
@@ -765,9 +784,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
explicit ConversionPatternRewriter(MLIRContext *ctx,
const ConversionConfig &config);
- // Hide unsupported pattern rewriter API.
- using OpBuilder::setListener;
-
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146aac15b9..7513427e0d3ae 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -18,6 +18,8 @@
namespace mlir {
+class ConversionTarget;
+
/// This enum controls which ops are put on the worklist during a greedy
/// pattern rewrite.
enum class GreedyRewriteStrictness {
@@ -78,6 +80,8 @@ class GreedyRewriteConfig {
/// excluded.
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+ bool enableOperationDce = true;
+
/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;
};
@@ -188,6 +192,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
+LogicalResult
+applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns);
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 10ccd5c97783b..f7e0ffeaa0d47 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace mlir {
@@ -563,8 +564,8 @@ class LowerAffinePass
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
scf::SCFDialect, VectorDialect>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index d882f1184f457..dcc3dd3659cf6 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <type_traits>
namespace mlir {
@@ -479,8 +480,8 @@ struct ArithToLLVMConversionPass
LLVMTypeConverter converter(&getContext(), options);
mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index d8150aeb828a5..885bf7f9822df 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#include <type_traits>
@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
- if (failed(
- applyPartialConversion(getOperation(), target, std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
} // namespace
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index b8e5aec25286d..0d81277a82996 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
@@ -240,8 +241,8 @@ struct ConvertControlFlowToLLVM
LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 23e957288eb95..89c5a980d7dea 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
@@ -291,8 +292,8 @@ struct ConvertMathToLLVMPass
LLVMTypeConverter converter(&getContext());
populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
LLVMConversionTarget target(getContext());
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..931cedc0c5eb9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialOneShotConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
+ // applyPartialConversion
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..8f0d560c05d7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
setListener(impl.get());
}
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ : PatternRewriter(ctx), impl(nullptr) {}
+
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1717,19 +1720,17 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<Value> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
- remappedValues)))
+ if (failed(getRemappedValues(key, remappedValues)))
return nullptr;
return remappedValues.front();
}
LogicalResult
ConversionPatternRewriter::getRemappedValues(ValueRange keys,
- SmallVectorImpl<Value> &results) {
+ SmallVector<Value> &results) {
if (keys.empty())
return success();
- return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
- results);
+ return getAdapterOperands("value", /*inputLoc=*/std::nullopt, keys, results);
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1819,6 +1820,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
+void ConversionPatternRewriter::setCurrentTypeConverter(
+ const TypeConverter *converter) {
+ impl->currentTypeConverter = converter;
+}
+
+const TypeConverter *
+ConversionPatternRewriter::getCurrentTypeConverter() const {
+ return impl->currentTypeConverter;
+}
+
+LogicalResult ConversionPatternRewriter::getAdapterOperands(
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+ SmallVector<Value> &remapped) {
+ return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
+}
+
//===----------------------------------------------------------------------===//
// ConversionPattern
//===----------------------------------------------------------------------===//
@@ -1827,16 +1844,18 @@ LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
- auto &rewriterImpl = dialectRewriter.getImpl();
// Track the current conversion pattern type converter in the rewriter.
- llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
- getTypeConverter());
+ const TypeConverter *currentTypeConverter =
+ dialectRewriter.getCurrentTypeConverter();
+ auto resetTypeConverter = llvm::make_scope_exit(
+ [&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
+ dialectRewriter.setCurrentTypeConverter(getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ SmallVector<Value> operands;
+ if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
+ op->getOperands(), operands))) {
return failure();
}
return matchAndRewrite(op, operands, dialectRewriter);
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911b..99e82827cdefb 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,10 +14,12 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
+#include "mlir/IR/Iterators.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
@@ -321,7 +323,7 @@ class RandomizedWorklist : public Worklist {
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
- explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+ explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config);
@@ -329,7 +331,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
void addSingleOpToWorklist(Operation *op);
/// Add the given operation and its ancestors to the worklist.
- void addToWorklist(Operation *op);
+ virtual void addToWorklist(Operation *op);
/// Notify the driver that the specified operation may have been modified
/// in-place. The operation is added to the worklist.
@@ -356,7 +358,7 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
- PatternRewriter rewriter;
+ PatternRewriter &rewriter;
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
@@ -375,6 +377,11 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+#ifndef NDEBUG
+ /// A logger used to emit information during the application process.
+ llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+
private:
/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
@@ -394,11 +401,6 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
-#ifndef NDEBUG
- /// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-
/// The low-level pattern applicator.
PatternApplicator matcher;
@@ -409,9 +411,9 @@ class GreedyPatternRewriteDriver : public RewriterBase::Listener {
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : rewriter(ctx), config(config), matcher(patterns)
+ : rewriter(rewriter), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
@@ -476,7 +478,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
});
// If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
+ if (config.enableOperationDce && isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
changed = true;
@@ -780,7 +782,7 @@ namespace {
/// This driver simplfies all ops in a region.
class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
- explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+ explicit RegionPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
Region ®ions);
@@ -796,9 +798,9 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
} // namespace
RegionPatternRewriteDriver::RegionPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, Region ®ion)
- : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) {
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
@@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
- RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
- region);
+ PatternRewriter rewriter(region.getContext());
+ RegionPatternRewriteDriver driver(rewriter, patterns, config, region);
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -928,7 +930,7 @@ namespace {
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
@@ -950,10 +952,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
} // namespace
MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
- : GreedyPatternRewriteDriver(ctx, patterns, config),
+ : GreedyPatternRewriteDriver(rewriter, patterns, config),
survivingOps(survivingOps) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
+ PatternRewriter rewriter(ops.front()->getContext());
llvm::SmallDenseSet<Operation *, 4> surviving;
- MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- config, ops,
+ MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
@@ -1053,3 +1055,242 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+ OneShotConversionPatternRewriter(MLIRContext *ctx)
+ : ConversionPatternRewriter(ctx) {}
+
+ bool canRecoverFromRewriteFailure() const override { return false; }
+
+ void replaceOp(Operation *op, ValueRange newValues) override;
+
+ void replaceOp(Operation *op, Operation *newOp) override {
+ replaceOp(op, newOp->getResults());
+ }
+
+ void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+ void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+ void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+ ValueRange argValues = std::nullopt) override {
+ PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+ }
+ using PatternRewriter::inlineBlockBefore;
+
+ void startOpModification(Operation *op) override {
+ PatternRewriter::startOpModification(op);
+ }
+
+ void finalizeOpModification(Operation *op) override {
+ PatternRewriter::finalizeOpModification(op);
+ }
+
+ void cancelOpModification(Operation *op) override {
+ PatternRewriter::cancelOpModification(op);
+ }
+
+ void setCurrentTypeConverter(const TypeConverter *converter) override {
+ typeConverter = converter;
+ }
+
+ const TypeConverter *getCurrentTypeConverter() const override {
+ return typeConverter;
+ }
+
+ LogicalResult getAdapterOperands(StringRef valueDiagTag,
+ std::optional<Location> inputLoc,
+ ValueRange values,
+ SmallVector<Value> &remapped) override;
+
+private:
+ /// Build an unrealized_conversion_cast op or look it up in the cache.
+ Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+ /// The current type converter.
+ const TypeConverter *typeConverter;
+
+ /// A cache for unrealized_conversion_casts. To ensure that identical casts
+ /// are not built multiple times.
+ DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+ ValueRange newValues) {
+ assert(op->getNumResults() == newValues.size());
+ for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+ if (orig.getType() != repl.getType()) {
+ // Type mismatch: insert unrealized_conversion cast.
+ replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+ op->getLoc(), orig.getType(), repl));
+ } else {
+ // Same type: use replacement value directly.
+ replaceAllUsesWith(orig, repl);
+ }
+ }
+ eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+ Location loc, Type type, Value value) {
+ auto it = castCache.find(std::make_pair(value, type));
+ if (it != castCache.end())
+ return it->second;
+
+ // Insert cast at the beginning of the block (for block arguments) or right
+ // after the defining op.
+ OpBuilder::InsertionGuard g(*this);
+ Block *insertBlock = value.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(value))
+ insertPt = ++inputRes.getOwner()->getIterator();
+ setInsertionPoint(insertBlock, insertPt);
+ auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+ castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+ return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+ ConversionPatternRewriteDriver(PatternRewriter &rewriter,
+ const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config,
+ const ConversionTarget &target)
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) {
+ }
+
+ /// Populate the worklist with all illegal ops and start the conversion
+ /// process.
+ LogicalResult convert(Operation *op) &&;
+
+protected:
+ void addToWorklist(Operation *op) override;
+
+ /// Notify the driver that the specified operation was removed. Update the
+ /// worklist as needed: The operation and its children are removed from the
+ /// worklist.
+ void notifyOperationErased(Operation *op) override;
+
+private:
+ const ConversionTarget ⌖
+};
+} // namespace
+
+LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && {
+ op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) {
+ auto legalityInfo = target.isLegal(op);
+ if (!legalityInfo) {
+ addSingleOpToWorklist(op);
+ return WalkResult::advance();
+ }
+ if (legalityInfo->isRecursivelyLegal) {
+ // Don't check this operation's children for conversion if the
+ // operation is recursively legal.
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
+ });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ // TODO: newly enqueued ops must also be reversed
+ worklist.reverse();
+
+ processWorklist();
+
+ return success();
+}
+
+void ConversionPatternRewriteDriver::addToWorklist(Operation *op) {
+ if (!target.isLegal(op))
+ addSingleOpToWorklist(op);
+}
+
+// TODO: Refactor. This is the same as
+// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to
+// the worklist.
+void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+
+#ifndef NDEBUG
+ // Only ops that are within the configured scope are added to the worklist of
+ // the greedy pattern rewriter. Moreover, the parent op of the scope region is
+ // the part of the IR that is taken into account for the "expensive checks".
+ // A greedy pattern rewrite is not allowed to erase the parent op of the scope
+ // region, as that would break the worklist handling and the expensive checks.
+ if (config.scope && config.scope->getParentOp() == op)
+ llvm_unreachable(
+ "scope region must not be erased during greedy pattern rewrite");
+#endif // NDEBUG
+
+ if (config.listener)
+ config.listener->notifyOperationErased(op);
+
+ worklist.remove(op);
+
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.erase(op);
+}
+
+/// Populate the converted operands in `remapped`. (Based on the currently set
+/// type converter.)
+LogicalResult OneShotConversionPatternRewriter::getAdapterOperands(
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+ SmallVector<Value> &remapped) {
+ // TODO: Refactor. This is mostly copied from the current dialect conversion.
+ for (Value v : values) {
+ // Skip all unrealized_conversion_casts in the chain of defining ops.
+ Value vBase = v;
+ while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>())
+ vBase = castOp.getInputs()[0];
+
+ if (!getCurrentTypeConverter()) {
+ // No type converter set. Just replicate what the current type conversion
+ // is doing.
+ // TODO: We may have to distinguish between newly-inserted an
+ // pre-existing unrealized_conversion_casts.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Type desiredType;
+ SmallVector<Type, 1> legalTypes;
+ if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes)))
+ return failure();
+ assert(legalTypes.size() == 1 && "1:N conversion not supported yet");
+ desiredType = legalTypes.front();
+ if (desiredType == vBase.getType()) {
+ // Type already matches. No need to convert anything.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Location operandLoc = inputLoc ? *inputLoc : v.getLoc();
+ remapped.push_back(
+ buildUnrealizedConversionCast(operandLoc, desiredType, vBase));
+ }
+ return success();
+}
+
+LogicalResult
+mlir::applyPartialOneShotConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns) {
+ GreedyRewriteConfig config;
+ config.enableOperationDce = false;
+ OneShotConversionPatternRewriter rewriter(op->getContext());
+ ConversionPatternRewriteDriver driver(rewriter, patterns, config, target);
+ return std::move(driver).convert(op);
+}
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb..865840a751fa6 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -430,7 +430,7 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
#map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>
// CHECK-LABEL: func @affine_applies(
-func.func @affine_applies(%arg0 : index) {
+func.func @affine_applies(%arg0 : index) -> (index, index, index, index, index) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
%zero = affine.apply #map0()
@@ -442,9 +442,7 @@ func.func @affine_applies(%arg0 : index) {
%102 = arith.constant 102 : index
%copy = affine.apply #map2(%zero)
-// CHECK-NEXT: %[[v0:.*]] = arith.addi %[[c0]], %[[c0]] : index
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 : index
-// CHECK-NEXT: %[[v1:.*]] = arith.addi %[[v0]], %[[c1]] : index
%one = affine.apply #map3(%symbZero)[%zero]
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
@@ -466,7 +464,9 @@ func.func @affine_applies(%arg0 : index) {
// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] : index
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
- return
+
+// CHECK: return %[[c0]], %[[c0]], %[[c0]], %[[c1]], %[[v13]]
+ return %zero, %symbZero, %copy, %one, %four : index, index, index, index, index
}
// CHECK-LABEL: func @args_ret_affine_apply(
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 56ae930e6d627..2360d6e71dedc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -321,9 +321,9 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
// CHECK-NEXT: = llvm.sext %[[ARG0]] : vector<1xi3> to vector<1xi6>
%0 = arith.extsi %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6>
+// CHECK: = llvm.zext %[[ARG0]] : vector<1xi3> to vector<1xi6>
%1 = arith.extui %arg0 : vector<i3> to vector<i6>
-// CHECK-NEXT: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2>
+// CHECK: = llvm.trunc %[[ARG0]] : vector<1xi3> to vector<1xi2>
%2 = arith.trunci %arg0 : vector<i3> to vector<i2>
return
}
@@ -478,11 +478,12 @@ func.func @mului_extended_vector1d(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -
// -----
// CHECK-LABEL: func @cmpf_2dvector(
+// CHECK-SAME: %[[FARG0:.*]]: vector<4x3xf32>, %[[FARG1:.*]]: vector<4x3xf32>
func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
- // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
- // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>>
- // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
+ // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xf32>>
// CHECK: %[[CMP:.*]] = llvm.fcmp "olt" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xf32>
// CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>>
%0 = arith.cmpf olt, %arg0, %arg1 : vector<4x3xf32>
@@ -492,9 +493,10 @@ func.func @cmpf_2dvector(%arg0 : vector<4x3xf32>, %arg1 : vector<4x3xf32>) {
// -----
// CHECK-LABEL: func @cmpi_0dvector(
+// CHECK-SAME: %[[FARG0:.*]]: vector<i32>, %[[FARG1:.*]]: vector<i32>
func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
- // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
// CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[ARG0]], %[[ARG1]] : vector<1xi32>
%0 = arith.cmpi ult, %arg0, %arg1 : vector<i32>
func.return
@@ -503,11 +505,12 @@ func.func @cmpi_0dvector(%arg0 : vector<i32>, %arg1 : vector<i32>) {
// -----
// CHECK-LABEL: func @cmpi_2dvector(
+// CHECK-SAME: %[[FARG0:.*]]: vector<4x3xi32>, %[[FARG1:.*]]: vector<4x3xi32>
func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
- // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast
- // CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>>
- // CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[FARG0]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[FARG1]]
+ // CHECK-DAG: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi32>>
+ // CHECK-DAG: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
// CHECK: %[[CMP:.*]] = llvm.icmp "ult" %[[EXTRACT1]], %[[EXTRACT2]] : vector<3xi32>
// CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[CMP]], %2[0] : !llvm.array<4 x vector<3xi1>>
%0 = arith.cmpi ult, %arg0, %arg1 : vector<4x3xi32>
diff --git a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
index 63989347567b5..b234cbbb35f32 100644
--- a/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/convert-nd-vector-to-llvmir.mlir
@@ -199,9 +199,9 @@ func.func @bitcast_2d(%arg0: vector<2x4xf32>) {
// CHECK-LABEL: func @select_2d(
func.func @select_2d(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : vector<4x3xi32>) {
- // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
- // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %arg0
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %arg1
+ // CHECK-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %arg2
// CHECK: %[[EXTRACT1:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.array<4 x vector<3xi1>>
// CHECK: %[[EXTRACT2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.array<4 x vector<3xi32>>
// CHECK: %[[EXTRACT3:.*]] = llvm.extractvalue %[[ARG2]][0] : !llvm.array<4 x vector<3xi32>>
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index b86103422b074..599bb6ca14c48 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -294,6 +294,7 @@ func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -
// CHECK: %[[MEM:[a-zA-Z0-9]*]]: memref
func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?xf32, strided<[?, 1], offset: ?>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEMREF]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
@@ -316,7 +317,6 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Sizes and strides @rank 1: static stride 1, dynamic size unchanged from source memref.
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[CST_STRIDE1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[CST_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%c0 = arith.constant 1 : index
@@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
-// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 1b046d32f163a..57f805307a065 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -275,8 +275,8 @@ func.func @async_cp_i4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index
func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
- // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
// CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
// CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -310,8 +310,8 @@ func.func @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
- // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
- // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
+ // CHECK-DAG: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64
+ // CHECK-DAG: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64
// CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64
// CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64
@@ -533,7 +533,9 @@ func.func @mbarrier_nocomplete() {
}
// CHECK-LABEL: func @mbarrier_wait
+// CHECK-SAME: %[[barriers:.*]]: !nvgpu.mbarrier.group
func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, %token : !tokenType) {
+// CHECK: %[[barriersCast:.*]] = builtin.unrealized_conversion_cast %[[barriers]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%n = arith.constant 100 : index
@@ -545,7 +547,7 @@ func.func @mbarrier_wait(%barriers : !nvgpu.mbarrier.group<memorySpace = #gpu.ad
// CHECK: scf.for %[[i:.*]] =
// CHECK: %[[S2:.+]] = arith.remui %[[i]], %[[c5]] : index
// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[S2]] : index to i64
-// CHECK: %[[S4:.+]] = llvm.extractvalue %0[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[barriersCast]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[S5:.+]] = llvm.getelementptr %[[S4]][%[[S3]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
%mbarId = arith.remui %i, %numBarriers : index
%isDone = nvgpu.mbarrier.test.wait %barriers[%mbarId], %token : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>, num_barriers = 5>, !tokenType
@@ -871,9 +873,9 @@ func.func @warpgroup_mma_128_128_64(
%descB: !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>,
%acc: !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>)
{
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> to !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[UD:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1280,9 +1282,9 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
to memref<128x128xf32,3>
-// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
-// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
-// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-DAG: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>> to i64
+// CHECK-DAG: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>> to i64
+// CHECK-DAG: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : memref<128x128xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[S4:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S5:.+]] = llvm.extractvalue %[[S4]][0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
@@ -1296,7 +1298,7 @@ func.func @warpgroup_matrix_multiply_m128n128k64(
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: %[[S137:.+]] = llvm.mlir.undef : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
// CHECK: %[[S138:.+]] = llvm.extractvalue %136[0] : !llvm.struct<(struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>, struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>)>
-// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %0, %1, %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+// CHECK: %[[S139:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], %[[S138]], <m = 64, n = 128, k = 16>, D[<f32>, <one>, <wrapped>], A[<f16>, <one>, <row>], B[<f16>, <one>, <row>] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvvm.wgmma.mma_async
// CHECK: nvvm.wgmma.mma_async
// CHECK: %[[S154:.+]] = nvvm.wgmma.mma_async
More information about the Mlir-commits
mailing list