[Mlir-commits] [mlir] bf561dd - [mlir][Vector] Vectorize integer matmuls

Benjamin Kramer llvmlistbot at llvm.org
Wed Jul 22 10:45:26 PDT 2020


Author: Benjamin Kramer
Date: 2020-07-22T19:39:56+02:00
New Revision: bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2

URL: https://github.com/llvm/llvm-project/commit/bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2
DIFF: https://github.com/llvm/llvm-project/commit/bf561dd2eb138e5c5a78adcd429b7f79fd58b0d2.diff

LOG: [mlir][Vector] Vectorize integer matmuls

The underlying infrastructure supports this already, just add the
pattern matching for linalg.generic.

Differential Revision: https://reviews.llvm.org/D84335

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d923ea1bea76..8e5da6ae539d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,9 +52,17 @@ static bool hasMultiplyAddBody(Region &r) {
   auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
   auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
   auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
+  auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
+  auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
+  auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
+  auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
   return pattern1.match(&r.front().back()) ||
          pattern2.match(&r.front().back()) ||
-         pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
+         pattern3.match(&r.front().back()) ||
+         pattern4.match(&r.front().back()) ||
+         pattern5.match(&r.front().back()) ||
+         pattern6.match(&r.front().back()) ||
+         pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
 }
 
 // TODO: Should be Tablegen'd from a single source that generates the op itself.

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 9eedc31ef43a..819b3b764137 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -118,6 +118,23 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 //       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
 //       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
+func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
+                                 %C: memref<8x32xi32>) {
+  linalg.generic #matmul_trait %A, %B, %C {
+    ^bb(%a: i32, %b: i32, %c: i32) :
+      %d = muli %a, %b: i32
+      %e = addi %c, %d: i32
+      linalg.yield %e : i32
+  } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
+  return
+}
+// CHECK-LABEL: func @vectorization_test_integer
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+//       CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
   linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :


        


More information about the Mlir-commits mailing list