[Mlir-commits] [mlir] [mlir][vector] Add support for vector extract/insert_strided_slice in vector distribution. (PR #145421)
Charitha Saumya
llvmlistbot at llvm.org
Wed Jun 25 12:33:32 PDT 2025
================
@@ -1076,6 +1094,195 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
}
};
+/// Sink out insert_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
+/// strides = [1, 1] : vector<4x16xf32> into vector<8x16xf32>
+/// gpu.yield %insert : vector<8x16xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
+/// vector<8x1xf32>) {
+/// ...
+/// %src = ... : vector<4x16xf32>
+/// %dest = ... : vector<8x16xf32>
+/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
+/// }
+/// %insert = vector.insert_strided_slice %0#0, %0#1,
+/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
+/// ```
+/// NOTE: Current support assume that both src and dest vectors are distributed
+/// to lanes and sinking the insert op does not require any cross lane
+/// communication.
+struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto insertOp =
+ operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Distributed type must be 2D or higher.
+ // TODO: Support 1D distributed types.
+ if (distributedType.getRank() < 2)
+ return rewriter.notifyMatchFailure(
+ insertOp, "result vector type must be 2D or higher");
+ // Find the distributed dimension of the dest vector. There should be
+ // exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ int64_t destDistributedDim =
+ getDistributedDim(yieldedType, distributedType);
+ assert(destDistributedDim != -1 && "could not find distributed dimension");
+ (void)destDistributedDim;
+ VectorType srcType = insertOp.getSourceVectorType();
+ VectorType destType = insertOp.getDestVectorType();
+ // Currently we require that both source (kD) and dest (nD) vectors are
+ // distributed. This requires that distributedDim (d) is contained in the
+ // last k dims of the dest vector (d >= n - k).
+ // TODO: Add support for case where source vector is not distributed.
+ int64_t sourceDistributedDim =
+ destDistributedDim - (destType.getRank() - srcType.getRank());
+ if (sourceDistributedDim < 0)
+ return rewriter.notifyMatchFailure(
+ insertOp, "distributed dimension must be in the last k dims");
+ // Distributed dimension must be fully inserted.
+ if (srcType.getDimSize(sourceDistributedDim) !=
+ destType.getDimSize(destDistributedDim))
+ return rewriter.notifyMatchFailure(
+ insertOp, "distributed dimension must be fully inserted");
+ SmallVector<int64_t> newSourceDistShape(
+ insertOp.getSourceVectorType().getShape()),
+ newDestDistShape(insertOp.getDestVectorType().getShape());
+ newSourceDistShape[sourceDistributedDim] =
+ distributedType.getDimSize(destDistributedDim);
+ newDestDistShape[destDistributedDim] =
+ distributedType.getDimSize(destDistributedDim);
+ auto newSourceTy =
+ VectorType::get(newSourceDistShape, distributedType.getElementType());
+ auto newDestTy =
+ VectorType::get(newDestDistShape, distributedType.getElementType());
+ SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
+ {newSourceTy, newDestTy}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto distributedSource = newWarpOp->getResult(newRetIndices[0]);
+ auto distributedDest = newWarpOp->getResult(newRetIndices[1]);
+ // Create a new insert strided slice op that inserts distributed source into
+ // distributed dest.
+ Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
+ insertOp.getLoc(), distributedDest.getType(), distributedSource,
+ distributedDest, insertOp.getOffsets(), insertOp.getStrides());
+ rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
+ return success();
+ }
+};
+
+/// Sink out extract_strided_slice op feeding into a warp op yield.
+/// ```
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
+/// ...
+/// %src = ... : vector<32x16xf32>
+/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
+/// strides = [1] : vector<32x16xf32> to vector<16x16xf32>
+/// gpu.yield %extract : vector<16x16xf32>
+/// }
+/// ```
+/// To
+/// ````
+/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<32x1xf32>) {
+/// ...
+/// %src = ... : vector<32x16xf32>
+/// gpu.yield %src : vector<32x16xf32>
+/// }
+/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
+/// strides = [1] : vector<32x1xf32> to vector<16x1xf32>
+/// ```
+/// NOTE: Current support assumes that the extraction happens only on non
+/// distributed dimensions (does not require cross lane communication).
+struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
+ using Base::Base;
+ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ unsigned int operandNumber = operand->getOperandNumber();
+ auto extractOp =
+ operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ // Distributed type must be 2D or higher.
+ // TODO: Support 1D distributed types.
+ if (distributedType.getRank() < 2)
+ return rewriter.notifyMatchFailure(
+ extractOp, "result vector type must be 2D or higher");
+
+ // Find the distributed dimension. There should be exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
+ assert(distributedDim != -1 && "could not find distributed dimension");
+ (void)distributedDim;
+
+ // Distributed dimension must be fully extracted.
+ // TODO: Partial extraction from distributed dimension require cross lane
+ // communication.
+ if (distributedDim < static_cast<int64_t>(extractOp.getSizes().size())) {
----------------
charithaintc wrote:
I felt like `numOfExtractedDims` is a more appropriate name. so changed it again.
https://github.com/llvm/llvm-project/pull/145421
More information about the Mlir-commits
mailing list