[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)

Jianhui Li llvmlistbot at llvm.org
Thu Nov 20 11:15:12 PST 2025


================
@@ -1471,6 +1485,234 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    auto extractOp =
+        cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+    unsigned operandIdx = operand->getOperandNumber();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    // Find the distributed dimension. There should be exactly one.
+    auto yieldedType = cast<VectorType>(operand->get().getType());
+    auto distributedDims = getDistributedDims(yieldedType, distributedType);
+    // Only single dimension distribution is supported.
+    if (distributedDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t distributedDim = distributedDims[0];
+    // Check if the distributed dimension is fully extracted. If so, we exit
+    // early becuase this case already handled by vector distribution patterns.
+    // Distributed dimension is fully extracted if:
+    //  1) Distributed dim comes after all the extracted dimensions.
+    //  2) Or, the size extacted along the distributed dimension is equal the
+    //  size of that dim in source vector.
+    auto extractedSizes = extractOp.getSizes();
+    if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Distributed dimension is fully extracted, skipping.");
+
+    int distrDimExtractedSize =
+        cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
+    int sourceDistrDimSize =
+        extractOp.getSourceVectorType().getShape()[distributedDim];
+    if (distrDimExtractedSize == sourceDistrDimSize)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Distributed dimension is fully extracted, skipping.");
+
+    auto sourceLayout =
+        xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+    if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+      return rewriter.notifyMatchFailure(
+          warpOp, "the source of extract_strided_slice op lacks distribution "
+                  "layout");
+    auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+    // Because only single dimension distribution is supported, lane layout size
+    // at the distributed dim must be the subgroup size.
+    int subgroupSize = sourceLaneLayout[distributedDim];
+    // Check if the source size in the distributed dimension is a multiple of
+    // subgroup size.
+    if (sourceDistrDimSize % subgroupSize != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Source size along distributed dimension is not a multiple of "
+          "subgroup size.");
+    auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+    // We expect lane data to be all ones in this case.
+    if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting unit lane data in source layout");
----------------
Jianhui-Li wrote:

Why this check is necessary?  the extraction should work if lane data is not unit.  consider use case like: load_nd a fp8 matrix A using array block load. 

https://github.com/llvm/llvm-project/pull/168626


More information about the Mlir-commits mailing list