[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)

Jianhui Li llvmlistbot at llvm.org
Tue Jun 24 11:30:49 PDT 2025


================
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+///   gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          insertOp, "result vector type must be 2D or higher");
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t destDistributedDim =
+        getDistributedDim(yieldedType, distributedType);
+    assert(destDistributedDim != -1 && "could not find distributed dimension");
+    (void)destDistributedDim;
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    // TODO: Add support for case where source vector is not distributed.
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k dims");
+    // Distributed dimension must be fully inserted.
+    if (srcType.getDimSize(sourceDistributedDim) !=
----------------
Jianhui-Li wrote:

What is the reason we disallow distributing the following case?  I think the distribution should work as long as offsets are multiple of subgroup size. 
///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 32],
///     strides = [1, 1] : vector<8x32xf32> into vector<8x64xf32>

https://github.com/llvm/llvm-project/pull/145421


More information about the Mlir-commits mailing list