[Mlir-commits] [mlir] e349fb7 - [mlir][Linalg] NFC - Make markers use Identifier instead of StringRef
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 3 02:55:21 PDT 2020
Author: Nicolas Vasilache
Date: 2020-06-03T05:52:32-04:00
New Revision: e349fb70a23f3a39e058605e4e2db66da5e5ea4a
URL: https://github.com/llvm/llvm-project/commit/e349fb70a23f3a39e058605e4e2db66da5e5ea4a
DIFF: https://github.com/llvm/llvm-project/commit/e349fb70a23f3a39e058605e4e2db66da5e5ea4a.diff
LOG: [mlir][Linalg] NFC - Make markers use Identifier instead of StringRef
Summary: This removes string ownership worries by putting everything into the context and allows more constructing identifiers programmatically.
Reviewers: ftynse
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D81027
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index fb8fc4cbe949..2437ac557799 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -206,15 +207,16 @@ struct LinalgTransforms {
/// Helper class to control common attribute matching and setting behavior.
struct LinalgMarker {
- LinalgMarker(ArrayRef<StringRef> matchDisjunction = {},
- Optional<StringRef> replacement = None);
- LinalgMarker(ArrayRef<StringRef> matchDisjunction, StringRef replacement);
+ explicit LinalgMarker(ArrayRef<Identifier> matchDisjunction = {},
+ Optional<Identifier> replacement = None);
+ LinalgMarker(LinalgMarker &&) = default;
+ LinalgMarker(const LinalgMarker &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgMarker(PatternRewriter &rewriter, Operation *op) const;
private:
- SmallVector<StringRef, 4> matchDisjunction;
- Optional<StringRef> replacement;
+ SmallVector<Identifier, 4> matchDisjunction;
+ Optional<Identifier> replacement;
};
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 5b4fec4bbf20..0dac95739679 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -459,8 +459,8 @@ class RewritePatternList<OpTy, OpTypes...> {
public:
static void insert(OwningRewritePatternList &patterns,
const LinalgTilingOptions &options, MLIRContext *ctx) {
- patterns.insert<LinalgTilingPattern<OpTy>>(ctx, options,
- LinalgMarker({}, "tiled"));
+ patterns.insert<LinalgTilingPattern<OpTy>>(
+ ctx, options, LinalgMarker({}, Identifier::get("tiled", ctx)));
RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 76e118e482f0..1aefb8f42f72 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -46,15 +46,11 @@ using llvm::dbgs;
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
-mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
- Optional<StringRef> replacement)
+mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
+ Optional<Identifier> replacement)
: matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
replacement(replacement) {}
-mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<StringRef> matchDisjunction,
- StringRef replacement)
- : LinalgMarker(matchDisjunction, Optional<StringRef>{replacement}) {}
-
LogicalResult
mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
Operation *op) const {
@@ -66,12 +62,7 @@ mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
if (matchDisjunction.empty())
return success();
- // 2. Has no marker and matchDisjuntion matches the no-moarker case.
- for (auto marker : matchDisjunction)
- if (marker.empty())
- return success();
-
- // 3. Has no marker but was expecting a marker.
+ // 2. Has no marker but was expecting a marker.
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << " does not have any marker from list: ";
interleaveComma(matchDisjunction, diag);
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 41fa3fd95d93..9a022082b3be 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,9 +14,10 @@
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>,
%v: memref<f32>) {
- linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<f32>
+ linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<f32>
return
}
// CHECK-LABEL: func @dot
@@ -35,9 +36,10 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
- linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?xf32, offset: ?, strides: [1]>,
- memref<?xf32, offset: ?, strides: [1]>
+ linalg.matvec(%A, %x, %y) :
+ memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @matvec
@@ -51,9 +53,10 @@ func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>
+ linalg.matmul(%A, %B, %C) { __internal_linalg_transform__ = "MEM" } :
+ memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @matmul
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 31189f47f9ae..4b1b5bdc1366 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -66,26 +66,29 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
- LinalgMarker({"MEM", {}}, "L3"));
+ LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
- LinalgMarker({"L3"}, "L2"));
+ LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
- LinalgMarker({"L2"}, "L1"));
+ LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
- LinalgMarker({"L1"}, "REG"));
+ LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx,
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
LinalgTilingLoopType::ParallelLoops),
- LinalgMarker({}, "L1"));
+ LinalgMarker({}, Identifier::get("L1", ctx)));
patterns.insert<LinalgTilingPattern<DotOp>>(
ctx, LinalgTilingOptions().setTileSizes(8000),
- LinalgMarker({"MEM", "L3", "L2", {}}, "REG"));
+ LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
+ Identifier::get("L3", ctx),
+ Identifier::get("L2", ctx)},
+ Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
@@ -95,20 +98,24 @@ static void applyPatterns(FuncOp funcOp) {
LinalgTilingOptions()
.setTileSizes({2000, 3000, 4000})
.setInterchange({1, 2, 0}),
- LinalgMarker({"__with_perm__"}, "L2__with_perm__"));
+ LinalgMarker(Identifier::get("__with_perm__", ctx),
+ Identifier::get("L2__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
LinalgTilingOptions()
.setTileSizes({200, 300, 400})
.setInterchange({1, 0, 2}),
- LinalgMarker({"L2__with_perm__"}, "L1__with_perm__"));
+ LinalgMarker(Identifier::get("L2__with_perm__", ctx),
+ Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
- LinalgMarker({"L1__with_perm__"}, "REG__with_perm__"));
+ LinalgMarker(Identifier::get("L1__with_perm__", ctx),
+ Identifier::get("REG__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatvecOp>>(
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
- LinalgMarker({"__with_perm__"}, "L1__with_perm__"));
+ LinalgMarker(Identifier::get("__with_perm__", ctx),
+ Identifier::get("L1__with_perm__", ctx)));
patterns.insert<LinalgTilingPattern<MatmulOp>>(
ctx,
@@ -116,14 +123,16 @@ static void applyPatterns(FuncOp funcOp) {
.setTileSizes({16, 8, 4})
.setInterchange({1, 2, 0})
.setLoopType(LinalgTilingLoopType::ParallelLoops),
- LinalgMarker({"par__with_perm__"}, "after_par__with_perm__"));
+ LinalgMarker(Identifier::get("par__with_perm__", ctx),
+ Identifier::get("after_par__with_perm__", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to loops patterns.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgLoweringPattern<DotOp>>(
ctx,
- /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker({"REG"}));
+ /*loweringType=*/LinalgLoweringType::Loops,
+ LinalgMarker(Identifier::get("REG", ctx)));
//===--------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
@@ -131,7 +140,7 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
LinalgVectorizationPattern<FillOp>,
LinalgVectorizationPattern<GenericOp>>(
- ctx, LinalgMarker({"VECTORIZE"}));
+ ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
//===--------------------------------------------------------------------===//
// Linalg generic permutation patterns.
@@ -139,31 +148,34 @@ static void applyPatterns(FuncOp funcOp) {
patterns.insert<LinalgInterchangePattern<GenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgMarker({}, "PERMUTED"));
+ LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgMarker({}, "PERMUTED"));
+ LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
//===--------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===--------------------------------------------------------------------===//
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
- LinalgMarker({"_promote_views_"}, "_views_promoted_"));
+ LinalgMarker(Identifier::get("_promote_views_", ctx),
+ Identifier::get("_views_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.useFullTileBuffersByDefault(),
- LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
+ LinalgMarker(Identifier::get("_promote_first_view_", ctx),
+ Identifier::get("_first_view_promoted_", ctx)));
patterns.insert<LinalgPromotionPattern<FillOp>>(
ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0})
.setUseFullTileBuffers({true})
.setAlignment(32),
- LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
+ LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
+ Identifier::get("_views_aligned_promoted_", ctx)));
applyPatternsAndFoldGreedily(funcOp, patterns);
@@ -176,21 +188,22 @@ static void applyPatterns(FuncOp funcOp) {
static void fillL1TilingAndMatmulToVectorPatterns(
FuncOp funcOp, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
- MLIRContext *context = funcOp.getContext();
+ MLIRContext *ctx = funcOp.getContext();
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
- context,
+ ctx,
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
- LinalgMarker({startMarker}, "L1")));
+ LinalgMarker(Identifier::get(startMarker, ctx),
+ Identifier::get("L1", ctx))));
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
- context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
- LinalgMarker({"L1"}, "VEC")));
+ ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
+ LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
- patternsVector.emplace_back(
- LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
+ patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
+ ctx, LinalgMarker(Identifier::get("VEC", ctx))));
patternsVector.back()
.insert<LinalgVectorizationPattern<FillOp>,
- LinalgVectorizationPattern<CopyOp>>(context);
+ LinalgVectorizationPattern<CopyOp>>(ctx);
}
//===----------------------------------------------------------------------===//
@@ -231,13 +244,14 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
return success();
}
-void fillPromotionCallBackPatterns(MLIRContext *context,
+void fillPromotionCallBackPatterns(MLIRContext *ctx,
OwningRewritePatternList &patterns) {
patterns.insert<LinalgTilingPattern<MatmulOp>>(
- context, LinalgTilingOptions().setTileSizes({16, 16, 16}),
- LinalgMarker({"START"}, "PROMOTE"));
+ ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
+ LinalgMarker(Identifier::get("START", ctx),
+ Identifier::get("PROMOTE", ctx)));
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
- context,
+ ctx,
LinalgPromotionOptions()
.setOperandsToPromote({0, 2})
.setUseFullTileBuffers({false, false})
@@ -251,7 +265,7 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
copyCallBackFn(b, src, dst, true);
return success();
}),
- LinalgMarker({"PROMOTE"}));
+ LinalgMarker(Identifier::get("PROMOTE", ctx)));
}
static void
@@ -261,15 +275,18 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
if (testMatmulToVectorPatterns1dTiling) {
- fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
+ fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
+ stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
- stage1Patterns.emplace_back(
- LinalgTilingPattern<MatmulOp>(ctx,
- LinalgTilingOptions()
- .setTileSizes({768, 264, 768})
- .setInterchange({1, 2, 0}),
- LinalgMarker({"START"}, "L2")));
- fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
+ stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
+ ctx,
+ LinalgTilingOptions()
+ .setTileSizes({768, 264, 768})
+ .setInterchange({1, 2, 0}),
+ LinalgMarker(Identifier::get("START", ctx),
+ Identifier::get("L2", ctx))));
+ fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
+ stage1Patterns);
}
OwningRewritePatternList stage2Patterns =
getLinalgTilingCanonicalizationPatterns(ctx);
More information about the Mlir-commits
mailing list