[PATCH] D80809: [mlir][Linalg] Make contraction vectorization use vector transfers

Nicolas Vasilache via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri May 29 09:13:00 PDT 2020


nicolasvasilache created this revision.
nicolasvasilache added reviewers: aartbik, ftynse.
Herald added subscribers: llvm-commits, jurahul, Kayjukh, frgossen, grosul1, Joonsoo, stephenneuendorffer, liufengdb, lucyrfox, mgester, arpith-jacob, antiagainst, shauheen, jpienaar, rriddle, mehdi_amini.
Herald added a project: LLVM.
ftynse accepted this revision.
ftynse added inline comments.
This revision is now accepted and ready to land.


================
Comment at: mlir/test/Dialect/Linalg/transform-patterns.mlir:2
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns
+//| FileCheck %s
 
----------------
Something went wrong here, you need `\` at the end of the previous line for this to work


This revision replaces the load + vector.type_cast by appropriate vector transfer
operations. These play more nicely with other vector abstractions and canonicalization
patterns and lower to load/store with or without masks when appropriate.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D80809

Files:
  mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
  mlir/test/Dialect/Linalg/transform-patterns.mlir


Index: mlir/test/Dialect/Linalg/transform-patterns.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns
+//| FileCheck %s
 
 // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1  + s0)>
 // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1.
@@ -106,14 +107,11 @@
   return
 }
 // CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32> to vector<8x16xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32> to vector<16x32xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32> to vector<8x32xf32>
 //       CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}}[] : vector<8x32xf32>
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {
Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -120,14 +120,30 @@
   // Vectorize other ops as vector contraction (currently only matmul).
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
+  auto extractVectorTypeFromScalarView = [](Value v) {
+    MemRefType mt = v.getType().cast<MemRefType>();
+    return VectorType::get(mt.getShape(), mt.getElementType());
+  };
   auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
-  Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
-  Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
-  Value c = std_load(memref);
+  Value viewA = linalgOp.getInput(0);
+  Value viewB = linalgOp.getInput(1);
+  Value viewC = linalgOp.getOutputBuffer(0);
+  Value zero = std_constant_index(0);
+  SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+                                 zero);
+  Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+                                 indicesA);
+  Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+                                 indicesB);
+  Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+                                 indicesC);
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  std_store(res, memref);
+  vector_transfer_write(res, viewC, indicesC);
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D80809.267260.patch
Type: text/x-patch
Size: 3938 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200529/6e1965ba/attachment.bin>


More information about the llvm-commits mailing list