[Mlir-commits] [mlir] [mlir][toy] Update dialect conversion example (PR #150826)
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 27 10:50:23 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/150826
>From a039b2ae6688d941671d71dc0e14e0e0f549c068 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 11:13:24 +0000
Subject: [PATCH 1/2] [mlir][toy] Update dialect conversion example
---
mlir/docs/Tutorials/Toy/Ch-5.md | 57 +++-----
mlir/docs/Tutorials/Toy/Ch-6.md | 8 ++
.../toy/Ch5/mlir/LowerToAffineLoops.cpp | 126 ++++++++----------
.../toy/Ch6/mlir/LowerToAffineLoops.cpp | 126 ++++++++----------
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 19 +--
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 122 ++++++++---------
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 19 +--
7 files changed, 211 insertions(+), 266 deletions(-)
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index c750c07ddfc04..17cd6bb412a9a 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -91,13 +91,11 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details.
After the conversion target has been defined, we can define how to convert the
*illegal* operations into *legal* ones. Similarly to the canonicalization
framework introduced in [chapter 3](Ch-3.md), the
-[`DialectConversion` framework](../../DialectConversion.md) also uses
-[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic.
-These patterns may be the `RewritePatterns` seen before or a new type of pattern
-specific to the conversion framework `ConversionPattern`. `ConversionPatterns`
+[`DialectConversion` framework](../../DialectConversion.md) uses a special kind
+of `ConversionPattern` to perform the conversion logic. `ConversionPatterns`
are different from traditional `RewritePatterns` in that they accept an
-additional `operands` parameter containing operands that have been
-remapped/replaced. This is used when dealing with type conversions, as the
+additional `operands` (or `adaptor`) parameter containing operands that have
+been remapped/replaced. This is used when dealing with type conversions, as the
pattern will want to operate on values of the new type but match against the
old. For our lowering, this invariant will be useful as it translates from the
[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being
@@ -106,38 +104,23 @@ look at a snippet of lowering the `toy.transpose` operation:
```c++
/// Lower the `toy.transpose` operation to an affine loop nest.
-struct TransposeOpLowering : public mlir::ConversionPattern {
- TransposeOpLowering(mlir::MLIRContext *ctx)
- : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
-
- /// Match and rewrite the given `toy.transpose` operation, with the given
- /// operands that have been remapped from `tensor<...>` to `memref<...>`.
- llvm::LogicalResult
- matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
- mlir::ConversionPatternRewriter &rewriter) const final {
- auto loc = op->getLoc();
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
- // Call to a helper function that will lower the current operation to a set
- // of affine loops. We provide a functor that operates on the remapped
- // operands, as well as the loop induction variables for the inner most
- // loop body.
- lowerOpToLoops(
- op, operands, rewriter,
- [loc](mlir::PatternRewriter &rewriter,
- ArrayRef<mlir::Value> memRefOperands,
- ArrayRef<mlir::Value> loopIvs) {
- // Generate an adaptor for the remapped operands of the TransposeOp.
- // This allows for using the nice named accessors that are generated
- // by the ODS. This adaptor is automatically provided by the ODS
- // framework.
- TransposeOpAdaptor transposeAdaptor(memRefOperands);
- mlir::Value input = transposeAdaptor.input();
-
- // Transpose the elements by generating a load from the reverse
- // indices.
- SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return mlir::AffineLoadOp::create(rewriter, loc, input, reverseIvs);
- });
+ LogicalResult
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(op, rewriter,
+ [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input,
+ reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 529de55304206..1ef9351422a59 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -81,6 +81,14 @@ enough for our use case.
LLVMTypeConverter typeConverter(&getContext());
```
+For the `toy.print` lowering, we need a special type converter to ensure that
+the pattern receives a `memref` value in its adaptor. If we were to use the
+LLVM type converter, it would receive an `llvm.struct`, which is the normal
+lowering of a `memref` type to LLVM. If we were to use no type converter at
+all, it would receive a value with the original tensor type. (Note: The dialect
+conversion driver currently passes the "most recently mapped value", i.e., a
+value of unspecified type. This is a bug in the conversion driver.)
+
### Conversion Patterns
Now that the conversion target has been defined, we need to provide the patterns
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index e0950ef56f4fe..feb7ab33b76ce 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -55,19 +55,18 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToLLVM RewritePatterns
+// ToyToLLVM Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
/// elements of the array.
-class PrintOpLowering : public ConversionPattern {
+class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
public:
- explicit PrintOpLowering(MLIRContext *context)
- : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
@@ -108,9 +107,8 @@ class PrintOpLowering : public ConversionPattern {
}
// Generate a call to printf for the current element of the loop.
- auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -223,8 +221,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp.
- patterns.add<PrintOpLowering>(&getContext());
+ // PrintOp. An identity converter is needed because the PrintOp lowering
+ // operates on MemRefType instead of the lowered LLVM struct type.
+ TypeConverter identityConverter;
+ identityConverter.addConversion([](Type type) { return type; });
+ patterns.add<PrintOpLowering>(identityConverter, &getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..cbe4236050f57 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
+ // Call the processing function with the rewriter
// and the loop induction variables. This function will return the value
// to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 43a84da88e189..af502950e03ff 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -55,19 +55,18 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToLLVM RewritePatterns
+// ToyToLLVM Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
/// elements of the array.
-class PrintOpLowering : public ConversionPattern {
+class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
public:
- explicit PrintOpLowering(MLIRContext *context)
- : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
@@ -108,9 +107,8 @@ class PrintOpLowering : public ConversionPattern {
}
// Generate a call to printf for the current element of the loop.
- auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -224,8 +222,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp.
- patterns.add<PrintOpLowering>(&getContext());
+ // PrintOp. An identity converter is needed because the PrintOp lowering
+ // operates on MemRefType instead of the lowered LLVM struct type.
+ TypeConverter identityConverter;
+ identityConverter.addConversion([](Type type) { return type; });
+ patterns.add<PrintOpLowering>(identityConverter, &getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
>From a34ed3d7a21646ed5b2751efbc5a937e8b4ae2ef Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 17:48:54 +0000
Subject: [PATCH 2/2] address comments
---
mlir/docs/Tutorials/Toy/Ch-6.md | 8 --------
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 9 +++------
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 9 +++------
3 files changed, 6 insertions(+), 20 deletions(-)
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 1ef9351422a59..529de55304206 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -81,14 +81,6 @@ enough for our use case.
LLVMTypeConverter typeConverter(&getContext());
```
-For the `toy.print` lowering, we need a special type converter to ensure that
-the pattern receives a `memref` value in its adaptor. If we were to use the
-LLVM type converter, it would receive an `llvm.struct`, which is the normal
-lowering of a `memref` type to LLVM. If we were to use no type converter at
-all, it would receive a value with the original tensor type. (Note: The dialect
-conversion driver currently passes the "most recently mapped value", i.e., a
-value of unspecified type. This is a bug in the conversion driver.)
-
### Conversion Patterns
Now that the conversion target has been defined, we need to provide the patterns
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index feb7ab33b76ce..987dfa1eb9e78 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -108,7 +108,7 @@ class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
// Generate a call to printf for the current element of the loop.
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -221,11 +221,8 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp. An identity converter is needed because the PrintOp lowering
- // operates on MemRefType instead of the lowered LLVM struct type.
- TypeConverter identityConverter;
- identityConverter.addConversion([](Type type) { return type; });
- patterns.add<PrintOpLowering>(identityConverter, &getContext());
+ // PrintOp.
+ patterns.add<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index af502950e03ff..8b48a8f798beb 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -108,7 +108,7 @@ class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
// Generate a call to printf for the current element of the loop.
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -222,11 +222,8 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp. An identity converter is needed because the PrintOp lowering
- // operates on MemRefType instead of the lowered LLVM struct type.
- TypeConverter identityConverter;
- identityConverter.addConversion([](Type type) { return type; });
- patterns.add<PrintOpLowering>(identityConverter, &getContext());
+ // PrintOp.
+ patterns.add<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
More information about the Mlir-commits
mailing list