[Mlir-commits] [mlir] f385f6c - [mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads (#73539)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 13:23:52 PST 2023


Author: Quinn Dawkins
Date: 2023-11-27T16:23:48-05:00
New Revision: f385f6c93b33eff211523e7838ea023d7beee3f2

URL: https://github.com/llvm/llvm-project/commit/f385f6c93b33eff211523e7838ea023d7beee3f2
DIFF: https://github.com/llvm/llvm-project/commit/f385f6c93b33eff211523e7838ea023d7beee3f2.diff

LOG: [mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads (#73539)

The primary difficulty with distribution of masked transfers is when the
permutation map permutes the vector, in which case the distribution
logic needs to make sure the correct mask elements end up with the
distributed transfer. This is only tricky when the permutation map has a
permutation in it, so we can relax the condition for distribution.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 0ad2c71cf3a6a11..07ecd8857520338 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // of which lane is responsible for which element is captured strictly
       // by shape information on the warp op, and thus requires materializing
       // the permutation in IR.
-      if (!read.getPermutationMap().isMinorIdentity())
+      if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
         return failure();
       VectorType maskType =
           getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8056260f4610977..ab175effa3dfb80 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
 
 // -----
 
+func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
+    vector.yield %0 : vector<128xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     vector.yield %[[M0]] : vector<128xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
+//  CHECK-PROP-SAME:   permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
+
+// -----
+
 func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
   %f0 = arith.constant 0.000000e+00 : f32
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list