[Mlir-commits] [mlir] 12929c2 - Revert "[mlir] Enable perfect forwarding in RewritePatternSet::add"
lorenzo chelini
llvmlistbot at llvm.org
Fri Jul 15 10:16:33 PDT 2022
Author: lorenzo chelini
Date: 2022-07-15T19:16:09+02:00
New Revision: 12929c241af385f2a655d0926cc227f8b0b0ccb8
URL: https://github.com/llvm/llvm-project/commit/12929c241af385f2a655d0926cc227f8b0b0ccb8
DIFF: https://github.com/llvm/llvm-project/commit/12929c241af385f2a655d0926cc227f8b0b0ccb8.diff
LOG: Revert "[mlir] Enable perfect forwarding in RewritePatternSet::add"
Did not preserve author information.
This reverts commit b0afda78f007740371307bfacbe4a486a4b77a3e.
Added:
Modified:
mlir/include/mlir/IR/PatternMatch.h
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 12bf196bb58e5..0858908ffb73e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1416,10 +1416,7 @@ class RewritePatternSet {
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
- 0, (addImpl<Ts>(/*debugLabels=*/llvm::None,
- std::forward<ConstructorArg>(arg),
- std::forward<ConstructorArgs>(args)...),
- 0)...};
+ 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
return *this;
}
/// An overload of the above `add` method that allows for attaching a set
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 67651a98f7947..63ebe4024f7fc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -23,60 +23,91 @@
using namespace mlir;
using namespace mlir::linalg;
-/// Use this to safely fill patterns for this test, since RewritePatternSet::add
-/// forwards Rvalues only to the first pattern.
-template <typename OpTy, LinalgTilingLoopType LoopType>
-static void fillFusionPattern(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns,
- const Twine &testCase,
- ArrayRef<int64_t> tileSizes,
- ArrayRef<int64_t> indicesToFuse) {
- patterns.add<LinalgTileAndFusePattern<OpTy>>(
+template <LinalgTilingLoopType LoopType>
+static void fillFusionPatterns(MLIRContext *context,
+ const LinalgDependenceGraph &dependenceGraph,
+ RewritePatternSet &patterns) {
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>,
+ LinalgTileAndFusePattern<Conv2DOp>>(
context, dependenceGraph,
- LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
- LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({2}),
LinalgTransformationFilter(
- StringAttr::get(context, testCase + "_fusion"),
- StringAttr::get(context, "after_" + testCase + "_fusion")),
+ StringAttr::get(context, "basic_fusion"),
+ StringAttr::get(context, "after_basic_fusion")),
LinalgTransformationFilter(
ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
+ StringAttr::get(context, "after_basic_fusion_producer")),
LinalgTransformationFilter(
ArrayRef<StringAttr>(),
- StringAttr::get(context, "after_" + testCase + "_fusion_original")));
-}
+ StringAttr::get(context, "after_basic_fusion_original")));
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
- const LinalgDependenceGraph &dependenceGraph,
- RewritePatternSet &patterns) {
- fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"basic",
- /*tileSizes=*/{32, 64, 16},
- /*indicesToFuse=*/{2});
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({0}),
+ LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
+ StringAttr::get(context, "after_lhs_fusion")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_lhs_fusion_producer")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_lhs_fusion_original")));
+
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({2}),
+ LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
+ StringAttr::get(context, "after_out_fusion")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_out_fusion_producer")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_out_fusion_original")));
- auto fillMatmulPattern = [&](const Twine &testCase,
- ArrayRef<int64_t> indicesToFuse) {
- fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
- testCase, /*tileSizes=*/{32, 64, 16},
- indicesToFuse);
- };
- fillMatmulPattern(/*testCase=*/"basic",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"lhs",
- /*indicesToFuse=*/{0});
- fillMatmulPattern(/*testCase=*/"out",
- /*indicesToFuse=*/{2});
- fillMatmulPattern(/*testCase=*/"rhs",
- /*indicesToFuse=*/{1});
- fillMatmulPattern(/*testCase=*/"two_operand",
- /*indicesToFuse=*/{0, 2});
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({1}),
+ LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
+ StringAttr::get(context, "after_rhs_fusion")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_rhs_fusion_producer")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_rhs_fusion_original")));
- fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
- /*testCase=*/"transpose",
- /*tileSizes=*/{32, 64},
- /*indicesToFuse=*/{0, 1});
+ patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({0, 2}),
+ LinalgTransformationFilter(
+ StringAttr::get(context, "two_operand_fusion"),
+ StringAttr::get(context, "after_two_operand_fusion")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_two_operand_fusion_producer")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_two_operand_fusion_original")));
+
+ patterns.add<LinalgTileAndFusePattern<GenericOp>>(
+ context, dependenceGraph,
+ LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
+ LinalgFusionOptions().setIndicesToFuse({0, 1}),
+ LinalgTransformationFilter(
+ StringAttr::get(context, "transpose_fusion"),
+ StringAttr::get(context, "after_transpose_fusion")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_transpose_fusion_producer")),
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>(),
+ StringAttr::get(context, "after_transpose_fusion_original")));
}
namespace {
More information about the Mlir-commits
mailing list