[Mlir-commits] [mlir] [mlir][vector] Add pattern to distribute masked reads (PR #71610)
Lei Zhang
llvmlistbot at llvm.org
Wed Nov 8 21:48:45 PST 2023
================
@@ -818,15 +818,39 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto distributedType = cast<VectorType>(distributedVal.getType());
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
+
+ // Distribute the mask if present.
OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPointAfter(warpOp);
+ WarpExecuteOnLane0Op newWarpOp = warpOp;
+ Value newMask = read.getMask();
+ if (read.getMask()) {
+ // TODO: Distribution of masked reads with non-trivial permutation maps
+ // requires the distribution of the mask to elementwise match the
+ // distribution of the permuted written vector. Currently the details
+ // of which lane is responsible for which element is captured strictly
+ // by shape information on the warp op, and thus requires materializing
+ // the permutation in IR.
+ if (!read.getPermutationMap().isMinorIdentity())
+ return failure();
+ VectorType maskType =
+ getDistributedType(read.getMask().getType().cast<VectorType>(), map,
----------------
antiagainst wrote:
`getMaskType`
https://github.com/llvm/llvm-project/pull/71610
More information about the Mlir-commits
mailing list