[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (PR #157554)

Nishant Patel llvmlistbot at llvm.org
Wed Sep 24 10:46:57 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/157554

>From d55504607f2a3341146119b0cd893474bc5e583a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 8 Sep 2025 18:05:00 +0000
Subject: [PATCH 1/6] Add pattern for reduction

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 67 ++++++++++++++++++-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 28 ++++++++
 2 files changed, 92 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5d0f1d18402f2..fab2b8773a6b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -757,8 +757,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (!layout.getLaneLayoutAsInt().empty() ||
+        !layout.getLaneDataAsInt().empty())
+      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -919,6 +921,59 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   }
 };
 
+// Pattern for lowering vector.multi_reduction op to subgroup level.
+struct WgToSgMultiDimReductionOp
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+    VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!srcType || !dstType)
+      return failure();
+
+    // Only handle [m,1]->[m] or [1,m]->[m]
+    // TODO: generalize it
+    auto srcShape = srcType.getShape();
+    auto dstShape = dstType.getShape();
+    if (srcShape.size() != 2 || dstShape.size() != 1)
+      return failure();
+
+    if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) ||
+          (srcShape[0] == 1 && srcShape[1] == dstShape[0])))
+      return failure();
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getSource());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
+    VectorType newDstType;
+    if (op.getReductionDims() == ArrayRef<int64_t>({0}))
+      newDstType = VectorType::get({sgShape[1]}, dstType.getElementType());
+    else
+      newDstType = VectorType::get({sgShape[0]}, dstType.getElementType());
+
+    SmallVector<Value> newReductions;
+    for (auto [sgSrc, sgAcc] :
+         llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
+      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+          op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
+          op.getReductionDims());
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getLaneDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newOp->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newReductions.push_back(newOp.getResult());
+    }
+    rewriter.replaceOpWithMultiple(op, {newReductions});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -932,7 +987,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp>(patterns.getContext());
+           WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1077,6 +1133,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index afb2bf876c18f..47e6f4cfd6d08 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -365,4 +365,32 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: @vector_reduce_dim_0
+  gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+      -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+      -> vector<1x128xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
+      : vector<1x128xf32> to vector<128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @vector_reduce_dim_1
+  gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+      -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+      -> vector<256x1xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
+      : vector<256x1xf32> to vector<256xf32>
+    gpu.return
+  }
 }

>From 924a5e1c444f837ef6f808dd2d2dc56b7aa80708 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 12 Sep 2025 20:10:05 +0000
Subject: [PATCH 2/6] Fix

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp    | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index efa39cbb43255..0aafb0139a095 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -757,8 +757,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getLaneLayoutAsInt().empty() ||
-        !layout.getLaneDataAsInt().empty())
+    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+        !layout.getEffectiveInstDataAsInt().empty())
       xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
                                      layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
@@ -963,8 +963,8 @@ struct WgToSgMultiDimReductionOp
       auto newOp = rewriter.create<vector::MultiDimReductionOp>(
           op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
           op.getReductionDims());
-      if (!layout.getLaneLayoutAsInt().empty() ||
-          !layout.getLaneDataAsInt().empty())
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
         xegpu::setDistributeLayoutAttr(newOp->getResult(0),
                                        layout.dropSgLayoutAndData());
       newReductions.push_back(newOp.getResult());

>From b4761ded9bf5a39c67f09e19821e1c4d452ab249 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 12 Sep 2025 21:28:58 +0000
Subject: [PATCH 3/6] sg local reduction

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 25 +++++++------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 36 +++++++++----------
 2 files changed, 32 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aafb0139a095..29209446c57d8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -934,28 +934,31 @@ struct WgToSgMultiDimReductionOp
     if (!srcType || !dstType)
       return failure();
 
-    // Only handle [m,1]->[m] or [1,m]->[m]
     // TODO: generalize it
     auto srcShape = srcType.getShape();
     auto dstShape = dstType.getShape();
     if (srcShape.size() != 2 || dstShape.size() != 1)
       return failure();
 
-    if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) ||
-          (srcShape[0] == 1 && srcShape[1] == dstShape[0])))
-      return failure();
-
     xegpu::DistributeLayoutAttr layout =
-        xegpu::getDistributeLayoutAttr(op.getSource());
+        xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
+    auto reductionDims = op.getReductionDims();
+    if (reductionDims.size() != 1)
+      return failure();
+
+    SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
+                                        .getParent()
+                                        .getEffectiveSgLayoutAsInt();
+    // Check that the sgLayout in the reduced dimension is 1.
+    if (sgLayout[reductionDims[0]] != 1)
+      return failure();
     SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
-    VectorType newDstType;
-    if (op.getReductionDims() == ArrayRef<int64_t>({0}))
-      newDstType = VectorType::get({sgShape[1]}, dstType.getElementType());
-    else
-      newDstType = VectorType::get({sgShape[0]}, dstType.getElementType());
+
+    VectorType newDstType =
+        VectorType::get({sgShape}, dstType.getElementType());
 
     SmallVector<Value> newReductions;
     for (auto [sgSrc, sgAcc] :
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 47e6f4cfd6d08..e6db091174dcb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -367,30 +367,30 @@ gpu.module @test_distribution {
   }
 
   // CHECK-LABEL: @vector_reduce_dim_0
-  gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
-    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
-    %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
-      -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+  gpu.func @vector_reduce_dim_0(%src: memref<4x128xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<4x128xf32>
+      -> !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
     %load =  xegpu.load_nd %tdesc[0, 0]
-      : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
-      -> vector<1x128xf32>
-    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
-    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
-      : vector<1x128xf32> to vector<128xf32>
+      : !xegpu.tensor_desc<4x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>>
+      -> vector<4x128xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<4x4xf32> to vector<4xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [4, 4]>, dims = [0]>} [0]
+      : vector<4x128xf32> to vector<128xf32>
     gpu.return
   }
 
   // CHECK-LABEL: @vector_reduce_dim_1
-  gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
-    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
-    %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
-      -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+  gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
+      -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
     %load =  xegpu.load_nd %tdesc[0, 0]
-      : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
-      -> vector<256x1xf32>
-    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
-    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
-      : vector<256x1xf32> to vector<256xf32>
+      : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>>
+      -> vector<256x64xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1], sg_data = [16, 64]>, dims = [1]>} [1]
+      : vector<256x64xf32> to vector<256xf32>
     gpu.return
   }
 }

>From 9be2284ec7938bc2bca850ae15d5ed64f24262bd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 16 Sep 2025 18:56:08 +0000
Subject: [PATCH 4/6] Address feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 25 ++++++++++++-------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 15 +++++++++++
 2 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 658e10d12e38b..dced2e5351b1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1027,7 +1027,12 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
-// Pattern for lowering vector.multi_reduction op to subgroup level.
+/// Pattern for lowering vector.multi_reduction op to subgroup level.
+/// Current limitation: only support 2D->1D reduction with single reduction
+/// dimension, and the sg_layout in the reduced dimension being 1
+/// so that reduction is local to subgroup & no cross-subgroup communication is
+/// needed.
+/// TODO: Add cases to handle more general situations which require SLM access.
 struct WgToSgMultiDimReductionOp
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
@@ -1035,14 +1040,15 @@ struct WgToSgMultiDimReductionOp
   LogicalResult
   matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
+    VectorType srcType = op.getSourceVectorType();
     VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
-    if (!srcType || !dstType)
+    if (!dstType)
       return failure();
 
-    // TODO: generalize it
-    auto srcShape = srcType.getShape();
-    auto dstShape = dstType.getShape();
+    SmallVector<int64_t> srcShape(srcType.getShape().begin(),
+                                  srcType.getShape().end());
+    SmallVector<int64_t> dstShape(dstType.getShape().begin(),
+                                  dstType.getShape().end());
     if (srcShape.size() != 2 || dstShape.size() != 1)
       return failure();
 
@@ -1051,7 +1057,8 @@ struct WgToSgMultiDimReductionOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    auto reductionDims = op.getReductionDims();
+    SmallVector<int64_t> reductionDims(op.getReductionDims().begin(),
+                                       op.getReductionDims().end());
     if (reductionDims.size() != 1)
       return failure();
 
@@ -1069,8 +1076,8 @@ struct WgToSgMultiDimReductionOp
     SmallVector<Value> newReductions;
     for (auto [sgSrc, sgAcc] :
          llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
-      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
-          op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
+      auto newOp = vector::MultiDimReductionOp::create(
+          rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
           op.getReductionDims());
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty())
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 6ff7a94d678a3..b4a6aef0b4206 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -82,4 +82,19 @@ gpu.module @test_distribution {
       : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_reduce_dim_1
+  gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
+      -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
+      -> vector<256x64xf32>
+    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
+    // CHECK-NOT: vector.multi_reduction
+    %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
+      : vector<256x64xf32> to vector<256xf32>
+    gpu.return
+  }
 }

>From be94e2fa0d5529a1c92e9f5c8cb60406d17d1edc Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 17 Sep 2025 04:07:44 +0000
Subject: [PATCH 5/6] CHECK

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index b4a6aef0b4206..dce73dee507e1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -85,13 +85,14 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: vector_reduce_dim_1
   gpu.func @vector_reduce_dim_1(%src: memref<256x64xf32>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} dense<1.0> : vector<256xf32>
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x64xf32>
       -> !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
       -> vector<256x64xf32>
-    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<16x64xf32> to vector<16xf32>
+    // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
     // CHECK-NOT: vector.multi_reduction
     %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
       : vector<256x64xf32> to vector<256xf32>

>From a402217ce1d604eef889ad5e51dacfcb63ede840 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 24 Sep 2025 17:41:58 +0000
Subject: [PATCH 6/6] restrict chained reduction

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dced2e5351b1e..9e070693f2d64 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1045,10 +1045,8 @@ struct WgToSgMultiDimReductionOp
     if (!dstType)
       return failure();
 
-    SmallVector<int64_t> srcShape(srcType.getShape().begin(),
-                                  srcType.getShape().end());
-    SmallVector<int64_t> dstShape(dstType.getShape().begin(),
-                                  dstType.getShape().end());
+    auto srcShape = srcType.getShape();
+    auto dstShape = dstType.getShape();
     if (srcShape.size() != 2 || dstShape.size() != 1)
       return failure();
 
@@ -1057,27 +1055,32 @@ struct WgToSgMultiDimReductionOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> reductionDims(op.getReductionDims().begin(),
-                                       op.getReductionDims().end());
+    auto reductionDims = llvm::to_vector(op.getReductionDims());
     if (reductionDims.size() != 1)
       return failure();
 
     SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
                                         .getParent()
                                         .getEffectiveSgLayoutAsInt();
-    // Check that the sgLayout in the reduced dimension is 1.
-    if (sgLayout[reductionDims[0]] != 1)
+    SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
+                                      .getParent()
+                                      .getEffectiveSgDataAsInt();
+
+    // Check that the sgLayout in the reduced dimension is 1 and
+    // each sg gets the entire slice to reduce.
+    if (sgLayout[reductionDims[0]] != 1 ||
+        sgData[reductionDims[0]] != srcShape[reductionDims[0]])
       return failure();
+
     SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
 
     VectorType newDstType =
         VectorType::get({sgShape}, dstType.getElementType());
 
     SmallVector<Value> newReductions;
-    for (auto [sgSrc, sgAcc] :
-         llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
-      auto newOp = vector::MultiDimReductionOp::create(
-          rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
+    for (auto sgSrc : adaptor.getSource()) {
+      auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+          op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
           op.getReductionDims());
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty())
@@ -1085,6 +1088,7 @@ struct WgToSgMultiDimReductionOp
                                        layout.dropSgLayoutAndData());
       newReductions.push_back(newOp.getResult());
     }
+
     rewriter.replaceOpWithMultiple(op, {newReductions});
     return success();
   }



More information about the Mlir-commits mailing list