[Mlir-commits] [mlir] 771f575 - [mlir][vector] Add pattern to distribute masked reads (#71610)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 9 06:24:30 PST 2023


Author: Quinn Dawkins
Date: 2023-11-09T09:24:26-05:00
New Revision: 771f5759df7429e0431b3f5a1c661302fb75cbb7

URL: https://github.com/llvm/llvm-project/commit/771f5759df7429e0431b3f5a1c661302fb75cbb7
DIFF: https://github.com/llvm/llvm-project/commit/771f5759df7429e0431b3f5a1c661302fb75cbb7.diff

LOG: [mlir][vector] Add pattern to distribute masked reads (#71610)

Because the distribution is based on types, supporting general masked
reads requires first materializing the permutation map in IR to align
the elements of the mask with the elements read by the transfer op. For
now just support cases with the trivial permutation map.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..1975ba9c92d9988 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -818,15 +818,38 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto distributedType = cast<VectorType>(distributedVal.getType());
     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
+
+    // Distribute the mask if present.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    WarpExecuteOnLane0Op newWarpOp = warpOp;
+    Value newMask = read.getMask();
+    if (read.getMask()) {
+      // TODO: Distribution of masked reads with non-trivial permutation maps
+      // requires the distribution of the mask to elementwise match the
+      // distribution of the permuted written vector. Currently the details
+      // of which lane is responsible for which element is captured strictly
+      // by shape information on the warp op, and thus requires materializing
+      // the permutation in IR.
+      if (!read.getPermutationMap().isMinorIdentity())
+        return failure();
+      VectorType maskType =
+          getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
+      SmallVector<size_t> newRetIndices;
+      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, ValueRange{read.getMask()}, TypeRange{maskType},
+          newRetIndices);
+      newMask = newWarpOp.getResult(newRetIndices[0]);
+      distributedVal = newWarpOp.getResult(operandIndex);
+    }
+
+    rewriter.setInsertionPointAfter(newWarpOp);
 
     // Try to delinearize the lane ID to match the rank expected for
     // distribution.
     SmallVector<Value> delinearizedIds;
     if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
-                           distributedType.getShape(), warpOp.getWarpSize(),
-                           warpOp.getLaneid(), delinearizedIds))
+                           distributedType.getShape(), newWarpOp.getWarpSize(),
+                           newWarpOp.getLaneid(), delinearizedIds))
       return rewriter.notifyMatchFailure(
           read, "cannot delinearize lane ID for distribution");
     assert(!delinearizedIds.empty() || map.getNumResults() == 0);
@@ -846,7 +869,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     }
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
-        read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
+        read.getPermutationMapAttr(), read.getPadding(), newMask,
         read.getInBoundsAttr());
 
     // Check that the produced operation is legal.
@@ -854,18 +877,19 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // warpOp's body, which is illegal.
     // We do the check late because incdices may be changed by
     // makeComposeAffineApply. This rewrite may remove dependencies from
-    // warOp's body.
-    // E.g., warop {
+    // warpOp's body.
+    // E.g., warpop {
     //   %idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%idx]
     // }
     // will be rewritten in:
-    // warop {
+    // warpop {
     // }
     //  %new_idx = affine.apply...[%outsideDef]
     //   ... = transfer_read ...[%new_idx]
     if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
-          return warpOp.isDefinedOutsideOfRegion(value);
+          return (newRead.getMask() && value == newRead.getMask()) ||
+                 newWarpOp.isDefinedOutsideOfRegion(value);
         }))
       return failure();
 

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..2a1007fbbe86435 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1281,3 +1281,33 @@ func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>)
 //       CHECK-DIST-AND-PROP:   }
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#2, {{.*}}, %[[W]]#3 {in_bounds = [true]} : vector<128xf32>, memref<4096xf32>
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]]#0, {{.*}}, %[[W]]#1 {in_bounds = [true]} : vector<1xf32>, memref<4096xf32>
+
+// -----
+
+func.func @warp_propagate_masked_transfer_read(%laneid: index, %src: memref<4096x4096xf32>, %index: index) -> (vector<2xf32>, vector<2x2xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<2xf32>, vector<2x2xf32>) {
+    %mask = "mask_def_0"() : () -> (vector<128xi1>)
+    %0 = vector.transfer_read %src[%c0, %index], %f0, %mask {in_bounds = [true]} : memref<4096x4096xf32>, vector<128xf32>
+    %mask2 = "mask_def_1"() : () -> (vector<128x2xi1>)
+    %1 = vector.transfer_read %src[%c0, %index], %f0, %mask2 {in_bounds = [true, true]} : memref<4096x4096xf32>, vector<128x2xf32>
+    vector.yield %0, %1 : vector<128xf32>, vector<128x2xf32>
+  }
+  return %r#0, %r#1 : vector<2xf32>, vector<2x2xf32>
+}
+
+//   CHECK-PROP-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 2)>
+//   CHECK-PROP-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)>
+// CHECK-PROP-LABEL: func.func @warp_propagate_masked_transfer_read
+//  CHECK-PROP-SAME:   %[[ARG0:.+]]: index, {{.*}}, %[[ARG2:.+]]: index
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[64] -> (vector<2xi1>, vector<2x2xi1>) {
+//       CHECK-PROP:     %[[M0:.*]] = "mask_def_0"
+//       CHECK-PROP:     %[[M1:.*]] = "mask_def_1"
+//       CHECK-PROP:     vector.yield %[[M0]], %[[M1]] : vector<128xi1>, vector<128x2xi1>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[DIST_READ_IDX0:.+]] = affine.apply #[[$MAP0]]()[%[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[DIST_READ_IDX0]], %[[ARG2]]], {{.*}}, %[[R]]#1 {{.*}} vector<2x2xf32>
+//       CHECK-PROP:   %[[DIST_READ_IDX1:.+]] = affine.apply #[[$MAP1]]()[%[[ARG2]], %[[ARG0]]]
+//       CHECK-PROP:   vector.transfer_read {{.*}}[%[[C0]], %[[DIST_READ_IDX1]]], {{.*}}, %[[R]]#0 {{.*}} vector<2xf32>


        


More information about the Mlir-commits mailing list