[Mlir-commits] [mlir] d12d05a - [mlir][Linalg] Introduce a helper function for staged pattern application

Nicolas Vasilache llvmlistbot at llvm.org
Mon May 11 13:48:47 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-11T16:46:30-04:00
New Revision: d12d05a73142af287a698e5df62d7b7b1eefdf3e

URL: https://github.com/llvm/llvm-project/commit/d12d05a73142af287a698e5df62d7b7b1eefdf3e
DIFF: https://github.com/llvm/llvm-project/commit/d12d05a73142af287a698e5df62d7b7b1eefdf3e.diff

LOG: [mlir][Linalg] Introduce a helper function for staged pattern application

Summary:
This revision introduces a helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion:
1. the first stage consists of an OwningRewritePatternList. The RewritePattern in this list are applied once and in order.
2. the second stage consists of a single OwningRewritePattern that is applied greedily until convergence.
3. the third stage consists of applying a lambda, generally used for non-local transformation effects.

This allows creating custom fused transformations where patterns can be ordered and applied at a finer granularity than a sequence of traditional compiler passes.

A test that exercises these behaviors is added.

Differential Revision: https://reviews.llvm.org/D79518

Added: 
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 896b31835fb4..f5bf5892199f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -367,6 +367,23 @@ struct LinalgLoweringPattern : public RewritePattern {
   LinalgLoweringType loweringType;
 };
 
+//===----------------------------------------------------------------------===//
+// Support for staged pattern application.
+//===----------------------------------------------------------------------===//
+/// Helper function to allow applying rewrite patterns, interleaved with more
+/// global transformations, in a staged fashion:
+///   1. the first stage consists of a list of OwningRewritePatternList. Each
+///   OwningRewritePatternList in this list is applied once, in order.
+///   2. the second stage consists of a single OwningRewritePattern that is
+///   applied greedily until convergence.
+///   3. the third stage consists of applying a lambda, generally used for
+///   non-local transformation effects. This allows creating custom fused
+///   transformations where patterns can be ordered and applied at a finer
+///   granularity than a sequence of traditional compiler passes.
+LogicalResult applyStagedPatterns(
+    Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
+    const OwningRewritePatternList &stage2Patterns,
+    llvm::function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0d125b3d8148..6b124e0ecdfa 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -388,6 +388,15 @@ class OwningRewritePatternList {
   using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
 
 public:
+  OwningRewritePatternList() = default;
+
+  /// Construct a OwningRewritePatternList populated with the pattern `t` of
+  /// type `T`.
+  template <typename T>
+  OwningRewritePatternList(T &&t) {
+    patterns.emplace_back(std::make_unique<T>(t));
+  }
+
   PatternListT::iterator begin() { return patterns.begin(); }
   PatternListT::iterator end() { return patterns.end(); }
   PatternListT::const_iterator begin() const { return patterns.begin(); }
@@ -399,12 +408,13 @@ class OwningRewritePatternList {
   //===--------------------------------------------------------------------===//
 
   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
-  /// the given arguments.
+  /// the given arguments. Return a reference to `this` for chaining insertions.
   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
   template <typename... Ts, typename ConstructorArg,
             typename... ConstructorArgs,
             typename = std::enable_if_t<sizeof...(Ts) != 0>>
-  void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
+  OwningRewritePatternList &insert(ConstructorArg &&arg,
+                                   ConstructorArgs &&... args) {
     // The following expands a call to emplace_back for each of the pattern
     // types 'Ts'. This magic is necessary due to a limitation in the places
     // that a parameter pack can be expanded in c++11.
@@ -412,6 +422,7 @@ class OwningRewritePatternList {
     using dummy = int[];
     (void)dummy{
         0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
+    return *this;
   }
 
 private:

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 175c6c8ef096..96d97809a9d3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -198,3 +198,24 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
   rewriter.eraseOp(op);
   return success();
 }
+
+LogicalResult mlir::linalg::applyStagedPatterns(
+    Operation *op, ArrayRef<OwningRewritePatternList> stage1Patterns,
+    const OwningRewritePatternList &stage2Patterns,
+    llvm::function_ref<LogicalResult(Operation *)> stage3Lambda) {
+  for (const auto &patterns : stage1Patterns) {
+    if (!applyPatternsAndFoldGreedily(op, patterns)) {
+      llvm::dbgs() << "Underlying first stage rewrite did not converge";
+      return failure();
+    }
+    if (!applyPatternsAndFoldGreedily(op, stage2Patterns)) {
+      llvm::dbgs() << "Underlying second stage rewrite did not converge";
+      return failure();
+    }
+    if (stage3Lambda) {
+      if (failed(stage3Lambda(op)))
+        return failure();
+    }
+  }
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
new file mode 100644
index 000000000000..29ea43aa540b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -0,0 +1,34 @@
+// TODO: this needs a fix to land before being reactivated.
+// RUN: ls
+// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
+// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
+
+func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+                  %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+                  %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
+  linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "START"} :
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
+  return
+}
+
+// CHECK-LABEL:func @matmul
+//      CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+//      CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
+//
+//      CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
+//      CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
+//
+//      CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
+//      CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
+//
+//      CHECK: linalg.copy
+//      CHECK: linalg.copy
+//      CHECK: linalg.copy
+//
+//      CHECK: vector.contract
+// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
+//
+//      CHECK: linalg.copy

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index eb27a7ae0034..0390ac945d2f 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -33,6 +33,18 @@ struct TestLinalgTransforms
   Option<bool> testPatterns{*this, "test-patterns",
                             llvm::cl::desc("Test a mixed set of patterns"),
                             llvm::cl::init(false)};
+  Option<bool> testMatmulToVectorPatterns1dTiling{
+      *this, "test-matmul-to-vector-patterns-tile-1d",
+      llvm::cl::desc(
+          "Test a fused pass that applies patterns from matmul to vectors via "
+          "1-d tiling"),
+      llvm::cl::init(false)};
+  Option<bool> testMatmulToVectorPatterns2dTiling{
+      *this, "test-matmul-to-vector-patterns-tile-2d",
+      llvm::cl::desc(
+          "Test a fused pass that applies patterns from matmul to vectors via "
+          "2-d tiling"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -137,10 +149,65 @@ static void applyPatterns(FuncOp funcOp) {
   });
 }
 
+OwningRewritePatternList
+getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
+  OwningRewritePatternList patterns;
+  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+  AffineMinOp::getCanonicalizationPatterns(patterns, context);
+  AffineMaxOp::getCanonicalizationPatterns(patterns, context);
+  AllocOp::getCanonicalizationPatterns(patterns, context);
+  SubViewOp::getCanonicalizationPatterns(patterns, context);
+  ViewOp::getCanonicalizationPatterns(patterns, context);
+  MatmulOp::getCanonicalizationPatterns(patterns, context);
+  return patterns;
+}
+
+void fillL1TilingAndMatmulToVectorPatterns(
+    MLIRContext *context, StringRef startMarker,
+    SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
+  patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
+      context,
+      LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
+      LinalgMarker({startMarker}, "L1")));
+
+  patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
+      context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
+
+  patternsVector.emplace_back(
+      LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));
+  patternsVector.back()
+      .insert<LinalgVectorizationPattern<FillOp>,
+              LinalgVectorizationPattern<CopyOp>>(context);
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
-  if (testPatterns)
-    return applyPatterns(getFunction());
+  if (testPatterns) {
+    applyPatterns(getFunction());
+  } else {
+    SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+    if (testMatmulToVectorPatterns1dTiling) {
+      fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START",
+                                            stage1Patterns);
+    } else if (testMatmulToVectorPatterns2dTiling) {
+      stage1Patterns.emplace_back(
+          LinalgTilingPattern<MatmulOp>(&getContext(),
+                                        LinalgTilingOptions()
+                                            .setTileSizes({768, 264, 768})
+                                            .setInterchange({1, 2, 0}),
+                                        LinalgMarker({"START"}, "L2")));
+      fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2",
+                                            stage1Patterns);
+    }
+    OwningRewritePatternList stage2Patterns =
+        getMatmulToVectorCanonicalizationPatterns(&getContext());
+    applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
+  }
+
+  // Drop the marker.
+  getFunction().walk([](LinalgOp op) {
+    op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+  });
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list