[Mlir-commits] [mlir] ac14d5a - [mlir] Enable perfect forwarding in RewritePatternSet::add

lorenzo chelini llvmlistbot at llvm.org
Fri Jul 15 10:44:36 PDT 2022


Author: Laszlo Kindrat
Date: 2022-07-15T19:44:18+02:00
New Revision: ac14d5a1db4b498cf38e5d79e09fa90a8715909b

URL: https://github.com/llvm/llvm-project/commit/ac14d5a1db4b498cf38e5d79e09fa90a8715909b
DIFF: https://github.com/llvm/llvm-project/commit/ac14d5a1db4b498cf38e5d79e09fa90a8715909b.diff

LOG: [mlir] Enable perfect forwarding in RewritePatternSet::add

This patch modifies the implementation of `RewritePatternSet::add` to perfectly forward its arguments to pattern constructors. Without this, code like the following compiles but, due to the limited lifetime of the temporary TypeConverter, can produce unexpected behavior:
```
RewritePatternSet patterns(context);
patterns.add<SomeOpConversion, OtherOpConversion>(TypeConverter(), context);

if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
  return signalPassFailure();
```

The patch also changes the linalg fusion pattern implementation to correctly fill the test pattern set given the new behavior.

Author: Laszlo Kindrat <laszlokindrat at gmail.com>

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129601

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0858908ffb73..12bf196bb58e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -1416,7 +1416,10 @@ class RewritePatternSet {
     // that a parameter pack can be expanded in c++11.
     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
     (void)std::initializer_list<int>{
-        0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
+        0, (addImpl<Ts>(/*debugLabels=*/llvm::None,
+                        std::forward<ConstructorArg>(arg),
+                        std::forward<ConstructorArgs>(args)...),
+            0)...};
     return *this;
   }
   /// An overload of the above `add` method that allows for attaching a set

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
index 63ebe4024f7f..67651a98f794 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
@@ -23,91 +23,60 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-template <LinalgTilingLoopType LoopType>
-static void fillFusionPatterns(MLIRContext *context,
-                               const LinalgDependenceGraph &dependenceGraph,
-                               RewritePatternSet &patterns) {
-  patterns.add<LinalgTileAndFusePattern<MatmulOp>,
-               LinalgTileAndFusePattern<Conv2DOp>>(
+/// Use this to safely fill patterns for this test, since RewritePatternSet::add
+/// forwards Rvalues only to the first pattern.
+template <typename OpTy, LinalgTilingLoopType LoopType>
+static void fillFusionPattern(MLIRContext *context,
+                              const LinalgDependenceGraph &dependenceGraph,
+                              RewritePatternSet &patterns,
+                              const Twine &testCase,
+                              ArrayRef<int64_t> tileSizes,
+                              ArrayRef<int64_t> indicesToFuse) {
+  patterns.add<LinalgTileAndFusePattern<OpTy>>(
       context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({2}),
-      LinalgTransformationFilter(
-          StringAttr::get(context, "basic_fusion"),
-          StringAttr::get(context, "after_basic_fusion")),
+      LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType),
+      LinalgFusionOptions().setIndicesToFuse(indicesToFuse),
       LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_basic_fusion_producer")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_basic_fusion_original")));
-
-  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
-      context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({0}),
-      LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"),
-                                 StringAttr::get(context, "after_lhs_fusion")),
+          StringAttr::get(context, testCase + "_fusion"),
+          StringAttr::get(context, "after_" + testCase + "_fusion")),
       LinalgTransformationFilter(
           ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_lhs_fusion_producer")),
+          StringAttr::get(context, "after_" + testCase + "_fusion_producer")),
       LinalgTransformationFilter(
           ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_lhs_fusion_original")));
-
-  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
-      context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({2}),
-      LinalgTransformationFilter(StringAttr::get(context, "out_fusion"),
-                                 StringAttr::get(context, "after_out_fusion")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_out_fusion_producer")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_out_fusion_original")));
+          StringAttr::get(context, "after_" + testCase + "_fusion_original")));
+}
 
-  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
-      context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({1}),
-      LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"),
-                                 StringAttr::get(context, "after_rhs_fusion")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_rhs_fusion_producer")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_rhs_fusion_original")));
+template <LinalgTilingLoopType LoopType>
+static void fillFusionPatterns(MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               RewritePatternSet &patterns) {
+  fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns,
+                                        /*testCase=*/"basic",
+                                        /*tileSizes=*/{32, 64, 16},
+                                        /*indicesToFuse=*/{2});
 
-  patterns.add<LinalgTileAndFusePattern<MatmulOp>>(
-      context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({0, 2}),
-      LinalgTransformationFilter(
-          StringAttr::get(context, "two_operand_fusion"),
-          StringAttr::get(context, "after_two_operand_fusion")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_two_operand_fusion_producer")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_two_operand_fusion_original")));
+  auto fillMatmulPattern = [&](const Twine &testCase,
+                               ArrayRef<int64_t> indicesToFuse) {
+    fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns,
+                                          testCase, /*tileSizes=*/{32, 64, 16},
+                                          indicesToFuse);
+  };
+  fillMatmulPattern(/*testCase=*/"basic",
+                    /*indicesToFuse=*/{2});
+  fillMatmulPattern(/*testCase=*/"lhs",
+                    /*indicesToFuse=*/{0});
+  fillMatmulPattern(/*testCase=*/"out",
+                    /*indicesToFuse=*/{2});
+  fillMatmulPattern(/*testCase=*/"rhs",
+                    /*indicesToFuse=*/{1});
+  fillMatmulPattern(/*testCase=*/"two_operand",
+                    /*indicesToFuse=*/{0, 2});
 
-  patterns.add<LinalgTileAndFusePattern<GenericOp>>(
-      context, dependenceGraph,
-      LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType),
-      LinalgFusionOptions().setIndicesToFuse({0, 1}),
-      LinalgTransformationFilter(
-          StringAttr::get(context, "transpose_fusion"),
-          StringAttr::get(context, "after_transpose_fusion")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_transpose_fusion_producer")),
-      LinalgTransformationFilter(
-          ArrayRef<StringAttr>(),
-          StringAttr::get(context, "after_transpose_fusion_original")));
+  fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns,
+                                         /*testCase=*/"transpose",
+                                         /*tileSizes=*/{32, 64},
+                                         /*indicesToFuse=*/{0, 1});
 }
 
 namespace {


        


More information about the Mlir-commits mailing list