[PATCH] D80809: [mlir][Linalg] Make contraction vectorization use vector transfers

Nicolas Vasilache via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri May 29 12:34:29 PDT 2020


This revision was automatically updated to reflect the committed changes.
Closed by commit rG9534192c3bfd: [mlir][Linalg] Make contraction vectorization use vector transfers (authored by nicolasvasilache).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D80809/new/

https://reviews.llvm.org/D80809

Files:
  mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
  mlir/test/Dialect/Linalg/transform-patterns.mlir


Index: mlir/test/Dialect/Linalg/transform-patterns.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -106,14 +106,11 @@
   return
 }
 // CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
 //       CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -120,14 +120,30 @@
   // Vectorize other ops as vector contraction (currently only matmul).
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
+  auto extractVectorTypeFromScalarView = [](Value v) {
+    MemRefType mt = v.getType().cast<MemRefType>();
+    return VectorType::get(mt.getShape(), mt.getElementType());
+  };
   auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
-  Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
-  Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
-  Value c = std_load(memref);
+  Value viewA = linalgOp.getInput(0);
+  Value viewB = linalgOp.getInput(1);
+  Value viewC = linalgOp.getOutputBuffer(0);
+  Value zero = std_constant_index(0);
+  SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+                                 zero);
+  Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+                                 indicesA);
+  Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+                                 indicesB);
+  Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+                                 indicesC);
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  std_store(res, memref);
+  vector_transfer_write(res, viewC, indicesC);
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D80809.267333.patch
Type: text/x-patch
Size: 3585 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200529/00a089f2/attachment.bin>


More information about the llvm-commits mailing list