[Mlir-commits] [mlir] 6fb6a4d - [mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Apr 8 13:58:08 PDT 2020
Author: Nicolas Vasilache
Date: 2020-04-08T16:54:40-04:00
New Revision: 6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915
URL: https://github.com/llvm/llvm-project/commit/6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915
DIFF: https://github.com/llvm/llvm-project/commit/6fb6a4d7f972d8faacd6b2646fe15f2eea1e4915.diff
LOG: [mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go from matmul to vectors
This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.
Added:
mlir/test/Dialect/Linalg/matmul-to-vector.mlir
mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
Modified:
mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Linalg/matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
new file mode 100644
index 000000000000..351b2041d8c0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-vector.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s
+
+func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
+ linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
+ return
+}
+
+// CHECK-LABEL:func @matmul_perm
+// CHECK: vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
diff --git a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
index 9672edb4c493..f06854289abb 100644
--- a/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
+++ b/mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
@@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
+mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
new file mode 100644
index 000000000000..7fa4a3db6128
--- /dev/null
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgMatmulToVectorPatterns.td
@@ -0,0 +1,43 @@
+//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the pattern definition file for declarative Linalg transformations
+// tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
+
+include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
+include "mlir/Dialect/Vector/VectorTransformPatterns.td"
+
+//===----------------------------------------------------------------------===//
+// Linalg tiling and permutation patterns.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
+ [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
+ [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (PromoteSubviewsLinalgOp),
+ [(Constraint<HasOperandsOfType<"SubViewOp">>),
+ (Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
+
+//===----------------------------------------------------------------------===//
+// Linalg to vector contraction patterns.
+//===----------------------------------------------------------------------===//
+def : Pattern<(MatmulOp:$op $_, $_, $_),
+ [(VectorizeLinalgOp)],
+ [(Constraint<And<[
+ HasLinalgTransformMarker<"L1__with_perm__">,
+ PreconditionVectorizeLinalgOp]>>)]>;
+
+#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 904a47221ac1..23107f223b9c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
TestInlining.cpp
+ TestLinalgMatmulToVector.cpp
TestLinalgTransforms.cpp
TestLiveness.cpp
TestLoopMapping.cpp
@@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms
DEPENDS
MLIRStandardOpsIncGen
+ MLIRTestLinalgMatmulToVectorPatternsIncGen
MLIRTestLinalgTransformPatternsIncGen
MLIRTestVectorTransformPatternsIncGen
)
diff --git a/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
new file mode 100644
index 000000000000..6f49fabc192a
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
@@ -0,0 +1,51 @@
+//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::vector;
+
+namespace {
+#include "TestLinalgMatmulToVectorPatterns.h.inc"
+
+struct DeclarativeTransforms
+ : public PassWrapper<DeclarativeTransforms, FunctionPass> {
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ AffineApplyOp::getCanonicalizationPatterns(patterns, context);
+ AffineMinOp::getCanonicalizationPatterns(patterns, context);
+ AffineMaxOp::getCanonicalizationPatterns(patterns, context);
+ AllocOp::getCanonicalizationPatterns(patterns, context);
+ SubViewOp::getCanonicalizationPatterns(patterns, context);
+ ViewOp::getCanonicalizationPatterns(patterns, context);
+ populateWithGenerated(context, &patterns);
+ applyPatternsGreedily(getFunction(), patterns);
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+void registerTestLinalgMatmulToVectorPass() {
+ PassRegistration<DeclarativeTransforms> pass(
+ "linalg-matmul-to-vector",
+ "Test declarative transform patterns for matmul 3-D tiling + promotion"
+ " + vectorization");
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e8b2f3dc49f5..50a929616f27 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
+void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
@@ -101,6 +102,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
+ registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();
registerTestConstantFold();
More information about the Mlir-commits
mailing list