[Mlir-commits] [mlir] d12d05a - [mlir][Linalg] Introduce a helper function for staged pattern application
Nicolas Vasilache
llvmlistbot at llvm.org
Mon May 11 13:48:47 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-11T16:46:30-04:00
New Revision: d12d05a73142af287a698e5df62d7b7b1eefdf3e
URL: https://github.com/llvm/llvm-project/commit/d12d05a73142af287a698e5df62d7b7b1eefdf3e
DIFF: https://github.com/llvm/llvm-project/commit/d12d05a73142af287a698e5df62d7b7b1eefdf3e.diff
LOG: [mlir][Linalg] Introduce a helper function for staged pattern application
Summary:
This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion:
1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order.
2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence.
3. the third stage consists of applying a lambda, generally used for non-local transformation effects.
This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes.
A test that exercises these behaviors is added.
Differential Revision: https://reviews.llvm.org/D79518
Added:
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 896b31835fb4..f5bf5892199f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -367,6 +367,23 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringType loweringType;
};
+//===----------------------------------------------------------------------===//
+// Support for staged pattern application.
+//===----------------------------------------------------------------------===//
+/// Helper function to allow applying rewrite patterns, interleaved with more
+/// global transformations, in a staged fashion:
+/// 1. the first stage consists of a list of OwningRewritePatternList. Each
+/// OwningRewritePatternList in this list is applied once, in order.
+/// 2. the second stage consists of a single OwningRewritePattern that is
+/// applied greedily until convergence.
+/// 3. the third stage consists of applying a lambda, generally used for
+/// non-local transformation effects. This allows creating custom fused
+/// transformations where patterns can be ordered and applied at a finer
+/// granularity than a sequence of traditional compiler passes.
+LogicalResult applyStagedPatterns(
+ Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
+ const OwningRewritePatternList &stage2Patterns,
+ llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0d125b3d8148..6b124e0ecdfa 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -388,6 +388,15 @@ class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
+ OwningRewritePatternList() = default;
+
+ /// Construct a OwningRewritePatternList populated with the pattern `t` of
+ /// type `T`.
+ template <typename T>
+ OwningRewritePatternList(T &&t) {
+ patterns.emplace_back(std::make_unique<T>(t));
+ }
+
PatternListT::iterator begin() { return patterns.begin(); }
PatternListT::iterator end() { return patterns.end(); }
PatternListT::const_iterator begin() const { return patterns.begin(); }
@@ -399,12 +408,13 @@ class OwningRewritePatternList {
//===--------------------------------------------------------------------===//
/// Add an instance of each of the pattern types 'Ts' to the pattern list with
- /// the given arguments.
+ /// the given arguments. Return a reference to `this` for chaining insertions.
/// Note: ConstructorArg is necessary here to separate the two variadic lists.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
- void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
+ OwningRewritePatternList &insert(ConstructorArg &&arg,
+ ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
@@ -412,6 +422,7 @@ class OwningRewritePatternList {
using dummy = int[];
(void)dummy{
0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
+ return *this;
}
private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 175c6c8ef096..96d97809a9d3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
rewriter.eraseOp(op);
return success();
}
+
+LogicalResult mlir::linalg::applyStagedPatterns(
+ Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
+ const OwningRewritePatternList &stage2Patterns,
+ llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
+ for (const auto &patterns : stage1Patterns) {
+ if (!applyPatternsAndFoldGreedily(op, patterns)) {
+ llvm::dbgs() << "Underlying first stage rewrite did not converge";
+ return failure();
+ }
+ if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
+ llvm::dbgs() << "Underlying second stage rewrite did not converge";
+ return failure();
+ }
+ if (stage3Lambda) {
+ if (failed(stage3Lambda(op)))
+ return failure();
+ }
+ }
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
new file mode 100644
index 000000000000..29ea43aa540b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -0,0 +1,34 @@
+// TODO: this needs a fix to land before being reactivated.
+// RUN: ls
+// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
+// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+
+func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
+ linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
+ return
+}
+
+// CHECK-LABEL:func @matmul
+// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
+//
+// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
+// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
+//
+// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
+// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
+//
+// CHECK: linalg.copy
+// CHECK: linalg.copy
+// CHECK: linalg.copy
+//
+// CHECK: vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
+//
+// CHECK: linalg.copy
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index eb27a7ae0034..0390ac945d2f 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -33,6 +33,18 @@ struct TestLinalgTransforms
Option<bool> testPatterns{*this, "test-patterns",
llvm::cl::desc("Test a mixed set of patterns"),
llvm::cl::init(false)};
+ Option<bool> testMatmulToVectorPatterns1dTiling{
+ *this, "test-matmul-to-vector-patterns-tile-1d",
+ llvm::cl::desc(
+ "Test a fused pass that applies patterns from matmul to vectors via "
+ "1-d tiling"),
+ llvm::cl::init(false)};
+ Option<bool> testMatmulToVectorPatterns2dTiling{
+ *this, "test-matmul-to-vector-patterns-tile-2d",
+ llvm::cl::desc(
+ "Test a fused pass that applies patterns from matmul to vectors via "
+ "2-d tiling"),
+ llvm::cl::init(false)};
};
} // end anonymous namespace
@@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
});
}
+OwningRewritePatternList
+getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
+ OwningRewritePatternList patterns;
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ AffineMinOp::getCanonicalizationPatterns(patterns, context);
+ AffineMaxOp::getCanonicalizationPatterns(patterns, context);
+ AllocOp::getCanonicalizationPatterns(patterns, context);
+ SubViewOp::getCanonicalizationPatterns(patterns, context);
+ ViewOp::getCanonicalizationPatterns(patterns, context);
+ MatmulOp::getCanonicalizationPatterns(patterns, context);
+ return patterns;
+}
+
+void fillL1TilingAndMatmulToVectorPatterns(
+ MLIRContext *context, StringRef startMarker,
+ SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
+ patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
+ context,
+ LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
+ LinalgMarker({startMarker}, "L1")));
+
+ patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
+ context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
+
+ patternsVector.emplace_back(
+ LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
+ patternsVector.back()
+ .insert<LinalgVectorizationPattern<FillOp>,
+ LinalgVectorizationPattern<CopyOp>>(context);
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
- if (testPatterns)
- return applyPatterns(getFunction());
+ if (testPatterns) {
+ applyPatterns(getFunction());
+ } else {
+ SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+ if (testMatmulToVectorPatterns1dTiling) {
+ fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
+ stage1Patterns);
+ } else if (testMatmulToVectorPatterns2dTiling) {
+ stage1Patterns.emplace_back(
+ LinalgTilingPattern<MatmulOp>(&getContext(),
+ LinalgTilingOptions()
+ .setTileSizes({768, 264, 768})
+ .setInterchange({1, 2, 0}),
+ LinalgMarker({"START"}, "L2")));
+ fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
+ stage1Patterns);
+ }
+ OwningRewritePatternList stage2Patterns =
+ getMatmulToVectorCanonicalizationPatterns(&getContext());
+ applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
+ }
+
+ // Drop the marker.
+ getFunction().walk([](LinalgOp op) {
+ op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+ });
}
namespace mlir {
More information about the Mlir-commits
mailing list