[Mlir-commits] [mlir] 29d1fba - [mlir][vector] Make linalg FillOp vectorization use Transfer op

Thomas Raoux llvmlistbot at llvm.org
Tue Nov 3 14:40:55 PST 2020


Author: Thomas Raoux
Date: 2020-11-03T14:35:26-08:00
New Revision: 29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef

URL: https://github.com/llvm/llvm-project/commit/29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef
DIFF: https://github.com/llvm/llvm-project/commit/29d1fba7b5335d969e3e5daa84b7a25cd1fa75ef.diff

LOG: [mlir][vector] Make linalg FillOp vectorization use Transfer op

Differential Revision: https://reviews.llvm.org/D90474

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4430c34af1e9..8860674ef847 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -106,17 +106,6 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
   (void)dbgPref;
   edsc::ScopedContext scope(builder, op->getLoc());
-  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
-    // Vectorize fill as a vector.broadcast.
-    LLVM_DEBUG(dbgs() << dbgPref
-                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
-    Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
-    Value dst = std_load(memref);
-    Value res = vector_broadcast(dst.getType(), fillOp.value());
-    std_store(res, memref);
-    return;
-  }
-
   // In the case of 0-D memrefs, return null and special case to scalar load or
   // store later.
   auto extractVectorTypeFromScalarView = [](Value v) {
@@ -125,7 +114,24 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                ? VectorType()
                : VectorType::get(mt.getShape(), mt.getElementType());
   };
-
+  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+    // Vectorize fill as a vector.broadcast.
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
+    Value viewOutput = fillOp.output();
+    if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
+      auto vecType =
+          VectorType::get(fillOp.getOutputBufferType(0).getShape(),
+                          fillOp.getOutputBufferType(0).getElementType());
+      Value vector = vector_broadcast(vecType, fillOp.value());
+      Value zero = std_constant_index(0);
+      SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
+      vector_transfer_write(vector, viewOutput, indicesOutput);
+    } else {
+      std_store(fillOp.value(), viewOutput);
+    }
+    return;
+  }
   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
     // Vectorize copy as a vector.transfer_read+vector.transfer_write.
     LLVM_DEBUG(dbgs() << dbgPref

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index bc3b7477885a..155247a53806 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -13,9 +13,9 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
 }
 
 // CHECK-LABEL:func @matmul
-//      CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
-//      CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
-//      CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
+//      CHECK: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+//      CHECK: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32>
+//      CHECK: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32>
 //
 //      CHECK: linalg.copy
 //      CHECK: linalg.copy

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 1d9c4f9bdb66..9bdc4ad54826 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -157,7 +157,16 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   return
 }
 // CHECK-LABEL: func @test_vectorize_fill
-//       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+//       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+  linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<f32>, f32
+  return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+//  CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+//       CHECK:   store %[[V]], %[[M]][] : memref<f32>
 
 func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
   linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, memref<8x16xf32>


        


More information about the Mlir-commits mailing list