[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution patterns for vector step, shape_cast & broadcast from sg-to-wi (PR #185960)
Igor Zamyatin
llvmlistbot at llvm.org
Mon Mar 23 07:36:03 PDT 2026
================
@@ -828,22 +828,216 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
}
};
-/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
-struct SgToWiConvertLayout
- : public OpConversionPattern<xegpu::ConvertLayoutOp> {
- using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+/// Distribute a vector::StepOp with a sliced result layout.
+/// The sliced layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
+ matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto inputLayout = op.getInputLayoutAttr();
- auto targetLayout = op.getTargetLayoutAttr();
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "the result vector of the step op lacks subgroup layout");
+ auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
+ if (!sliceLayout)
+ return rewriter.notifyMatchFailure(
+ op, "the result layout must be a slice layout");
+ if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "expecting 1 dim in the effective result layout");
+
+ auto loc = op.getLoc();
+ auto stepResultVecTy = op.getResult().getType();
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+ VectorType newVecTy = wiShapeOrFailure.value();
- if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+ rewriter, loc, laneId, stepResultVecTy.getShape());
+ if (failed(laneDataBlockCoords))
return rewriter.notifyMatchFailure(
- op, "lowering incompatible convert_layout not yet supported");
+ op, "failed to compute lane data block coordinates");
+
+ auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+ auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+ assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+ newVecTy.getNumElements() / laneDataBlockLength);
+ SmallVector<Value> stepVals;
+ // For each lane_data block, reconstruct its sub-range
+ // from the range of SG-level vector.step.
+ for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+ auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+ stepVals.push_back(laneDataBlockStartCoord);
+ for (int i = 1; i < laneDataBlockLength; ++i) {
+ auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+ stepVals.push_back(arith::AddIOp::create(
+ rewriter, loc, laneDataBlockStartCoord, offset));
+ }
}
- rewriter.replaceOp(op, adaptor.getSource());
+ assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+ "Expecting the number of step values to match the number of "
+ "elements in the vector");
+ auto stepOpVal =
+ vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+ rewriter.replaceOp(op, stepOpVal);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
----------------
Garra1980 wrote:
will/should this support shape cast, say 32->4x8? If no let's make sure we gracefully fail, if yes let's have a test for such case(s)
https://github.com/llvm/llvm-project/pull/185960
More information about the Mlir-commits
mailing list