[Mlir-commits] [mlir] 7eba5cd - [mlir][vector] Relax transfer_write vector distribution pattern

Thomas Raoux llvmlistbot at llvm.org
Fri Jun 24 12:03:35 PDT 2022


Author: Thomas Raoux
Date: 2022-06-24T19:03:14Z
New Revision: 7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3

URL: https://github.com/llvm/llvm-project/commit/7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3
DIFF: https://github.com/llvm/llvm-project/commit/7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3.diff

LOG: [mlir][vector] Relax transfer_write vector distribution pattern

Small change to relax the pattern to support any vector containing a
single element.

Differential Revision: https://reviews.llvm.org/D128545

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 08ea44225d4f..c71c31654105 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -346,8 +346,9 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
     Location loc = writeOp.getLoc();
     VectorType vecType = writeOp.getVectorType();
 
-    // Only vector<1x> is supported at the moment.
-    if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
+    // Only sink out vector of 1 element for now to not serialize large vector
+    // store. This can later be controlled by user.
+    if (vecType.getNumElements() != 1)
       return failure();
 
     // Do not process warp ops that contain only TransferWriteOps.

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index eca9ad5b512a..a2c988cae33e 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -109,22 +109,25 @@ func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>
 // -----
 
 // CHECK-D-LABEL: func @warp_extract(
-//       CHECK-D:   %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+//       CHECK-D:   %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1x1xf32>)
 //       CHECK-D:     "test.dummy_op"
-//       CHECK-D:     vector.yield %{{.*}} : vector<1xf32>
+//       CHECK-D:     "test.dummy_op"
+//       CHECK-D:     vector.yield %{{.*}}, %{{.*}} : vector<1xf32>, vector<1x1xf32>
 //       CHECK-D:   }
 //       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
-//       CHECK-D:     vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<1x1xf32>
+//       CHECK-D:   }
+//       CHECK-D:   vector.warp_execute_on_lane_0(%{{.*}})[32] {
+//       CHECK-D:     vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
 //       CHECK-D:   }
 
-#map2 =  affine_map<(d0)[s0] -> (d0 + s0)>
-
-func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
+func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
   vector.warp_execute_on_lane_0(%laneid)[32] {
-    %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
     %c0 = arith.constant 0 : index
     %v = "test.dummy_op"() : () -> (vector<1xf32>)
-    vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
+    %v1 = "test.dummy_op"() : () -> (vector<1x1xf32>)
+    vector.transfer_write %v1, %arg1[%c0, %c0] : vector<1x1xf32>, memref<1024x1024xf32>
+    vector.transfer_write %v, %arg1[%c0, %c0] : vector<1xf32>, memref<1024x1024xf32>
   }
   return
 }


        


More information about the Mlir-commits mailing list