[Mlir-commits] [mlir] e349fb7 - [mlir][Linalg] NFC - Make markers use Identifier instead of StringRef

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jun 3 02:55:21 PDT 2020


Author: Nicolas Vasilache
Date: 2020-06-03T05:52:32-04:00
New Revision: e349fb70a23f3a39e058605e4e2db66da5e5ea4a

URL: https://github.com/llvm/llvm-project/commit/e349fb70a23f3a39e058605e4e2db66da5e5ea4a
DIFF: https://github.com/llvm/llvm-project/commit/e349fb70a23f3a39e058605e4e2db66da5e5ea4a.diff

LOG: [mlir][Linalg] NFC - Make markers use Identifier instead of StringRef

Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically.

Reviewers: ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81027

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fb8fc4cbe949..2437ac557799 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/SmallBitVector.h"
 
@@ -206,15 +207,16 @@ struct LinalgTransforms {
 
 /// Helper class to control common attribute matching and setting behavior.
 struct LinalgMarker {
-  LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
-               Optional<StringRef> replacement = None);
-  LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
+  explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
+                        Optional<Identifier> replacement = None);
+  LinalgMarker(LinalgMarker &&) = default;
+  LinalgMarker(const LinalgMarker &) = default;
   LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
   void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
 
 private:
-  SmallVector<StringRef, 4> matchDisjunction;
-  Optional<StringRef> replacement;
+  SmallVector<Identifier, 4> matchDisjunction;
+  Optional<Identifier> replacement;
 };
 
 ///

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 5b4fec4bbf20..0dac95739679 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
 public:
   static void insert(OwningRewritePatternList &patterns,
                      const LinalgTilingOptions &options, MLIRContext *ctx) {
-    patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
-                                               LinalgMarker({}, "tiled"));
+    patterns.insert<LinalgTilingPattern<OpTy>>(
+        ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
     RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
   }
 };

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 76e118e482f0..1aefb8f42f72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -46,15 +46,11 @@ using llvm::dbgs;
 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
     "__internal_linalg_transform__";
 
-mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
-                                         Optional<StringRef> replacement)
+mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
+                                         Optional<Identifier> replacement)
     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
       replacement(replacement) {}
 
-mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
-                                         StringRef replacement)
-    : LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
-
 LogicalResult
 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
                                            Operation *op) const {
@@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
     if (matchDisjunction.empty())
       return success();
 
-    // 2. Has no marker and matchDisjuntion matches the no-moarker case.
-    for (auto marker : matchDisjunction)
-      if (marker.empty())
-        return success();
-
-    // 3. Has no marker but was expecting a marker.
+    // 2. Has no marker but was expecting a marker.
     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
       diag << " does not have any marker from list: ";
       interleaveComma(matchDisjunction, diag);

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 41fa3fd95d93..9a022082b3be 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,9 +14,10 @@
 func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
           %y: memref<?xf32, offset: ?, strides: [1]>,
           %v: memref<f32>) {
-  linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
-                           memref<?xf32, offset: ?, strides: [1]>,
-                           memref<f32>
+  linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
+             memref<?xf32, offset: ?, strides: [1]>,
+             memref<?xf32, offset: ?, strides: [1]>,
+             memref<f32>
   return
 }
 // CHECK-LABEL: func @dot
@@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
 func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {
-  linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                              memref<?xf32, offset: ?, strides: [1]>,
-                              memref<?xf32, offset: ?, strides: [1]>
+  linalg.matvec(%A, %x, %y) :
+                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?xf32, offset: ?, strides: [1]>,
+                memref<?xf32, offset: ?, strides: [1]>
   return
 }
 // CHECK-LABEL: func @matvec
@@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                              memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                              memref<?x?xf32, offset: ?, strides: [?, 1]>
+  linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
+                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                memref<?x?xf32, offset: ?, strides: [?, 1]>
   return
 }
 // CHECK-LABEL: func @matmul

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 31189f47f9ae..4b1b5bdc1366 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
-      LinalgMarker({"MEM", {}}, "L3"));
+      LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
-      LinalgMarker({"L3"}, "L2"));
+      LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
-      LinalgMarker({"L2"}, "L1"));
+      LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
-      LinalgMarker({"L1"}, "REG"));
+      LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
 
   patterns.insert<LinalgTilingPattern<MatvecOp>>(
       ctx,
       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
           LinalgTilingLoopType::ParallelLoops),
-      LinalgMarker({}, "L1"));
+      LinalgMarker({}, Identifier::get("L1", ctx)));
 
   patterns.insert<LinalgTilingPattern<DotOp>>(
       ctx, LinalgTilingOptions().setTileSizes(8000),
-      LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
+      LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
+                                        Identifier::get("L3", ctx),
+                                        Identifier::get("L2", ctx)},
+                   Identifier::get("REG", ctx)));
 
   //===--------------------------------------------------------------------===//
   // Linalg tiling and permutation patterns.
@@ -95,20 +98,24 @@ static void applyPatterns(FuncOp funcOp) {
       LinalgTilingOptions()
           .setTileSizes({2000, 3000, 4000})
           .setInterchange({1, 2, 0}),
-      LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
+      LinalgMarker(Identifier::get("__with_perm__", ctx),
+                   Identifier::get("L2__with_perm__", ctx)));
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx,
       LinalgTilingOptions()
           .setTileSizes({200, 300, 400})
           .setInterchange({1, 0, 2}),
-      LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
+      LinalgMarker(Identifier::get("L2__with_perm__", ctx),
+                   Identifier::get("L1__with_perm__", ctx)));
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
-      LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
+      LinalgMarker(Identifier::get("L1__with_perm__", ctx),
+                   Identifier::get("REG__with_perm__", ctx)));
 
   patterns.insert<LinalgTilingPattern<MatvecOp>>(
       ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
-      LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
+      LinalgMarker(Identifier::get("__with_perm__", ctx),
+                   Identifier::get("L1__with_perm__", ctx)));
 
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
       ctx,
@@ -116,14 +123,16 @@ static void applyPatterns(FuncOp funcOp) {
           .setTileSizes({16, 8, 4})
           .setInterchange({1, 2, 0})
           .setLoopType(LinalgTilingLoopType::ParallelLoops),
-      LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
+      LinalgMarker(Identifier::get("par__with_perm__", ctx),
+                   Identifier::get("after_par__with_perm__", ctx)));
 
   //===--------------------------------------------------------------------===//
   // Linalg to loops patterns.
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgLoweringPattern<DotOp>>(
       ctx,
-      /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
+      /*loweringType=*/LinalgLoweringType::Loops,
+      LinalgMarker(Identifier::get("REG", ctx)));
 
   //===--------------------------------------------------------------------===//
   // Linalg to vector contraction patterns.
@@ -131,7 +140,7 @@ static void applyPatterns(FuncOp funcOp) {
   patterns.insert<LinalgVectorizationPattern<MatmulOp>,
                   LinalgVectorizationPattern<FillOp>,
                   LinalgVectorizationPattern<GenericOp>>(
-      ctx, LinalgMarker({"VECTORIZE"}));
+      ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
 
   //===--------------------------------------------------------------------===//
   // Linalg generic permutation patterns.
@@ -139,31 +148,34 @@ static void applyPatterns(FuncOp funcOp) {
   patterns.insert<LinalgInterchangePattern<GenericOp>>(
       ctx,
       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
-      LinalgMarker({}, "PERMUTED"));
+      LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
   patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
       ctx,
       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
-      LinalgMarker({}, "PERMUTED"));
+      LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
 
   //===--------------------------------------------------------------------===//
   // Linalg subview operands promotion.
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
       ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
-      LinalgMarker({"_promote_views_"}, "_views_promoted_"));
+      LinalgMarker(Identifier::get("_promote_views_", ctx),
+                   Identifier::get("_views_promoted_", ctx)));
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
       ctx,
       LinalgPromotionOptions()
           .setOperandsToPromote({0})
           .useFullTileBuffersByDefault(),
-      LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
+      LinalgMarker(Identifier::get("_promote_first_view_", ctx),
+                   Identifier::get("_first_view_promoted_", ctx)));
   patterns.insert<LinalgPromotionPattern<FillOp>>(
       ctx,
       LinalgPromotionOptions()
           .setOperandsToPromote({0})
           .setUseFullTileBuffers({true})
           .setAlignment(32),
-      LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
+      LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
+                   Identifier::get("_views_aligned_promoted_", ctx)));
 
   applyPatternsAndFoldGreedily(funcOp, patterns);
 
@@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
 static void fillL1TilingAndMatmulToVectorPatterns(
     FuncOp funcOp, StringRef startMarker,
     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
-  MLIRContext *context = funcOp.getContext();
+  MLIRContext *ctx = funcOp.getContext();
   patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
-      context,
+      ctx,
       LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
-      LinalgMarker({startMarker}, "L1")));
+      LinalgMarker(Identifier::get(startMarker, ctx),
+                   Identifier::get("L1", ctx))));
 
   patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
-      context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
-      LinalgMarker({"L1"}, "VEC")));
+      ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
+      LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
 
-  patternsVector.emplace_back(
-      LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
+  patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
+      ctx, LinalgMarker(Identifier::get("VEC", ctx))));
   patternsVector.back()
       .insert<LinalgVectorizationPattern<FillOp>,
-              LinalgVectorizationPattern<CopyOp>>(context);
+              LinalgVectorizationPattern<CopyOp>>(ctx);
 }
 
 //===----------------------------------------------------------------------===//
@@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
   return success();
 }
 
-void fillPromotionCallBackPatterns(MLIRContext *context,
+void fillPromotionCallBackPatterns(MLIRContext *ctx,
                                    OwningRewritePatternList &patterns) {
   patterns.insert<LinalgTilingPattern<MatmulOp>>(
-      context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
-      LinalgMarker({"START"}, "PROMOTE"));
+      ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
+      LinalgMarker(Identifier::get("START", ctx),
+                   Identifier::get("PROMOTE", ctx)));
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
-      context,
+      ctx,
       LinalgPromotionOptions()
           .setOperandsToPromote({0, 2})
           .setUseFullTileBuffers({false, false})
@@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
                 copyCallBackFn(b, src, dst, true);
                 return success();
               }),
-      LinalgMarker({"PROMOTE"}));
+      LinalgMarker(Identifier::get("PROMOTE", ctx)));
 }
 
 static void
@@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
   MLIRContext *ctx = funcOp.getContext();
   SmallVector<OwningRewritePatternList, 4> stage1Patterns;
   if (testMatmulToVectorPatterns1dTiling) {
-    fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
+                                          stage1Patterns);
   } else if (testMatmulToVectorPatterns2dTiling) {
-    stage1Patterns.emplace_back(
-        LinalgTilingPattern<MatmulOp>(ctx,
-                                      LinalgTilingOptions()
-                                          .setTileSizes({768, 264, 768})
-                                          .setInterchange({1, 2, 0}),
-                                      LinalgMarker({"START"}, "L2")));
-    fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
+    stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
+        ctx,
+        LinalgTilingOptions()
+            .setTileSizes({768, 264, 768})
+            .setInterchange({1, 2, 0}),
+        LinalgMarker(Identifier::get("START", ctx),
+                     Identifier::get("L2", ctx))));
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
+                                          stage1Patterns);
   }
   OwningRewritePatternList stage2Patterns =
       getLinalgTilingCanonicalizationPatterns(ctx);


        


More information about the Mlir-commits mailing list