[PATCH] D80809: [mlir][Linalg] Make contraction vectorization use vector transfers
Nicolas Vasilache via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri May 29 12:34:29 PDT 2020
This revision was automatically updated to reflect the committed changes.
Closed by commit rG9534192c3bfd: [mlir][Linalg] Make contraction vectorization use vector transfers (authored by nicolasvasilache).
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D80809/new/
https://reviews.llvm.org/D80809
Files:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
Index: mlir/test/Dialect/Linalg/transform-patterns.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -106,14 +106,11 @@
return
}
// CHECK-LABEL: func @vectorization_test
-// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -120,14 +120,30 @@
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
+ auto extractVectorTypeFromScalarView = [](Value v) {
+ MemRefType mt = v.getType().cast<MemRefType>();
+ return VectorType::get(mt.getShape(), mt.getElementType());
+ };
auto linalgOp = cast<linalg::LinalgOp>(op);
- Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
- Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
- Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
- Value c = std_load(memref);
+ Value viewA = linalgOp.getInput(0);
+ Value viewB = linalgOp.getInput(1);
+ Value viewC = linalgOp.getOutputBuffer(0);
+ Value zero = std_constant_index(0);
+ SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+ zero);
+ Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+ indicesA);
+ Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+ indicesB);
+ Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+ indicesC);
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
- std_store(res, memref);
+ vector_transfer_write(res, viewC, indicesC);
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D80809.267333.patch
Type: text/x-patch
Size: 3585 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200529/00a089f2/attachment.bin>
More information about the llvm-commits
mailing list