[Mlir-commits] [mlir] [mlir][vector] Distribute all non-permutation or broadcasted masked transfer reads (PR #73539)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Nov 27 08:37:17 PST 2023
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/73539
The primary difficulty with distribution of masked transfers is when the permutation map permutes the vector, in which case the distribution logic needs to make sure the correct mask elements end up with the distributed transfer. This is only tricky when the permutation map has a permutation in it, so we can relax the condition for distribution.
>From c55d118e247bffdabedd6fe247933aa74e009a22 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 22 Nov 2023 11:58:28 -0500
Subject: [PATCH] [mlir][vector] Distribute all non-permutation or broadcasted
masked transfer reads
The primary difficulty with distribution of masked transfers is when the
permutation map permutes the vector, in which case the distribution
logic needs to make sure the correct mask elements end up with the
distributed transfer. This is only tricky when the permutation map
has a permutation in it, so we can relax the condition for
distribution.
---
.../Vector/Transforms/VectorDistribute.cpp | 2 +-
.../Vector/vector-warp-distribute.mlir | 26 +++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 0ad2c71cf3a6a11..07ecd8857520338 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -837,7 +837,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
// of which lane is responsible for which element is captured strictly
// by shape information on the warp op, and thus requires materializing
// the permutation in IR.
- if (!read.getPermutationMap().isMinorIdentity())
+ if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
return failure();
VectorType maskType =
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 8056260f4610977..ab175effa3dfb80 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1351,6 +1351,32 @@ func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096
// -----
+func.func @warp_propagate_nontrivial_map_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> vector<2xf32> {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>) {
+ %mask = "mask_def_0"() : () -> (vector<128xi1>)
+ %0 = vector.transfer_read %src[%index, %c0], %f0, %mask {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<4096x4096xf32>, vector<128xf32>
+ vector.yield %0 : vector<128xf32>
+ }
+ return %r : vector<2xf32>
+}
+
+// CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+// CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_nontrivial_map_masked_transfer_read
+// CHECK-PROP-SAME: %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>) {
+// CHECK-PROP: %[[M0:.*]] = "mask_def_0"
+// CHECK-PROP: vector.yield %[[M0]] : vector<128xi1>
+// CHECK-PROP: }
+// CHECK-PROP: %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG0]]]
+// CHECK-PROP: vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[C0]]], {{.*}}, %[[R]]
+// CHECK-PROP-SAME: permutation_map = #[[$MAP1]]} {{.*}} vector<2xf32>
+
+// -----
+
func.func @warp_propagate_masked_transfer_read_shared_mask(%laneid: index, %src: memref<4096x4096xf32>, %index: index, %index2: index, %mask_ub: index) -> (vector<2xf32>, vector<2xf32>) {
%f0 = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list