[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (PR #155443)
Chao Chen
llvmlistbot at llvm.org
Thu Sep 4 08:09:41 PDT 2025
================
@@ -919,6 +916,90 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
}
};
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+ std::optional<SmallVector<int64_t>> sgShape =
+ getSgShapeAndCount(wgShape, layout).first;
+ if (!sgShape)
+ return failure();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+ : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newShapeCastOps;
+ for (auto src : adaptor.getSource()) {
+ auto newShapeCast =
+ rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
----------------
chencha3 wrote:
Same here
https://github.com/llvm/llvm-project/pull/155443
More information about the Mlir-commits
mailing list