[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (PR #157554)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Sep 16 02:25:36 PDT 2025


================
@@ -1027,6 +1027,62 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
+// Pattern for lowering vector.multi_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+    VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!srcType || !dstType)
+      return failure();
+
+    // TODO: generalize it
+    auto srcShape = srcType.getShape();
+    auto dstShape = dstType.getShape();
+    if (srcShape.size() != 2 || dstShape.size() != 1)
+      return failure();
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    auto reductionDims = op.getReductionDims();
+    if (reductionDims.size() != 1)
+      return failure();
+
+    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
+                                        .getParent()
+                                        .getEffectiveSgLayoutAsInt();
+    // Check that the sgLayout in the reduced dimension is 1.
+    if (sgLayout[reductionDims[0]] != 1)
+      return failure();
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+
+    VectorType newDstType =
+        VectorType::get({sgShape}, dstType.getElementType());
+
+    SmallVector<Value> newReductions;
+    for (auto [sgSrc, sgAcc] :
+         llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
----------------
adam-smnk wrote:

nit: use the new API `vector::MultiDimReductionOp::create`

https://github.com/llvm/llvm-project/pull/157554


More information about the Mlir-commits mailing list