[Mlir-commits] [mlir] [mlir][xegpu] Add layout based SIMT distribution support for `vector.extract/insert_strided_slice` (PR #168626)

Artem Kroviakov llvmlistbot at llvm.org
Tue Nov 25 07:05:57 PST 2025


================
@@ -1469,6 +1485,227 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+// Distribute a `vector.extract_strided_slice` op feeding into yield op of an
+// enclosing `gpu.warp_execute_on_lane_0` region. This pattern only handles
+// advanced cases where the distributed is partially extracted and currently not
+// supported by the generic vector distribution patterns.
+struct VectorExtractStridedSliceDistribution
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
+    if (!operand)
+      return failure();
+    auto extractOp =
+        cast<vector::ExtractStridedSliceOp>(operand->get().getDefiningOp());
+    unsigned operandIdx = operand->getOperandNumber();
+    auto distributedType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    // Find the distributed dimension. There should be exactly one.
+    auto extractResultType = cast<VectorType>(operand->get().getType());
+    auto distributedDims =
+        getDistributedDims(extractResultType, distributedType);
+    // Collect updated source type, sizes and offsets. They may be adjusted
+    // later if the data is distributed to lanes (as opposed to being owned by
+    // all lanes uniformly).
+    VectorType updatedSourceType = extractOp.getSourceVectorType();
+    SmallVector<Attribute> updatedSizes = llvm::map_to_vector(
+        extractOp.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        extractOp.getOffsets(), [](Attribute attr) { return attr; });
+    // If the result is distributed, it must be distributed in exactly one
+    // dimension. In this case, we adjust the sourceDistType, distributedSizes
+    // and distributedOffsets accordingly.
+    if (distributedDims.size() > 0) {
+      if (distributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Source can not be distributed in multiple dimensions.");
+      int64_t distributedDim = distributedDims[0];
+      int sourceDistrDimSize =
+          extractOp.getSourceVectorType().getShape()[distributedDim];
+      auto sourceLayout =
+          xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
+      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            warpOp, "the source of extract_strided_slice op lacks distribution "
+                    "layout");
+      auto sourceLaneLayout = sourceLayout.getEffectiveLaneLayoutAsInt();
+      // Because only single dimension distribution is supported, lane layout
+      // size at the distributed dim must be the subgroup size.
+      int subgroupSize = sourceLaneLayout[distributedDim];
+      // Check if the source size in the distributed dimension is a multiple of
+      // subgroup size.
+      if (sourceDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Source size along distributed dimension is not a multiple of "
+            "subgroup size.");
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // We expect lane data to be all ones in this case.
+      if (!llvm::all_of(sourceLaneData, [](int64_t v) { return v == 1; }))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Expecting unit lane data in source layout");
+      // The offsets in the distributed dimention must be a multiple of subgroup
+      // size.
+      int64_t distrDimOffset =
+          cast<IntegerAttr>(extractOp.getOffsets()[distributedDim]).getInt();
+      if (distrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Offset along distributed dimension "
+                    "is not a multiple of subgroup size.");
+      updatedSourceType = getDistVecTypeBasedOnLaneLayout(
----------------
akroviakov wrote:

I suppose the above checks ensure the distribution always succeeds.

https://github.com/llvm/llvm-project/pull/168626


More information about the Mlir-commits mailing list