[Mlir-commits] [mlir] bf561dd - [mlir][Vector] Vectorize integer matmuls
Benjamin Kramer
llvmlistbot at llvm.org
Wed Jul 22 10:45:26 PDT 2020
Author: Benjamin Kramer
Date: 2020-07-22T19:39:56+02:00
New Revision: bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2
URL: https://github.com/llvm/llvm-project/commit/bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2
DIFF: https://github.com/llvm/llvm-project/commit/bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2.diff
LOG: [mlir][Vector] Vectorize integer matmuls
The underlying infrastructure supports this already, just add the
pattern matching for linalg.generic.
Differential Revision: https://reviews.llvm.org/D84335
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d923ea1bea76..8e5da6ae539d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,9 +52,17 @@ static bool hasMultiplyAddBody(Region &r) {
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
+ auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
+ auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
+ auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
+ auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
return pattern1.match(&r.front().back()) ||
pattern2.match(&r.front().back()) ||
- pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
+ pattern3.match(&r.front().back()) ||
+ pattern4.match(&r.front().back()) ||
+ pattern5.match(&r.front().back()) ||
+ pattern6.match(&r.front().back()) ||
+ pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
}
// TODO: Should be Tablegen'd from a single source that generates the op itself.
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 9eedc31ef43a..819b3b764137 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -118,6 +118,23 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
+func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
+ %C: memref<8x32xi32>) {
+ linalg.generic #matmul_trait %A, %B, %C {
+ ^bb(%a: i32, %b: i32, %c: i32) :
+ %d = muli %a, %b: i32
+ %e = addi %c, %d: i32
+ linalg.yield %e : i32
+ } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
+ return
+}
+// CHECK-LABEL: func @vectorization_test_integer
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :
More information about the Mlir-commits
mailing list