[Mlir-commits] [mlir] 6fb6a4d - [mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors

Nicolas Vasilache llvmlistbot at llvm.org
Wed Apr 8 13:58:08 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-08T16:54:40-04:00
New Revision: 6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915

URL: https://github.com/llvm/llvm-project/commit/6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915
DIFF: https://github.com/llvm/llvm-project/commit/6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915.diff

LOG: [mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors

This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.

Added: 
    mlir/test/Dialect/Linalg/matmul-to-vector.mlir
    mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
    mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp

Modified: 
    mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
new file mode 100644
index 000000000000..351b2041d8c0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s
+
+func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+                  %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+                  %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
+  linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+    memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
+  return
+}
+
+// CHECK-LABEL:func @matmul_perm
+//      CHECK:        vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>

diff  --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
index 9672edb4c493..f06854289abb 100644
--- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
+++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
@@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
 set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
 mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
+mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)

diff  --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
new file mode 100644
index 000000000000..7fa4a3db6128
--- /dev/null
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
@@ -0,0 +1,43 @@
+//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the pattern definition file for declarative Linalg transformations
+// tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+
+include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
+include "mlir/Dialect/Vector/VectorTransformPatterns.td"
+
+//===----------------------------------------------------------------------===//
+// Linalg tiling and permutation patterns.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $_, $_, $_),
+          (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
+         [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+          (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
+         [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+          (PromoteSubviewsLinalgOp),
+         [(Constraint<HasOperandsOfType<"SubViewOp">>),
+          (Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
+
+//===----------------------------------------------------------------------===//
+// Linalg to vector contraction patterns.
+//===----------------------------------------------------------------------===//
+def : Pattern<(MatmulOp:$op $_, $_, $_),
+              [(VectorizeLinalgOp)],
+              [(Constraint<And<[
+                HasLinalgTransformMarker<"L1__with_perm__">,
+                PreconditionVectorizeLinalgOp]>>)]>;
+
+#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 904a47221ac1..23107f223b9c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp
   TestInlining.cpp
+  TestLinalgMatmulToVector.cpp
   TestLinalgTransforms.cpp
   TestLiveness.cpp
   TestLoopMapping.cpp
@@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms
 
   DEPENDS
   MLIRStandardOpsIncGen
+  MLIRTestLinalgMatmulToVectorPatternsIncGen
   MLIRTestLinalgTransformPatternsIncGen
   MLIRTestVectorTransformPatternsIncGen
 )

diff  --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
new file mode 100644
index 000000000000..6f49fabc192a
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
@@ -0,0 +1,51 @@
+//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::vector;
+
+namespace {
+#include "TestLinalgMatmulToVectorPatterns.h.inc"
+
+struct DeclarativeTransforms
+    : public PassWrapper<DeclarativeTransforms, FunctionPass> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    auto *context = &getContext();
+    AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+    AffineMinOp::getCanonicalizationPatterns(patterns, context);
+    AffineMaxOp::getCanonicalizationPatterns(patterns, context);
+    AllocOp::getCanonicalizationPatterns(patterns, context);
+    SubViewOp::getCanonicalizationPatterns(patterns, context);
+    ViewOp::getCanonicalizationPatterns(patterns, context);
+    populateWithGenerated(context, &patterns);
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestLinalgMatmulToVectorPass() {
+  PassRegistration<DeclarativeTransforms> pass(
+      "linalg-matmul-to-vector",
+      "Test declarative transform patterns for matmul 3-D tiling + promotion"
+      " + vectorization");
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e8b2f3dc49f5..50a929616f27 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
 void registerSymbolTestPasses();
 void registerTestAffineDataCopyPass();
 void registerTestAllReduceLoweringPass();
+void registerTestLinalgMatmulToVectorPass();
 void registerTestLoopPermutationPass();
 void registerTestCallGraphPass();
 void registerTestConstantFold();
@@ -101,6 +102,7 @@ void registerTestPasses() {
   registerSymbolTestPasses();
   registerTestAffineDataCopyPass();
   registerTestAllReduceLoweringPass();
+  registerTestLinalgMatmulToVectorPass();
   registerTestLoopPermutationPass();
   registerTestCallGraphPass();
   registerTestConstantFold();


        


More information about the Mlir-commits mailing list