[Mlir-commits] [mlir] [mlir][vector] Add pattern to distribute masked reads (PR #71610)

Lei Zhang llvmlistbot at llvm.org
Wed Nov 8 21:48:45 PST 2023


================
@@ -818,15 +818,39 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
     auto distributedType = cast<VectorType>(distributedVal.getType());
     AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
+
+    // Distribute the mask if present.
     OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointAfter(warpOp);
+    WarpExecuteOnLane0Op newWarpOp = warpOp;
+    Value newMask = read.getMask();
+    if (read.getMask()) {
+      // TODO: Distribution of masked reads with non-trivial permutation maps
+      // requires the distribution of the mask to elementwise match the
+      // distribution of the permuted written vector. Currently the details
+      // of which lane is responsible for which element is captured strictly
+      // by shape information on the warp op, and thus requires materializing
+      // the permutation in IR.
+      if (!read.getPermutationMap().isMinorIdentity())
+        return failure();
+      VectorType maskType =
+          getDistributedType(read.getMask().getType().cast<VectorType>(), map,
----------------
antiagainst wrote:

`getMaskType`

https://github.com/llvm/llvm-project/pull/71610


More information about the Mlir-commits mailing list