[Mlir-commits] [mlir] 12929c2 - Revert "[mlir] Enable perfect forwarding in RewritePatternSet::add"

lorenzo chelini llvmlistbot at llvm.org
Fri Jul 15 10:16:33 PDT 2022


Author: lorenzo chelini
Date: 2022-07-15T19:16:09+02:00
New Revision: 12929c241af385f2a655d0926cc227f8b0b0ccb8

URL: https://github.com/llvm/llvm-project/commit/12929c241af385f2a655d0926cc227f8b0b0ccb8
DIFF: https://github.com/llvm/llvm-project/commit/12929c241af385f2a655d0926cc227f8b0b0ccb8.diff

LOG: Revert "[mlir] Enable perfect forwarding in RewritePatternSet::add"

Did not preserve author information.

This reverts commit b0afda78f007740371307bfacbe4a486a4b77a3e.

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 12bf196bb58e5..0858908ffb73e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1416,10 +1416,7 @@ class RewritePatternSet {
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
     (void)std::initializer_list<int>{
-        0, (addImpl<Ts>(/*debugLabels=*/llvm::None,
-                        std::forward<ConstructorArg>(arg),
-                        std::forward<ConstructorArgs>(args)...),
-            0)...};
+        0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
     return *this;
   }
   /// An overload of the above `add` method that allows for attaching a set

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 67651a98f7947..63ebe4024f7fc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -23,60 +23,91 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-/// Use this to safely fill patterns for this test, since RewritePatternSet::add
-/// forwards Rvalues only to the first pattern.
-template <typename OpTy, LinalgTilingLoopType LoopType>
-static void fillFusionPattern(MLIRContext *context,
-                              const LinalgDependenceGraph &dependenceGraph,
-                              RewritePatternSet &patterns,
-                              const Twine &testCase,
-                              ArrayRef<int64_t> tileSizes,
-                              ArrayRef<int64_t> indicesToFuse) {
-  patterns.add<LinalgTileAndFusePattern<OpTy>>(
+template <LinalgTilingLoopType LoopType>
+static void fillFusionPatterns(MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               RewritePatternSet &patterns) {
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>,
+               LinalgTileAndFusePattern<Conv2DOp>>(
       context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({2}),
       LinalgTransformationFilter(
-          StringAttr::get(context, testCase + "_fusion"),
-          StringAttr::get(context, "after_" + testCase + "_fusion")),
+          StringAttr::get(context, "basic_fusion"),
+          StringAttr::get(context, "after_basic_fusion")),
       LinalgTransformationFilter(
           ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
+          StringAttr::get(context, "after_basic_fusion_producer")),
       LinalgTransformationFilter(
           ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_" + testCase + "_fusion_original")));
-}
+          StringAttr::get(context, "after_basic_fusion_original")));
 
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
-                               const LinalgDependenceGraph &dependenceGraph,
-                               RewritePatternSet &patterns) {
-  fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
-                                        /*testCase=*/"basic",
-                                        /*tileSizes=*/{32, 64, 16},
-                                        /*indicesToFuse=*/{2});
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({0}),
+      LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
+                                 StringAttr::get(context, "after_lhs_fusion")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_lhs_fusion_producer")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_lhs_fusion_original")));
+
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({2}),
+      LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
+                                 StringAttr::get(context, "after_out_fusion")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_out_fusion_producer")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_out_fusion_original")));
 
-  auto fillMatmulPattern = [&](const Twine &testCase,
-                               ArrayRef<int64_t> indicesToFuse) {
-    fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
-                                          testCase, /*tileSizes=*/{32, 64, 16},
-                                          indicesToFuse);
-  };
-  fillMatmulPattern(/*testCase=*/"basic",
-                    /*indicesToFuse=*/{2});
-  fillMatmulPattern(/*testCase=*/"lhs",
-                    /*indicesToFuse=*/{0});
-  fillMatmulPattern(/*testCase=*/"out",
-                    /*indicesToFuse=*/{2});
-  fillMatmulPattern(/*testCase=*/"rhs",
-                    /*indicesToFuse=*/{1});
-  fillMatmulPattern(/*testCase=*/"two_operand",
-                    /*indicesToFuse=*/{0, 2});
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({1}),
+      LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
+                                 StringAttr::get(context, "after_rhs_fusion")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_rhs_fusion_producer")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_rhs_fusion_original")));
 
-  fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
-                                         /*testCase=*/"transpose",
-                                         /*tileSizes=*/{32, 64},
-                                         /*indicesToFuse=*/{0, 1});
+  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({0, 2}),
+      LinalgTransformationFilter(
+          StringAttr::get(context, "two_operand_fusion"),
+          StringAttr::get(context, "after_two_operand_fusion")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_two_operand_fusion_producer")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_two_operand_fusion_original")));
+
+  patterns.add<LinalgTileAndFusePattern<GenericOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse({0, 1}),
+      LinalgTransformationFilter(
+          StringAttr::get(context, "transpose_fusion"),
+          StringAttr::get(context, "after_transpose_fusion")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_transpose_fusion_producer")),
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>(),
+          StringAttr::get(context, "after_transpose_fusion_original")));
 }
 
 namespace {


        


More information about the Mlir-commits mailing list