[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)
Charitha Saumya
llvmlistbot at llvm.org
Thu Nov 20 11:54:42 PST 2025
================
@@ -1471,6 +1485,234 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
}
};
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+ if (!operand)
+ return failure();
+ auto extractOp =
+ cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+ unsigned operandIdx = operand->getOperandNumber();
+ auto distributedType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ // Find the distributed dimension. There should be exactly one.
+ auto yieldedType = cast<VectorType>(operand->get().getType());
+ auto distributedDims = getDistributedDims(yieldedType, distributedType);
+ // Only single dimension distribution is supported.
+ if (distributedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting source to be distributed in a single dimension.");
+ int64_t distributedDim = distributedDims[0];
+ // Check if the distributed dimension is fully extracted. If so, we exit
+ // early becuase this case already handled by vector distribution patterns.
+ // Distributed dimension is fully extracted if:
+ // 1) Distributed dim comes after all the extracted dimensions.
+ // 2) Or, the size extacted along the distributed dimension is equal the
+ // size of that dim in source vector.
+ auto extractedSizes = extractOp.getSizes();
+ if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension is fully extracted, skipping.");
+
+ int distrDimExtractedSize =
+ cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
+ int sourceDistrDimSize =
+ extractOp.getSourceVectorType().getShape()[distributedDim];
+ if (distrDimExtractedSize == sourceDistrDimSize)
+ return rewriter.notifyMatchFailure(
+ warpOp, "Distributed dimension is fully extracted, skipping.");
+
+ auto sourceLayout =
+ xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+ if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+ return rewriter.notifyMatchFailure(
+ warpOp, "the source of extract_strided_slice op lacks distribution "
+ "layout");
+ auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+ // Because only single dimension distribution is supported, lane layout size
+ // at the distributed dim must be the subgroup size.
+ int subgroupSize = sourceLaneLayout[distributedDim];
+ // Check if the source size in the distributed dimension is a multiple of
+ // subgroup size.
+ if (sourceDistrDimSize % subgroupSize != 0)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "Source size along distributed dimension is not a multiple of "
+ "subgroup size.");
+ auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+ // We expect lane data to be all ones in this case.
+ if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+ return rewriter.notifyMatchFailure(
+ warpOp, "Expecting unit lane data in source layout");
----------------
charithaintc wrote:
we can add that support in another PR. I would like to keep things simple for current use cases. This check ensures that update offset calculation is simple division of the current offset by SG size. currently we don't have any fp8 test cases.
https://github.com/llvm/llvm-project/pull/168626
More information about the Mlir-commits
mailing list