[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Nov 25 07:05:58 PST 2025
================
@@ -1469,6 +1485,227 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimension. There should be exactly one.
+ auto extractResultType = cast<VectorType>(operand->get().getType());
+ auto distributedDims =
+ getDistributedDims(extractResultType, distributedType);
+ // Collect updated source type, sizes and offsets. They may be adjusted
+ // later if the data is distributed to lanes (as opposed to being owned by
+ // all lanes uniformly).
+ VectorType updatedSourceType = extractOp.getSourceVectorType();
+ SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+ extractOp.getSizes(), [](Attribute attr) { return attr; });
+ SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+ extractOp.getOffsets(), [](Attribute attr) { return attr; });
+ // If the result is distributed, it must be distributed in exactly one
+ // dimension. In this case, we adjust the sourceDistType, distributedSizes
+ // and distributedOffsets accordingly.
+ if (distributedDims.size() > 0) {
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Source can not be distributed in multiple dimensions.");
+ int64_t distributedDim = distributedDims[0];
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout
+ // size at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
+ // The offsets in the distributed dimention must be a multiple of subgroup
+ // size.
+ int64_t distrDimOffset =
+ cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+ if (distrDimOffset % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Offset along distributed dimension "
+ "is not a multiple of subgroup size.");
+ updatedSourceType = getDistVecTypeBasedOnLaneLayout(
+ sourceLayout, extractOp.getSourceVectorType())
+ .value();
+ // Update the distributed sizes to match the distributed type.
+ updatedSizes[distributedDim] = rewriter.getI64IntegerAttr(
+ distributedType.getDimSize(distributedDim));
+ // Update the distributed offsets to match round robin distribution (i.e.
+ // each lane owns data at `subgroupSize` stride given unit lane data).
+ updatedOffsets[distributedDim] =
+ rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+ }
+ // Do the distribution by yielding the source of the extract op from
+ // the warp op and creating a new extract op outside the warp op.
+ SmallVector<size_t> newRetIndices;
+ auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, {extractOp.getSource()}, {updatedSourceType},
+ newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value source = newWarpOp.getResult(newRetIndices[0]);
+ // Create a new extract op outside the warp op.
+ Value newExtractOp = vector::ExtractStridedSliceOp::create(
+ rewriter, extractOp.getLoc(), distributedType, source,
+ ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+ ArrayAttr::get(rewriter.getContext(), updatedSizes),
+ extractOp.getStrides());
----------------
akroviakov wrote:
What about `resolveDistributedTy`? How do we know it always matches?
https://github.com/llvm/llvm-project/pull/168626
More information about the Mlir-commits
mailing list