[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)
Jianhui Li
llvmlistbot at llvm.org
Thu Nov 20 11:15:12 PST 2025
================
@@ -1471,6 +1485,234 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimension. There should be exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ auto distributedDims = getDistributedDims(yieldedType, distributedType);
+ // Only single dimension distribution is supported.
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting source to be distributed in a single dimension.");
+ int64_t distributedDim = distributedDims[0];
+ // Check if the distributed dimension is fully extracted. If so, we exit
+ // early becuase this case already handled by vector distribution patterns.
+ // Distributed dimension is fully extracted if:
+ // 1) Distributed dim comes after all the extracted dimensions.
+ // 2) Or, the size extacted along the distributed dimension is equal the
+ // size of that dim in source vector.
+ auto extractedSizes = extractOp.getSizes();
+ if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
----------------
Jianhui-Li wrote:
hard to reasoning without a code example. the condition reads to me like the extracted shape can be nD, but the distributedDim is only 1D?
https://github.com/llvm/llvm-project/pull/168626
More information about the Mlir-commits
mailing list