[Mlir-commits] [mlir] 9534192 - [mlir][Linalg] Make contraction vectorization use vector transfers
Nicolas Vasilache
llvmlistbot at llvm.org
Fri May 29 12:06:43 PDT 2020
Author: Nicolas Vasilache
Date: 2020-05-29T15:03:46-04:00
New Revision: 9534192c3bfd861f8082843c57dfee0a7881d266
URL: https://github.com/llvm/llvm-project/commit/9534192c3bfd861f8082843c57dfee0a7881d266
DIFF: https://github.com/llvm/llvm-project/commit/9534192c3bfd861f8082843c57dfee0a7881d266.diff
LOG: [mlir][Linalg] Make contraction vectorization use vector transfers
This revision replaces the load + vector.type_cast by appropriate vector transfer
operations. These play more nicely with other vector abstractions and canonicalization
patterns and lower to load/store with or without masks when appropriate.
Differential Revision: https://reviews.llvm.org/D80809
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8fa0aa35a874..763961311d0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -120,14 +120,30 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
+ auto extractVectorTypeFromScalarView = [](Value v) {
+ MemRefType mt = v.getType().cast<MemRefType>();
+ return VectorType::get(mt.getShape(), mt.getElementType());
+ };
auto linalgOp = cast<linalg::LinalgOp>(op);
- Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
- Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
- Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
- Value c = std_load(memref);
+ Value viewA = linalgOp.getInput(0);
+ Value viewB = linalgOp.getInput(1);
+ Value viewC = linalgOp.getOutputBuffer(0);
+ Value zero = std_constant_index(0);
+ SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+ zero);
+ Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+ indicesA);
+ Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+ indicesB);
+ Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+ indicesC);
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
- std_store(res, memref);
+ vector_transfer_write(res, viewC, indicesC);
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 4c46c74fe490..41fa3fd95d93 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -106,14 +106,11 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
return
}
// CHECK-LABEL: func @vectorization_test
-// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
More information about the Mlir-commits
mailing list