[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (PR #157554)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Sep 16 02:25:36 PDT 2025
================
@@ -1027,6 +1027,62 @@ struct WgToSgVectorShapeCastOp
}
};
+// Pattern for lowering vector.multi_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+ : public OpConversionPattern<vector::MultiDimReductionOp> {
+ using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+ VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!srcType || !dstType)
+ return failure();
+
+ // TODO: generalize it
+ auto srcShape = srcType.getShape();
+ auto dstShape = dstType.getShape();
+ if (srcShape.size() != 2 || dstShape.size() != 1)
+ return failure();
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ auto reductionDims = op.getReductionDims();
+ if (reductionDims.size() != 1)
+ return failure();
+
+ SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
+ .getParent()
+ .getEffectiveSgLayoutAsInt();
+ // Check that the sgLayout in the reduced dimension is 1.
+ if (sgLayout[reductionDims[0]] != 1)
+ return failure();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+
+ VectorType newDstType =
+ VectorType::get({sgShape}, dstType.getElementType());
+
+ SmallVector<Value> newReductions;
+ for (auto [sgSrc, sgAcc] :
+ llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+ auto newOp = rewriter.create<vector::MultiDimReductionOp>(
----------------
adam-smnk wrote:
nit: use the new API `vector::MultiDimReductionOp::create`
https://github.com/llvm/llvm-project/pull/157554
More information about the Mlir-commits
mailing list