[Mlir-commits] [mlir] [mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads (PR #73539)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 08:37:44 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

The primary difficulty with distribution of masked transfers is when the permutation map permutes the vector, in which case the distribution logic needs to make sure the correct mask elements end up with the distributed transfer. This is only tricky when the permutation map has a permutation in it, so we can relax the condition for distribution.

---
Full diff: https://github.com/llvm/llvm-project/pull/73539.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+1-1) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 0ad2c71cf3a6a11..07ecd8857520338 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // of which lane is responsible for which element is captured strictly
       // by shape information on the warp op, and thus requires materializing
       // the permutation in IR.
-      if (!read.getPermutationMap().isMinorIdentity())
+      if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
         return failure();
       VectorType maskType =
           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8056260f4610977..ab175effa3dfb80 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 
 // -----
 
+func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
+    vector.yield %0 : vector<128xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     vector.yield %[[M0]] : vector<128xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
+//  CHECK-PROP-SAME:   permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
+
+// -----
+
 func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
   %f0 = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/73539


More information about the Mlir-commits mailing list