[Mlir-commits] [mlir] [mlir][toy] Update dialect conversion example (PR #150826)
Matthias Springer
llvmlistbot at llvm.org
Sun Jul 27 04:15:19 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/150826
The Toy tutorial used outdated API. Update the example to:
* Use the `OpAdaptor` in all places.
* Do not mix `RewritePattern` and `ConversionPattern`. This cannot be done safely and should not be advertised in the example code.
* Always use values from the adaptor and never from the original op when constructing new IR.
>From 20bc42e2ec0b51adb3baa21a5a70692b241109c8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 27 Jul 2025 11:13:24 +0000
Subject: [PATCH] [mlir][toy] Update dialect conversion example
---
mlir/docs/Tutorials/Toy/Ch-5.md | 70 ++++------
mlir/docs/Tutorials/Toy/Ch-6.md | 8 ++
.../toy/Ch5/mlir/LowerToAffineLoops.cpp | 126 ++++++++----------
.../toy/Ch6/mlir/LowerToAffineLoops.cpp | 126 ++++++++----------
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 19 +--
.../toy/Ch7/mlir/LowerToAffineLoops.cpp | 122 ++++++++---------
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 19 +--
7 files changed, 218 insertions(+), 272 deletions(-)
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index c750c07ddfc04..39ec02d04c8ce 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -91,53 +91,37 @@ doesn't matter. See `ConversionTarget::getOpInfo` for the details.
After the conversion target has been defined, we can define how to convert the
*illegal* operations into *legal* ones. Similarly to the canonicalization
framework introduced in [chapter 3](Ch-3.md), the
-[`DialectConversion` framework](../../DialectConversion.md) also uses
-[RewritePatterns](../QuickstartRewrites.md) to perform the conversion logic.
-These patterns may be the `RewritePatterns` seen before or a new type of pattern
-specific to the conversion framework `ConversionPattern`. `ConversionPatterns`
-are different from traditional `RewritePatterns` in that they accept an
-additional `operands` parameter containing operands that have been
-remapped/replaced. This is used when dealing with type conversions, as the
-pattern will want to operate on values of the new type but match against the
-old. For our lowering, this invariant will be useful as it translates from the
-[TensorType](../../Dialects/Builtin.md/#rankedtensortype) currently being
-operated on to the [MemRefType](../../Dialects/Builtin.md/#memreftype). Let's
-look at a snippet of lowering the `toy.transpose` operation:
+[`DialectConversion` framework](../../DialectConversion.md) also uses a special
+kind of `ConversionPattern` to perform the conversion logic.
+`ConversionPatterns` are different from traditional `RewritePatterns` in that
+they accept an additional `operands` (or `adaptor`) parameter containing
+operands that have been remapped/replaced. This is used when dealing with type
+conversions, as the pattern will want to operate on values of the new type but
+match against the old. For our lowering, this invariant will be useful as it
+translates from the [TensorType](../../Dialects/Builtin.md/#rankedtensortype)
+currently being operated on to the
+[MemRefType](../../Dialects/Builtin.md/#memreftype). Let's look at a snippet of
+lowering the `toy.transpose` operation:
```c++
/// Lower the `toy.transpose` operation to an affine loop nest.
-struct TransposeOpLowering : public mlir::ConversionPattern {
- TransposeOpLowering(mlir::MLIRContext *ctx)
- : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}
-
- /// Match and rewrite the given `toy.transpose` operation, with the given
- /// operands that have been remapped from `tensor<...>` to `memref<...>`.
- llvm::LogicalResult
- matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
- mlir::ConversionPatternRewriter &rewriter) const final {
- auto loc = op->getLoc();
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
- // Call to a helper function that will lower the current operation to a set
- // of affine loops. We provide a functor that operates on the remapped
- // operands, as well as the loop induction variables for the inner most
- // loop body.
- lowerOpToLoops(
- op, operands, rewriter,
- [loc](mlir::PatternRewriter &rewriter,
- ArrayRef<mlir::Value> memRefOperands,
- ArrayRef<mlir::Value> loopIvs) {
- // Generate an adaptor for the remapped operands of the TransposeOp.
- // This allows for using the nice named accessors that are generated
- // by the ODS. This adaptor is automatically provided by the ODS
- // framework.
- TransposeOpAdaptor transposeAdaptor(memRefOperands);
- mlir::Value input = transposeAdaptor.input();
-
- // Transpose the elements by generating a load from the reverse
- // indices.
- SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return mlir::AffineLoadOp::create(rewriter, loc, input, reverseIvs);
- });
+ LogicalResult
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ auto loc = op->getLoc();
+ lowerOpToLoops(op, rewriter,
+ [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input,
+ reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 529de55304206..1ef9351422a59 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -81,6 +81,14 @@ enough for our use case.
LLVMTypeConverter typeConverter(&getContext());
```
+For the `toy.print` lowering, we need a special type converter to ensure that
+the pattern receives a `memref` value in its adaptor. If we were to use the
+LLVM type converter, it would receive an `llvm.struct`, which is the normal
+lowering of a `memref` type to LLVM. If we were to use no type converter at
+all, it would receive a value with the original tensor type. (Note: The dialect
+conversion driver currently passes the "most recently mapped value", i.e., a
+value of unspecified type. This is a bug in the conversion driver.)
+
### Conversion Patterns
Now that the conversion target has been defined, we need to provide the patterns
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..2969d3a795779 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
- // and the loop induction variables. This function will return the value
- // to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ // Call the processing function with the rewriter and the loop
+ // induction variables. This function will return the value to store at
+ // the current index.
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index e0950ef56f4fe..feb7ab33b76ce 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -55,19 +55,18 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToLLVM RewritePatterns
+// ToyToLLVM Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
/// elements of the array.
-class PrintOpLowering : public ConversionPattern {
+class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
public:
- explicit PrintOpLowering(MLIRContext *context)
- : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
@@ -108,9 +107,8 @@ class PrintOpLowering : public ConversionPattern {
}
// Generate a call to printf for the current element of the loop.
- auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -223,8 +221,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp.
- patterns.add<PrintOpLowering>(&getContext());
+ // PrintOp. An identity converter is needed because the PrintOp lowering
+ // operates on MemRefType instead of the lowered LLVM struct type.
+ TypeConverter identityConverter;
+ identityConverter.addConversion([](Type type) { return type; });
+ patterns.add<PrintOpLowering>(identityConverter, &getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index d65c89c3fcfa6..cbe4236050f57 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -44,7 +44,7 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns
+// ToyToAffine Conversion Patterns
//===----------------------------------------------------------------------===//
/// Convert the given RankedTensorType into the corresponding MemRefType.
@@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
}
/// This defines the function type used to process an iteration of a lowered
-/// loop. It takes as input an OpBuilder, an range of memRefOperands
-/// corresponding to the operands of the input operation, and the range of loop
-/// induction variables for the iteration. It returns a value to store at the
-/// current index of the iteration.
-using LoopIterationFn = function_ref<Value(
- OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-
-static void lowerOpToLoops(Operation *op, ValueRange operands,
- PatternRewriter &rewriter,
+/// loop. It takes as input an OpBuilder and the range of loop induction
+/// variables for the iteration. It returns a value to store at the current
+/// index of the iteration.
+using LoopIterationFn =
+ function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>;
+
+static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
@@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
affine::buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
- // Call the processing function with the rewriter, the memref operands,
+ // Call the processing function with the rewriter
// and the loop induction variables. This function will return the value
// to store at the current index.
- Value valueToStore = processIteration(nestedBuilder, operands, ivs);
+ Value valueToStore = processIteration(nestedBuilder, ivs);
affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc,
ivs);
});
@@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands,
namespace {
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Binary operations
+// ToyToAffine Conversion Patterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
-struct BinaryOpLowering : public ConversionPattern {
- BinaryOpLowering(MLIRContext *ctx)
- : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
+struct BinaryOpLowering : public OpConversionPattern<BinaryOp> {
+ using OpConversionPattern<BinaryOp>::OpConversionPattern;
+ using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // BinaryOp. This allows for using the nice named accessors
- // that are generated by the ODS.
- typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
-
- // Generate loads for the element of 'lhs' and 'rhs' at the
- // inner loop.
- auto loadedLhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getLhs(), loopIvs);
- auto loadedRhs = affine::AffineLoadOp::create(
- builder, loc, binaryAdaptor.getRhs(), loopIvs);
-
- // Create the binary operation performed on the loaded
- // values.
- return LoweredBinaryOp::create(builder, loc, loadedLhs,
- loadedRhs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ // Generate loads for the element of 'lhs' and 'rhs' at the
+ // inner loop.
+ auto loadedLhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs);
+ auto loadedRhs =
+ affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs);
+
+ // Create the binary operation performed on the loaded
+ // values.
+ return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs);
+ });
return success();
}
};
@@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Constant operations
+// ToyToAffine Conversion Patterns: Constant operations
//===----------------------------------------------------------------------===//
-struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
- using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
+struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> {
+ using OpConversionPattern<toy::ConstantOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ConstantOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.getValue();
Location loc = op.getLoc();
@@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Func operations
+// ToyToAffine Conversion Patterns: Func operations
//===----------------------------------------------------------------------===//
struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
@@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Print operations
+// ToyToAffine Conversion Patterns: Print operations
//===----------------------------------------------------------------------===//
struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
@@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Return operations
+// ToyToAffine Conversion Patterns: Return operations
//===----------------------------------------------------------------------===//
-struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
- using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
+struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> {
+ using OpConversionPattern<toy::ReturnOp>::OpConversionPattern;
- LogicalResult matchAndRewrite(toy::ReturnOp op,
- PatternRewriter &rewriter) const final {
+ LogicalResult
+ matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
@@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
};
//===----------------------------------------------------------------------===//
-// ToyToAffine RewritePatterns: Transpose operations
+// ToyToAffine Conversion Patterns: Transpose operations
//===----------------------------------------------------------------------===//
-struct TransposeOpLowering : public ConversionPattern {
- TransposeOpLowering(MLIRContext *ctx)
- : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
+struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> {
+ using OpConversionPattern<toy::TransposeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
- lowerOpToLoops(op, operands, rewriter,
- [loc](OpBuilder &builder, ValueRange memRefOperands,
- ValueRange loopIvs) {
- // Generate an adaptor for the remapped operands of the
- // TransposeOp. This allows for using the nice named
- // accessors that are generated by the ODS.
- toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
- Value input = transposeAdaptor.getInput();
-
- // Transpose the elements by generating a load from the
- // reverse indices.
- SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
- return affine::AffineLoadOp::create(builder, loc, input,
- reverseIvs);
- });
+ lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) {
+ Value input = adaptor.getInput();
+
+ // Transpose the elements by generating a load from the
+ // reverse indices.
+ SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
+ return affine::AffineLoadOp::create(builder, loc, input, reverseIvs);
+ });
return success();
}
};
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 43a84da88e189..af502950e03ff 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -55,19 +55,18 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// ToyToLLVM RewritePatterns
+// ToyToLLVM Conversion Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
/// elements of the array.
-class PrintOpLowering : public ConversionPattern {
+class PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
public:
- explicit PrintOpLowering(MLIRContext *context)
- : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
@@ -108,9 +107,8 @@ class PrintOpLowering : public ConversionPattern {
}
// Generate a call to printf for the current element of the loop.
- auto printOp = cast<toy::PrintOp>(op);
auto elementLoad =
- memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs);
+ memref::LoadOp::create(rewriter, loc, adaptor.getInput(), loopIvs);
LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef,
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
@@ -224,8 +222,11 @@ void ToyToLLVMLoweringPass::runOnOperation() {
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
- // PrintOp.
- patterns.add<PrintOpLowering>(&getContext());
+ // PrintOp. An identity converter is needed because the PrintOp lowering
+ // operates on MemRefType instead of the lowered LLVM struct type.
+ TypeConverter identityConverter;
+ identityConverter.addConversion([](Type type) { return type; });
+ patterns.add<PrintOpLowering>(identityConverter, &getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
More information about the Mlir-commits
mailing list