[Mlir-commits] [mlir] 9534192 - [mlir][Linalg] Make contraction vectorization use vector transfers

Nicolas Vasilache llvmlistbot at llvm.org
Fri May 29 12:06:43 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-29T15:03:46-04:00
New Revision: 9534192c3bfd861f8082843c57dfee0a7881d266

URL: https://github.com/llvm/llvm-project/commit/9534192c3bfd861f8082843c57dfee0a7881d266
DIFF: https://github.com/llvm/llvm-project/commit/9534192c3bfd861f8082843c57dfee0a7881d266.diff

LOG: [mlir][Linalg] Make contraction vectorization use vector transfers

This revision replaces the load + vector.type_cast by appropriate vector transfer
operations. These play more nicely with other vector abstractions and canonicalization
patterns and lower to load/store with or without masks when appropriate.

Differential Revision: https://reviews.llvm.org/D80809

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8fa0aa35a874..763961311d0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -120,14 +120,30 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   // Vectorize other ops as vector contraction (currently only matmul).
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
+  auto extractVectorTypeFromScalarView = [](Value v) {
+    MemRefType mt = v.getType().cast<MemRefType>();
+    return VectorType::get(mt.getShape(), mt.getElementType());
+  };
   auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
-  Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
-  Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
-  Value c = std_load(memref);
+  Value viewA = linalgOp.getInput(0);
+  Value viewB = linalgOp.getInput(1);
+  Value viewC = linalgOp.getOutputBuffer(0);
+  Value zero = std_constant_index(0);
+  SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+                                 zero);
+  Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+                                 indicesA);
+  Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+                                 indicesB);
+  Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+                                 indicesC);
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  std_store(res, memref);
+  vector_transfer_write(res, viewC, indicesC);
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 4c46c74fe490..41fa3fd95d93 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -106,14 +106,11 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   return
 }
 // CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
 //       CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {


        


More information about the Mlir-commits mailing list