[Mlir-commits] [mlir] b0afda7 - [mlir] Enable perfect forwarding in RewritePatternSet::add
lorenzo chelini
llvmlistbot at llvm.org
Fri Jul 15 10:09:29 PDT 2022
Author: lorenzo chelini
Date: 2022-07-15T19:08:23+02:00
New Revision: b0afda78f007740371307bfacbe4a486a4b77a3e
URL: https://github.com/llvm/llvm-project/commit/b0afda78f007740371307bfacbe4a486a4b77a3e
DIFF: https://github.com/llvm/llvm-project/commit/b0afda78f007740371307bfacbe4a486a4b77a3e.diff
LOG: [mlir] Enable perfect forwarding in RewritePatternSet::add
This patch modifies the implementation of `RewritePatternSet::add` to perfectly forward its arguments to pattern constructors. Without this, code like the following compiles but, due to the limited lifetime of the temporary TypeConverter, can produce unexpected behavior:
```
RewritePatternSet patterns(context);
patterns.add<SomeOpConversion, OtherOpConversion>(TypeConverter(), context);
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
```
The patch also changes the linalg fusion pattern implementation to correctly fill the test pattern set given the new behavior.
Author: Laszlo Kindrat <laszlokindrat at gmail.com>
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D129601
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0858908ffb73e..12bf196bb58e5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1416,7 +1416,10 @@ class RewritePatternSet {
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
- 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
+ 0, (addImpl<Ts>(/*debugLabels=*/llvm::None,
+ std::forward<ConstructorArg>(arg),
+ std::forward<ConstructorArgs>(args)...),
+ 0)...};
return *this;
}
/// An overload of the above `add` method that allows for attaching a set
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 63ebe4024f7fc..67651a98f7947 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -23,91 +23,60 @@
using namespace mlir;
using namespace mlir::linalg;
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns) {
- patterns.add<LinalgTileAndFusePattern<MatmulOp>,
- LinalgTileAndFusePattern<Conv2DOp>>(
+/// Use this to safely fill patterns for this test, since RewritePatternSet::add
+/// forwards Rvalues only to the first pattern.
+template <typename OpTy, LinalgTilingLoopType LoopType>
+static void fillFusionPattern(MLIRContext *context,
+ const LinalgDependenceGraph &dependenceGraph,
+ RewritePatternSet &patterns,
+ const Twine &testCase,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> indicesToFuse) {
+ patterns.add<LinalgTileAndFusePattern<OpTy>>(
context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({2}),
- LinalgTransformationFilter(
- StringAttr::get(context, "basic_fusion"),
- StringAttr::get(context, "after_basic_fusion")),
+ LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_basic_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_basic_fusion_original")));
-
- patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({0}),
- LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
- StringAttr::get(context, "after_lhs_fusion")),
+ StringAttr::get(context, testCase + "_fusion"),
+ StringAttr::get(context, "after_" + testCase + "_fusion")),
LinalgTransformationFilter(
ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_lhs_fusion_producer")),
+ StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
LinalgTransformationFilter(
ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_lhs_fusion_original")));
-
- patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({2}),
- LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
- StringAttr::get(context, "after_out_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_out_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_out_fusion_original")));
+ StringAttr::get(context, "after_" + testCase + "_fusion_original")));
+}
- patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({1}),
- LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
- StringAttr::get(context, "after_rhs_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_rhs_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_rhs_fusion_original")));
+template <LinalgTilingLoopType LoopType>
+static void fillFusionPatterns(MLIRContext *context,
+ const LinalgDependenceGraph &dependenceGraph,
+ RewritePatternSet &patterns) {
+ fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
+ /*testCase=*/"basic",
+ /*tileSizes=*/{32, 64, 16},
+ /*indicesToFuse=*/{2});
- patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({0, 2}),
- LinalgTransformationFilter(
- StringAttr::get(context, "two_operand_fusion"),
- StringAttr::get(context, "after_two_operand_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_two_operand_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_two_operand_fusion_original")));
+ auto fillMatmulPattern = [&](const Twine &testCase,
+ ArrayRef<int64_t> indicesToFuse) {
+ fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
+ testCase, /*tileSizes=*/{32, 64, 16},
+ indicesToFuse);
+ };
+ fillMatmulPattern(/*testCase=*/"basic",
+ /*indicesToFuse=*/{2});
+ fillMatmulPattern(/*testCase=*/"lhs",
+ /*indicesToFuse=*/{0});
+ fillMatmulPattern(/*testCase=*/"out",
+ /*indicesToFuse=*/{2});
+ fillMatmulPattern(/*testCase=*/"rhs",
+ /*indicesToFuse=*/{1});
+ fillMatmulPattern(/*testCase=*/"two_operand",
+ /*indicesToFuse=*/{0, 2});
- patterns.add<LinalgTileAndFusePattern<GenericOp>>(
- context, dependenceGraph,
- LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse({0, 1}),
- LinalgTransformationFilter(
- StringAttr::get(context, "transpose_fusion"),
- StringAttr::get(context, "after_transpose_fusion")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_transpose_fusion_producer")),
- LinalgTransformationFilter(
- ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_transpose_fusion_original")));
+ fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
+ /*testCase=*/"transpose",
+ /*tileSizes=*/{32, 64},
+ /*indicesToFuse=*/{0, 1});
}
namespace {
More information about the Mlir-commits
mailing list