[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)

Charitha Saumya llvmlistbot at llvm.org
Wed Jun 25 12:33:32 PDT 2025


================
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
   }
 };
 
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+///     strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+///   gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+///   ...
+///   %src = ... : vector<4x16xf32>
+///   %dest = ... : vector<8x16xf32>
+///   gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+///   offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp =
+        operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          insertOp, "result vector type must be 2D or higher");
+    // Find the distributed dimension of the dest vector. There should be
+    // exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t destDistributedDim =
+        getDistributedDim(yieldedType, distributedType);
+    assert(destDistributedDim != -1 && "could not find distributed dimension");
+    (void)destDistributedDim;
+    VectorType srcType = insertOp.getSourceVectorType();
+    VectorType destType = insertOp.getDestVectorType();
+    // Currently we require that both source (kD) and dest (nD) vectors are
+    // distributed. This requires that distributedDim (d) is contained in the
+    // last k dims of the dest vector (d >= n - k).
+    // TODO: Add support for case where source vector is not distributed.
+    int64_t sourceDistributedDim =
+        destDistributedDim - (destType.getRank() - srcType.getRank());
+    if (sourceDistributedDim < 0)
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be in the last k dims");
+    // Distributed dimension must be fully inserted.
+    if (srcType.getDimSize(sourceDistributedDim) !=
+        destType.getDimSize(destDistributedDim))
+      return rewriter.notifyMatchFailure(
+          insertOp, "distributed dimension must be fully inserted");
+    SmallVector<int64_t> newSourceDistShape(
+        insertOp.getSourceVectorType().getShape()),
+        newDestDistShape(insertOp.getDestVectorType().getShape());
+    newSourceDistShape[sourceDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    newDestDistShape[destDistributedDim] =
+        distributedType.getDimSize(destDistributedDim);
+    auto newSourceTy =
+        VectorType::get(newSourceDistShape, distributedType.getElementType());
+    auto newDestTy =
+        VectorType::get(newDestDistShape, distributedType.getElementType());
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+        {newSourceTy, newDestTy}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+    auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+    // Create a new insert strided slice op that inserts distributed source into
+    // distributed dest.
+    Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+        insertOp.getLoc(), distributedDest.getType(), distributedSource,
+        distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+    rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+    return success();
+  }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+///     strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+///   gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+///   ...
+///   %src = ... : vector<32x16xf32>
+///   gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+///   strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp =
+        operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    // Distributed type must be 2D or higher.
+    // TODO: Support 1D distributed types.
+    if (distributedType.getRank() < 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "result vector type must be 2D or higher");
+
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+    assert(distributedDim != -1 && "could not find distributed dimension");
+    (void)distributedDim;
+
+    // Distributed dimension must be fully extracted.
+    // TODO: Partial extraction from distributed dimension require cross lane
+    // communication.
+    if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
----------------
charithaintc wrote:

I felt like `numOfExtractedDims` is a more appropriate name. so changed it again. 

https://github.com/llvm/llvm-project/pull/145421


More information about the Mlir-commits mailing list