[Mlir-commits] [mlir] 29d1fba - [mlir][vector] Make linalg FillOp vectorization use Transfer op
Thomas Raoux
llvmlistbot at llvm.org
Tue Nov 3 14:40:55 PST 2020
Author: Thomas Raoux
Date: 2020-11-03T14:35:26-08:00
New Revision: 29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef
URL: https://github.com/llvm/llvm-project/commit/29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef
DIFF: https://github.com/llvm/llvm-project/commit/29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef.diff
LOG: [mlir][vector] Make linalg FillOp vectorization use Transfer op
Differential Revision: https://reviews.llvm.org/D90474
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4430c34af1e9..8860674ef847 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -106,17 +106,6 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
edsc::ScopedContext scope(builder, op->getLoc());
- if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
- // Vectorize fill as a vector.broadcast.
- LLVM_DEBUG(dbgs() << dbgPref
- << "Rewrite linalg.fill as vector.broadcast: " << *op);
- Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
- Value dst = std_load(memref);
- Value res = vector_broadcast(dst.getType(), fillOp.value());
- std_store(res, memref);
- return;
- }
-
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
auto extractVectorTypeFromScalarView = [](Value v) {
@@ -125,7 +114,24 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
? VectorType()
: VectorType::get(mt.getShape(), mt.getElementType());
};
-
+ if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+ // Vectorize fill as a vector.broadcast.
+ LLVM_DEBUG(dbgs() << dbgPref
+ << "Rewrite linalg.fill as vector.broadcast: " << *op);
+ Value viewOutput = fillOp.output();
+ if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
+ auto vecType =
+ VectorType::get(fillOp.getOutputBufferType(0).getShape(),
+ fillOp.getOutputBufferType(0).getElementType());
+ Value vector = vector_broadcast(vecType, fillOp.value());
+ Value zero = std_constant_index(0);
+ SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
+ vector_transfer_write(vector, viewOutput, indicesOutput);
+ } else {
+ std_store(fillOp.value(), viewOutput);
+ }
+ return;
+ }
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
LLVM_DEBUG(dbgs() << dbgPref
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index bc3b7477885a..155247a53806 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -13,9 +13,9 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
}
// CHECK-LABEL:func @matmul
-// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
-// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
-// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
+// CHECK: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+// CHECK: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32>
+// CHECK: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32>
//
// CHECK: linalg.copy
// CHECK: linalg.copy
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 1d9c4f9bdb66..9bdc4ad54826 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -157,7 +157,16 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
return
}
// CHECK-LABEL: func @test_vectorize_fill
-// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+ linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<f32>, f32
+ return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+// CHECK: store %[[V]], %[[M]][] : memref<f32>
func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32>
More information about the Mlir-commits
mailing list