[Mlir-commits] [mlir] 7eba5cd - [mlir][vector] Relax transfer_write vector distribution pattern
Thomas Raoux
llvmlistbot at llvm.org
Fri Jun 24 12:03:35 PDT 2022
Author: Thomas Raoux
Date: 2022-06-24T19:03:14Z
New Revision: 7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3
URL: https://github.com/llvm/llvm-project/commit/7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3
DIFF: https://github.com/llvm/llvm-project/commit/7eba5cdf9ce4214a3f2a1f29cb5790c1790e66a3.diff
LOG: [mlir][vector] Relax transfer_write vector distribution pattern
Small change to relax the pattern to support any vector containing a
single element.
Differential Revision: https://reviews.llvm.org/D128545
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 08ea44225d4f..c71c31654105 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -346,8 +346,9 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();
- // Only vector<1x> is supported at the moment.
- if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
+ // Only sink out vector of 1 element for now to not serialize large vector
+ // store. This can later be controlled by user.
+ if (vecType.getNumElements() != 1)
return failure();
// Do not process warp ops that contain only TransferWriteOps.
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index eca9ad5b512a..a2c988cae33e 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -109,22 +109,25 @@ func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>
// -----
// CHECK-D-LABEL: func @warp_extract(
-// CHECK-D: %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
+// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1x1xf32>)
// CHECK-D: "test.dummy_op"
-// CHECK-D: vector.yield %{{.*}} : vector<1xf32>
+// CHECK-D: "test.dummy_op"
+// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<1xf32>, vector<1x1xf32>
// CHECK-D: }
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
-// CHECK-D: vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
+// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<1x1xf32>
+// CHECK-D: }
+// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
+// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
// CHECK-D: }
-#map2 = affine_map<(d0)[s0] -> (d0 + s0)>
-
-func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
+func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
- %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
%c0 = arith.constant 0 : index
%v = "test.dummy_op"() : () -> (vector<1xf32>)
- vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
+ %v1 = "test.dummy_op"() : () -> (vector<1x1xf32>)
+ vector.transfer_write %v1, %arg1[%c0, %c0] : vector<1x1xf32>, memref<1024x1024xf32>
+ vector.transfer_write %v, %arg1[%c0, %c0] : vector<1xf32>, memref<1024x1024xf32>
}
return
}
More information about the Mlir-commits
mailing list