[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution patterns for vector step, shape_cast & broadcast from sg-to-wi (PR #185960)

Igor Zamyatin llvmlistbot at llvm.org
Mon Mar 23 07:36:03 PDT 2026


================
@@ -828,22 +828,216 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   }
 };
 
-/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
-struct SgToWiConvertLayout
-    : public OpConversionPattern<xegpu::ConvertLayoutOp> {
-  using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+/// Distribute a vector::StepOp with a sliced result layout.
+/// The sliced layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
+  matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto inputLayout = op.getInputLayoutAttr();
-    auto targetLayout = op.getTargetLayoutAttr();
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "the result vector of the step op lacks subgroup layout");
+    auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
+    if (!sliceLayout)
+      return rewriter.notifyMatchFailure(
+          op, "the result layout must be a slice layout");
+    if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "expecting 1 dim in the effective result layout");
+
+    auto loc = op.getLoc();
+    auto stepResultVecTy = op.getResult().getType();
+    auto wiShapeOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+    if (failed(wiShapeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+    VectorType newVecTy = wiShapeOrFailure.value();
 
-    if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
+    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         /*upperBound=*/mlir::IntegerAttr());
+    auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+        rewriter, loc, laneId, stepResultVecTy.getShape());
+    if (failed(laneDataBlockCoords))
       return rewriter.notifyMatchFailure(
-          op, "lowering incompatible convert_layout not yet supported");
+          op, "failed to compute lane data block coordinates");
+
+    auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+    auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+    assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+           newVecTy.getNumElements() / laneDataBlockLength);
+    SmallVector<Value> stepVals;
+    // For each lane_data block, reconstruct its sub-range
+    // from the range of SG-level vector.step.
+    for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+      auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+      stepVals.push_back(laneDataBlockStartCoord);
+      for (int i = 1; i < laneDataBlockLength; ++i) {
+        auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+        stepVals.push_back(arith::AddIOp::create(
+            rewriter, loc, laneDataBlockStartCoord, offset));
+      }
     }
-    rewriter.replaceOp(op, adaptor.getSource());
+    assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+           "Expecting the number of step values to match the number of "
+           "elements in the vector");
+    auto stepOpVal =
+        vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+    rewriter.replaceOp(op, stepOpVal);
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
----------------
Garra1980 wrote:

will/should this support shape cast, say 32->4x8? If no let's make sure we gracefully fail, if yes let's have a test for such case(s)

https://github.com/llvm/llvm-project/pull/185960


More information about the Mlir-commits mailing list