[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)
Jianhui Li
llvmlistbot at llvm.org
Tue Jun 24 11:30:49 PDT 2025
================
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
};
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+/// gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Distributed type must be 2D or higher.
+ // TODO: Support 1D distributed types.
+ if (distributedType.getRank() < 2)
+ return rewriter.notifyMatchFailure(
+ insertOp, "result vector type must be 2D or higher");
+ // Find the distributed dimension of the dest vector. There should be
+ // exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ int64_t destDistributedDim =
+ getDistributedDim(yieldedType, distributedType);
+ assert(destDistributedDim != -1 && "could not find distributed dimension");
+ (void)destDistributedDim;
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ // TODO: Add support for case where source vector is not distributed.
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp, "distributed dimension must be in the last k dims");
+ // Distributed dimension must be fully inserted.
+ if (srcType.getDimSize(sourceDistributedDim) !=
----------------
Jianhui-Li wrote:
What is the reason we disallow distributing the following case? I think the distribution should work as long as offsets are multiple of subgroup size.
/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 32],
/// strides = [1, 1] : vector<8x32xf32> into vector<8x64xf32>
https://github.com/llvm/llvm-project/pull/145421
More information about the Mlir-commits
mailing list