[Mlir-commits] [mlir] [mlir][xegpu] Add support for `vector.reduction` and `vector.multi_reduction` subgroup to work-item distribution. (PR #180308)

Charitha Saumya llvmlistbot at llvm.org
Mon Feb 9 15:49:33 PST 2026


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/180308

>From 4e252eea5664746ee2a10c27ae0539b4b9e0b283 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Feb 2026 00:28:12 +0000
Subject: [PATCH 1/8] save work

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  3 +
 .../XeGPUSgToWiDistributeExperimental.cpp     | 65 ++++++++++++++++++-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 17 +----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 15 +++++
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 28 ++++++++
 5 files changed, 110 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 700db5f9dd9be..5d54bbc5aca66 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -129,6 +129,9 @@ SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
 
+Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
+                        vector::CombiningKind kind, uint32_t size);
+
 /// Helper Function to find a proper instruction multiple for the user-supplied
 /// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
 /// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..9a2d72d7db6a4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/LogicalResult.h"
@@ -362,6 +363,59 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+
+    // If no layout, nothing to do.
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    VectorType vectorType = op.getSourceVectorType();
+
+    // Only rank 1 vectors supported.
+    if (vectorType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only rank 1 reductions can be distributed.");
+    // Lane layout must have the same rank as the vector.
+    if (layout.getRank() != vectorType.getRank())
+      return rewriter.notifyMatchFailure(
+          op, "Layout rank does not match vector rank.");
+
+    // Get the subgroup size from the layout.
+    int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
+
+    // Only subgroup-sized vectors supported.
+    if (vectorType.getShape()[0] % sgSize != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Reduction vector dimension must match subgroup size.");
+
+    if (!op.getType().isIntOrFloat())
+      return rewriter.notifyMatchFailure(
+          op, "Reduction distribution currently only supports floats and "
+              "integer types.");
+
+    // Get the distributed vector (per work-item portion).
+    Value laneValVec = adaptor.getVector();
+
+    // Distribute and reduce across work-items in the subgroup.
+    Value fullReduce = xegpu::subgroupReduction(
+        op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
+
+    // If there's an accumulator, combine it with the reduced value.
+    if (adaptor.getAcc())
+      fullReduce = vector::makeArithReduction(
+          rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
+
+    rewriter.replaceOp(op, fullReduce);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -551,8 +605,15 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
         }
         return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
       });
+  // VectorReductionOp is legal only if its source has no distribute layout
+  // attribute.
+  target.addDynamicallyLegalOp<vector::ReductionOp>(
+      [=](vector::ReductionOp op) -> bool {
+        auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+        return !layout;
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
-               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
-      typeConverter, patterns.getContext());
+               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+               SgToWiVectorReduction>(typeConverter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a8ed5a289f84a..11e3bc8eb1b27 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -2112,23 +2112,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
 
-  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
-                          vector::CombiningKind kind, uint32_t size) {
-    // First reduce on a single thread to get per lane reduction value.
-    Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
-    // Parallel reduction using butterfly shuffles.
-    for (uint64_t i = 1; i < size; i <<= 1) {
-      Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
-                                              /*width=*/size,
-                                              /*mode=*/gpu::ShuffleMode::XOR)
-                           .getShuffleResult();
-      laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
-    }
-    return laneVal;
-  };
-
   vector::populateDistributeReduction(
-      patterns, warpReduction,
+      patterns, xegpu::subgroupReduction,
       /*pattern benefit=*/PatternHierarchy::Regular);
 
   vector::populatePropagateWarpVectorDistributionPatterns(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 7e28c756f2d72..ff19a881c7bbc 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -756,6 +756,21 @@ int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
   return largest;
 }
 
+Value xegpu::subgroupReduction(Location loc, OpBuilder &builder, Value input,
+                               vector::CombiningKind kind, uint32_t size) {
+  // First reduce on a single thread to get per lane reduction value.
+  Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
+  // Parallel reduction using butterfly shuffles.
+  for (uint64_t i = 1; i < size; i <<= 1) {
+    Value shuffled =
+        gpu::ShuffleOp::create(builder, loc, laneVal, i, /**  width = **/ size,
+                               /**  mode = **/ gpu::ShuffleMode::XOR)
+            .getShuffleResult();
+    laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+  }
+  return laneVal;
+};
+
 /// Explicit instantiations
 template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
                                            ArrayRef<int> candidateMultiples);
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..aea4defd0cde5 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -149,4 +149,32 @@ gpu.func @prefetch_nd() {
     : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   gpu.return
 }
+
+// CHECK-LABEL: gpu.func @vector_reduction
+// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[LANE_RED:.*]] = vector.reduction <add>, %[[CAST:.*]] : vector<2xf32> into f32
+// CHECK: %[[C16_1:.*]] = arith.constant 16 : i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[SHUFFLE1:.*]], %{{.*}} = gpu.shuffle  xor %[[LANE_RED]], %[[C1]], %[[C16_1]] : f32
+// CHECK: %[[ADD1:.*]] = arith.addf %[[LANE_RED]], %[[SHUFFLE1]] : f32
+// CHECK: %[[C16_2:.*]] = arith.constant 16 : i32
+// CHECK: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: %[[SHUFFLE2:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD1]], %[[C2]], %[[C16_2]] : f32
+// CHECK: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[SHUFFLE2]] : f32
+// CHECK: %[[C16_3:.*]] = arith.constant 16 : i32
+// CHECK: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK: %[[SHUFFLE3:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD2]], %[[C4]], %[[C16_3]] : f32
+// CHECK: %[[ADD3:.*]] = arith.addf %[[ADD2]], %[[SHUFFLE3]] : f32
+// CHECK: %[[C16_4:.*]] = arith.constant 16 : i32
+// CHECK: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK: %[[SHUFFLE4:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD3]], %[[C8]], %[[C16_4]] : f32
+// CHECK: %[[ADD4:.*]] = arith.addf %[[ADD3]], %[[SHUFFLE4]] : f32
+// CHECK: %[[FINAL:.*]] = arith.addf %[[ADD4]], %[[CST]] : f32
+gpu.func @vector_reduction() {
+  %acc = arith.constant 1.0 : f32
+  %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : () -> vector<32xf32>
+  %2 = vector.reduction <add>, %0, %acc : vector<32xf32> into f32
+  gpu.return
+}
+
 }

>From b8607418e603e7b0d6b482cb965b356a297d67e5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 3 Feb 2026 00:36:40 +0000
Subject: [PATCH 2/8] save work

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h           | 5 +++++
 .../XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp   | 4 ++++
 2 files changed, 9 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 5d54bbc5aca66..6eaad8a499986 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -129,6 +129,11 @@ SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
 
+/// Given an `input` value representing per-lane data, this function returns the
+/// result after performing a reduction on the input over all lanes (number of
+/// lanes given by `size`). This uses butterfly shuffles to perform the
+/// reduction in a log2(size) number of steps.
+/// NOTE: Implementation taken from TestVectorTransforms.cpp
 Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
                         vector::CombiningKind kind, uint32_t size);
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9a2d72d7db6a4..5fd75e205bf18 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -363,6 +363,10 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern distributes a subgroup-level vector.reduction op to
+/// workitem-level. This require shuffling the data across the workitems (using
+/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
+/// result.
 struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
   using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
 

>From 522c9858e1b424f33ffcd5800956a173ac2f779f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 5 Feb 2026 00:47:50 +0000
Subject: [PATCH 3/8] save work

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |  5 +
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 10 ++
 .../XeGPUSgToWiDistributeExperimental.cpp     | 92 +++++++++++++++++++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 67 +-------------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 61 ++++++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 37 +++++++-
 6 files changed, 206 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index fede329990be4..ce75be245ffc4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -82,6 +82,11 @@ void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter);
 void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target);
+/// Appends patterns to rewrite vector::MultiDimReductionOp in terms of
+/// vector::ReductionOps if the multi-reduction involves cross-lane data
+/// movement.
+void populateXeGPUSgToWiRewriteVectorMultiReductionToVectorReductionPatterns(
+    RewritePatternSet &patterns);
 
 /// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
 /// Users can control whether an operation to be unrolled or not, as well as
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 6eaad8a499986..45a78f64eb869 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -137,6 +137,16 @@ SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
 Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
                         vector::CombiningKind kind, uint32_t size);
 
+/// Given a `src` and an `acc` argumments from a vector::MultiDimReductionOp,
+/// lower to a set of vector::ReductionOp ops over 1D slices extracted from
+/// `src`. The reduction is performed along `reductionDim`. The result is a
+/// vector with the same shape as `acc`.
+/// TODO: Only 2D to 1D reduction is supported for now.
+Value lowerToVectorReductions(TypedValue<VectorType> src,
+                              TypedValue<VectorType> acc,
+                              vector::CombiningKind kind, int64_t reductionDim,
+                              Location loc, PatternRewriter &rewriter);
+
 /// Helper Function to find a proper instruction multiple for the user-supplied
 /// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
 /// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 5fd75e205bf18..1eceff48348e2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -420,6 +420,92 @@ struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
   }
 };
 
+class SgToWiRewriteVectorMultiReduction
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = xegpu::getDistributeLayoutAttr(op.getSource());
+
+    // If no layout, nothing to do.
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    VectorType sourceType = op.getSourceVectorType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op,
+                                         "Only 2D reductions are supported.");
+
+    ArrayRef<int64_t> reductionDims = op.getReductionDims();
+    // Only 1 reduction dimension supported.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only 1 reduction dimension is supported.");
+
+    int64_t reductionDim = reductionDims[0];
+
+    // Get the distributed source type based on layout.
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sourceType);
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute the source vector type.");
+
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+
+    // Determine which dimension is distributed.
+    bool dim0Distributed =
+        sourceDistType.getShape()[0] != sourceType.getShape()[0];
+    bool dim1Distributed =
+        sourceDistType.getShape()[1] != sourceType.getShape()[1];
+
+    if (dim0Distributed && dim1Distributed)
+      return rewriter.notifyMatchFailure(
+          op, "Expecting source to be distributed in a single dimension.");
+
+    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+    if (sourceDistDim == -1)
+      return rewriter.notifyMatchFailure(
+          op, "Expecting a distributed source vector.");
+
+    // Check if reduction is lane-local or requires cross-lane shuffling.
+    // Lane-local: reduction dimension is NOT the distributed dimension.
+    // Cross-lane: reduction dimension IS the distributed dimension.
+    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+                                (sourceDistDim == 1 && reductionDim == 0);
+
+    // Get the distributed source and accumulator.
+    Value distributedSource = adaptor.getSource();
+    Value distributedAcc = adaptor.getAcc();
+
+    if (isReductionLaneLocal) {
+      // For lane-local reduction, each lane reduces its local data.
+      // Use lowerToVectorReductions to create a sequence of 1D reductions.
+      Value result = xegpu::lowerToVectorReductions(
+          cast<TypedValue<VectorType>>(distributedSource),
+          cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
+          reductionDim, op.getLoc(), rewriter);
+
+      rewriter.replaceOp(op, result);
+      return success();
+    }
+
+    // For non-lane-local (cross-lane) reduction, we also lower to a sequence
+    // of 1D reductions. The individual 1D reductions will be handled by
+    // SgToWiVectorReduction pattern which performs the cross-lane shuffles.
+    Value result = xegpu::lowerToVectorReductions(
+        cast<TypedValue<VectorType>>(distributedSource),
+        cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
+        reductionDim, op.getLoc(), rewriter);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -621,3 +707,9 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiVectorReduction>(typeConverter, patterns.getContext());
 }
+
+void xegpu::
+    populateXeGPUSgToWiRewriteVectorMultiReductionToVectorReductionPatterns(
+        RewritePatternSet &patterns) {
+  patterns.add<SgToWiRewriteVectorMultiReduction>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 11e3bc8eb1b27..346c31b5a8a7e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1205,69 +1205,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
   }
 };
 
-/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
-/// VectorReductionOps. We also insert layouts for the newly created ops.
-static Value lowerToVectorReductions(TypedValue<VectorType> src,
-                                     TypedValue<VectorType> acc,
-                                     vector::CombiningKind kind,
-                                     int64_t reductionDim, Location loc,
-                                     PatternRewriter &rewriter) {
-  // Expecting a 2D source vector.
-  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
-  VectorType sourceType = src.getType();
-  int64_t sourceH = sourceType.getShape()[0];
-  int64_t sourceW = sourceType.getShape()[1];
-  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
-  // Create a constant vector to hold the result of the reduction.
-  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
-  Value reductionResult = arith::ConstantOp::create(
-      rewriter, loc, acc.getType(),
-      DenseElementsAttr::get(acc.getType(), zeroAttr));
-  // Reduction result should have the same layout as the accumulator.
-  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
-                            xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
-  // For each slice of the source, extract the slice vector, do a reduction
-  // and, insert the reduced value back to the result vector.
-  for (int i = 0; i < nSlices; ++i) {
-    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-    if (reductionDim == 1) {
-      sliceOffsets = {i, 0};
-      sliceSizes = {1, sourceW};
-    } else {
-      sliceOffsets = {0, i};
-      sliceSizes = {sourceH, 1};
-    }
-    vector::ExtractStridedSliceOp extractOp =
-        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
-                                              sliceSizes, {1, 1});
-
-    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
-
-    vector::ShapeCastOp slice = vector::ShapeCastOp::create(
-        rewriter, loc,
-        VectorType::get({nSliceElements}, sourceType.getElementType()),
-        extractOp.getResult());
-
-    // Shape cast is currently handled in xegpu side. So layouts must be
-    // retained during lowering. Shape cast output has the same layout as the
-    // accumulator. Shape cast source has the same layout as the original
-    // reduction source.
-    // TODO: other ops generated here may also need layout attributes.
-    auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-    auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-
-    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
-    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
-    // Extract and reduction results in scalars, so no result layout is needed.
-    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
-    Value reduction = vector::ReductionOp::create(
-        rewriter, loc, kind, slice.getResult(), accExtract);
-    reductionResult =
-        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
-  }
-  return reductionResult;
-}
-
 /// This patterns distribute the `vector.multi_reduction` operation across
 /// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
 /// layouts for the source and accumulator vectors,
@@ -1405,7 +1342,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
           rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
           {sourceDistType, distributedResultType}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
-      Value result = lowerToVectorReductions(
+      Value result = xegpu::lowerToVectorReductions(
           cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
           cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
           reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
@@ -1417,7 +1354,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
     // of multiple ReductionOps. Actual distribution is done by the
     // WarpOpReduction pattern.
     rewriter.setInsertionPointAfter(reductionOp);
-    Value result = lowerToVectorReductions(
+    Value result = xegpu::lowerToVectorReductions(
         cast<TypedValue<VectorType>>(reductionOp.getSource()),
         cast<TypedValue<VectorType>>(reductionOp.getAcc()),
         reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index ff19a881c7bbc..c4bab915d9e64 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -771,6 +771,67 @@ Value xegpu::subgroupReduction(Location loc, OpBuilder &builder, Value input,
   return laneVal;
 };
 
+Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
+                                     TypedValue<VectorType> acc,
+                                     vector::CombiningKind kind,
+                                     int64_t reductionDim, Location loc,
+                                     PatternRewriter &rewriter) {
+  // Expecting a 2D source vector.
+  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+  VectorType sourceType = src.getType();
+  int64_t sourceH = sourceType.getShape()[0];
+  int64_t sourceW = sourceType.getShape()[1];
+  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // Create a constant vector to hold the result of the reduction.
+  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+  Value reductionResult = arith::ConstantOp::create(
+      rewriter, loc, acc.getType(),
+      DenseElementsAttr::get(acc.getType(), zeroAttr));
+  // Reduction result should have the same layout as the accumulator.
+  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
+                            xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
+  // For each slice of the source, extract the slice vector, do a reduction
+  // and, insert the reduced value back to the result vector.
+  for (int i = 0; i < nSlices; ++i) {
+    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+    if (reductionDim == 1) {
+      sliceOffsets = {i, 0};
+      sliceSizes = {1, sourceW};
+    } else {
+      sliceOffsets = {0, i};
+      sliceSizes = {sourceH, 1};
+    }
+    vector::ExtractStridedSliceOp extractOp =
+        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+                                              sliceSizes, {1, 1});
+
+    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+
+    vector::ShapeCastOp slice = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get({nSliceElements}, sourceType.getElementType()),
+        extractOp.getResult());
+
+    // Shape cast is currently handled in xegpu side. So layouts must be
+    // retained during lowering. Shape cast output has the same layout as the
+    // accumulator. Shape cast source has the same layout as the original
+    // reduction source.
+    // TODO: other ops generated here may also need layout attributes.
+    auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+    auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+
+    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+    // Extract and reduction results in scalars, so no result layout is needed.
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    Value reduction = vector::ReductionOp::create(
+        rewriter, loc, kind, slice.getResult(), accExtract);
+    reductionResult =
+        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+  }
+  return reductionResult;
+}
+
 /// Explicit instantiations
 template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
                                            ArrayRef<int> candidateMultiples);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 405e974500e08..bddb06e92da66 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -272,6 +272,12 @@ struct TestXeGPUSgToWiDistributeExperimental
            "Work-item Distribution";
   }
 
+  Option<bool> enablePartialReductionLowering{
+      *this, "enable-partial-multi-reduction-lowering",
+      llvm::cl::desc(
+          "Partially lower multi-reduction ops to vector.reduction and stop."),
+      llvm::cl::init(false)};
+
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
     registry.insert<arith::ArithDialect>();
     registry.insert<memref::MemRefDialect>();
@@ -283,7 +289,8 @@ struct TestXeGPUSgToWiDistributeExperimental
 
   TestXeGPUSgToWiDistributeExperimental() = default;
   TestXeGPUSgToWiDistributeExperimental(
-      const TestXeGPUSgToWiDistributeExperimental &pass) = default;
+      const TestXeGPUSgToWiDistributeExperimental &pass)
+      : PassWrapper(pass) {}
 
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
@@ -297,6 +304,34 @@ struct TestXeGPUSgToWiDistributeExperimental
     };
     typeConverter.addSourceMaterialization(materializeCast);
     typeConverter.addTargetMaterialization(materializeCast);
+
+    // If `enablePartialReductionLowering` is set, only focus on testing the
+    // partial lowering of vector::MultiReductionOp.
+    if (enablePartialReductionLowering) {
+      xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+      ConversionTarget target(*ctx);
+      RewritePatternSet patterns(ctx);
+      xegpu::populateXeGPUSgToWiRewriteMultiReductionToReductionPatterns(
+          patterns);
+      // Mark 2D to 1D vector::MultiDimReductionOp as illegal.
+      target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+          [&](vector::MultiDimReductionOp op) {
+            int64_t sourceRank = op.getSourceVectorType().getRank();
+            VectorType resultType =
+                dyn_cast<VectorType>(op.getResult().getType());
+            if (!resultType)
+              return true;
+            int64_t resultRank = resultType.getRank();
+            return sourceRank != 2 || resultRank != 1;
+          });
+      // vector::ReductionOp is legal.
+      target.addDynamicallyLegalOp<vector::ReductionOp>(
+          [&](vector::ReductionOp op) { return true; });
+      target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+      (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+      return;
+    }
+
     ConversionTarget target(*ctx);
     RewritePatternSet patterns(ctx);
     xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(

>From 1955d3c04b5a19d8c862835d33435632103c5092 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 7 Feb 2026 00:55:00 +0000
Subject: [PATCH 4/8] save work

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |   6 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   7 +
 .../XeGPUSgToWiDistributeExperimental.cpp     | 211 +++++++++++++-----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  20 +-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |  98 ++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  33 +--
 6 files changed, 276 insertions(+), 99 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index ce75be245ffc4..9967c4b2740bc 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -85,8 +85,8 @@ void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
 /// Appends patterns to rewrite vector::MultiDimReductionOp in terms of
 /// vector::ReductionOps if the multi-reduction involves cross-lane data
 /// movement.
-void populateXeGPUSgToWiRewriteVectorMultiReductionToVectorReductionPatterns(
-    RewritePatternSet &patterns);
+void populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+    RewritePatternSet &patterns, ConversionTarget &target);
 
 /// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
 /// Users can control whether an operation to be unrolled or not, as well as
@@ -98,7 +98,7 @@ void populateXeGPUSgToWiRewriteVectorMultiReductionToVectorReductionPatterns(
 ///   1. the unrolled type `unrolledType` and number of unrolled instances
 ///   `numUnrolledInstances` are computed from the `targetShape`.
 ///   2. pack each operand. ExtractStridedSlice are created to break-up the
-///   vector operands. And BuiltinUnrealizedCastop are created to break-up
+///   vector operands. And BuiltinUnrealizedCastOp are created to break-up
 ///    the TensorDesc operands.
 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
 ///   result.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index faafb7e8cee61..78f56cda281f0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -840,6 +840,9 @@ void LayoutInfoPropagation::visitDpasOp(
       // Step 1. Get all valid layouts for A, B, and C operands.
       // All operands must have at least one valid subgroup layout.
       LayoutInfo layoutD = results[0]->getValue();
+      llvm::errs() << "DPAS result layout: ";
+      layoutD.print(llvm::errs());
+      llvm::errs() << "\n";
       SmallVector<int> sgLayoutD = layoutD.getSgLayout();
       assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
       auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
@@ -994,6 +997,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
+      llvm::errs() << "Num SG: " << numSgOrErr.value() << "\n";
       auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
                                        instData, numSgOrErr.value());
       if (sgLayouts.empty()) {
@@ -1011,6 +1015,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
           DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
+      llvm::errs() << "Chosen SG Layout:";
+      storeLayout.print(llvm::errs());
+      llvm::errs() << "\n";
     }
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 1eceff48348e2..f116b20eea4aa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -420,24 +420,121 @@ struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
   }
 };
 
-class SgToWiRewriteVectorMultiReduction
+// class SgToWiRewriteVectorMultiReduction
+//     : public OpConversionPattern<vector::MultiDimReductionOp> {
+//   using
+//   OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+//   LogicalResult
+//   matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+//                   ConversionPatternRewriter &rewriter) const override {
+//     auto layout = xegpu::getDistributeLayoutAttr(op.getSource());
+
+//     // If no layout, nothing to do.
+//     if (!layout || !layout.isForSubgroup())
+//       return failure();
+
+//     VectorType sourceType = op.getSourceVectorType();
+//     // Only 2D vectors are supported.
+//     if (sourceType.getRank() != 2)
+//       return rewriter.notifyMatchFailure(op,
+//                                          "Only 2D reductions are
+//                                          supported.");
+
+//     ArrayRef<int64_t> reductionDims = op.getReductionDims();
+//     // Only 1 reduction dimension supported.
+//     if (reductionDims.size() != 1)
+//       return rewriter.notifyMatchFailure(
+//           op, "Only 1 reduction dimension is supported.");
+
+//     int64_t reductionDim = reductionDims[0];
+
+//     // Get the distributed source type based on layout.
+//     FailureOr<VectorType> sourceDistTypeOrFailure =
+//         getDistVecTypeBasedOnLaneLayout(layout, sourceType);
+//     if (failed(sourceDistTypeOrFailure))
+//       return rewriter.notifyMatchFailure(
+//           op, "Failed to distribute the source vector type.");
+
+//     VectorType sourceDistType = sourceDistTypeOrFailure.value();
+
+//     // Determine which dimension is distributed.
+//     bool dim0Distributed =
+//         sourceDistType.getShape()[0] != sourceType.getShape()[0];
+//     bool dim1Distributed =
+//         sourceDistType.getShape()[1] != sourceType.getShape()[1];
+
+//     if (dim0Distributed && dim1Distributed)
+//       return rewriter.notifyMatchFailure(
+//           op, "Expecting source to be distributed in a single dimension.");
+
+//     int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+//     if (sourceDistDim == -1)
+//       return rewriter.notifyMatchFailure(
+//           op, "Expecting a distributed source vector.");
+
+//     // Check if reduction is lane-local or requires cross-lane shuffling.
+//     // Lane-local: reduction dimension is NOT the distributed dimension.
+//     // Cross-lane: reduction dimension IS the distributed dimension.
+//     bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+//                                 (sourceDistDim == 1 && reductionDim == 0);
+
+//     // Get the distributed source and accumulator.
+//     Value distributedSource = adaptor.getSource();
+//     Value distributedAcc = adaptor.getAcc();
+
+//     if (isReductionLaneLocal) {
+//       // For lane-local reduction, each lane reduces its local data.
+//       // Use lowerToVectorReductions to create a sequence of 1D reductions.
+//       Value result = xegpu::lowerToVectorReductions(
+//           cast<TypedValue<VectorType>>(distributedSource),
+//           cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
+//           reductionDim, op.getLoc(), rewriter);
+
+//       rewriter.replaceOp(op, result);
+//       return success();
+//     }
+
+//     // For non-lane-local (cross-lane) reduction, we also lower to a sequence
+//     // of 1D reductions. The individual 1D reductions will be handled by
+//     // SgToWiVectorReduction pattern which performs the cross-lane shuffles.
+//     Value result = xegpu::lowerToVectorReductions(
+//         cast<TypedValue<VectorType>>(distributedSource),
+//         cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
+//         reductionDim, op.getLoc(), rewriter);
+
+//     rewriter.replaceOp(op, result);
+//     return success();
+//   }
+// };
+
+struct LowerVectorMultiReductionPattern
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto layout = xegpu::getDistributeLayoutAttr(op.getSource());
-
+    auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
     // If no layout, nothing to do.
-    if (!layout || !layout.isForSubgroup())
+    if (!resLayout || !resLayout.isForSubgroup())
       return failure();
-
     VectorType sourceType = op.getSourceVectorType();
     // Only 2D vectors are supported.
     if (sourceType.getRank() != 2)
-      return rewriter.notifyMatchFailure(op,
-                                         "Only 2D reductions are supported.");
+      return rewriter.notifyMatchFailure(
+          op, "Expecting 2D source vector in vector::MultiDimReductionOp at "
+              "subgroup level.");
+    VectorType resType = dyn_cast<VectorType>(op.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(op, "Expecting vector result type");
+    // Compute the distributed vector type based on the layout.
+    FailureOr<VectorType> resDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(resLayout, resType);
+    if (failed(resDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to compute the distributed vector type based on the "
+              "layout.");
 
     ArrayRef<int64_t> reductionDims = op.getReductionDims();
     // Only 1 reduction dimension supported.
@@ -447,59 +544,19 @@ class SgToWiRewriteVectorMultiReduction
 
     int64_t reductionDim = reductionDims[0];
 
-    // Get the distributed source type based on layout.
-    FailureOr<VectorType> sourceDistTypeOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, sourceType);
-    if (failed(sourceDistTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "Failed to distribute the source vector type.");
-
-    VectorType sourceDistType = sourceDistTypeOrFailure.value();
-
-    // Determine which dimension is distributed.
-    bool dim0Distributed =
-        sourceDistType.getShape()[0] != sourceType.getShape()[0];
-    bool dim1Distributed =
-        sourceDistType.getShape()[1] != sourceType.getShape()[1];
-
-    if (dim0Distributed && dim1Distributed)
-      return rewriter.notifyMatchFailure(
-          op, "Expecting source to be distributed in a single dimension.");
-
-    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
-    if (sourceDistDim == -1)
-      return rewriter.notifyMatchFailure(
-          op, "Expecting a distributed source vector.");
-
-    // Check if reduction is lane-local or requires cross-lane shuffling.
-    // Lane-local: reduction dimension is NOT the distributed dimension.
-    // Cross-lane: reduction dimension IS the distributed dimension.
-    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
-                                (sourceDistDim == 1 && reductionDim == 0);
-
-    // Get the distributed source and accumulator.
-    Value distributedSource = adaptor.getSource();
-    Value distributedAcc = adaptor.getAcc();
-
-    if (isReductionLaneLocal) {
-      // For lane-local reduction, each lane reduces its local data.
-      // Use lowerToVectorReductions to create a sequence of 1D reductions.
-      Value result = xegpu::lowerToVectorReductions(
-          cast<TypedValue<VectorType>>(distributedSource),
-          cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
-          reductionDim, op.getLoc(), rewriter);
-
-      rewriter.replaceOp(op, result);
-      return success();
-    }
+    // Rewrite is only needed when when the reduction is not lane-local.
+    // If the reduction is lane-local, result type is not the same before and
+    // after distribution (i.e. result is distributed to lanes not shared.)
+    bool reductionDimDistributed =
+        resType.getShape() != resDistTypeOrFailure.value().getShape();
+    if (reductionDimDistributed)
+      return failure();
 
-    // For non-lane-local (cross-lane) reduction, we also lower to a sequence
-    // of 1D reductions. The individual 1D reductions will be handled by
-    // SgToWiVectorReduction pattern which performs the cross-lane shuffles.
+    // Rewrite MultiDimReductionOp into a sequence of ReductionOps.
     Value result = xegpu::lowerToVectorReductions(
-        cast<TypedValue<VectorType>>(distributedSource),
-        cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
-        reductionDim, op.getLoc(), rewriter);
+        cast<TypedValue<VectorType>>(op.getSource()),
+        cast<TypedValue<VectorType>>(op.getAcc()), op.getKind(), reductionDim,
+        op.getLoc(), rewriter);
 
     rewriter.replaceOp(op, result);
     return success();
@@ -708,8 +765,38 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
                SgToWiVectorReduction>(typeConverter, patterns.getContext());
 }
 
-void xegpu::
-    populateXeGPUSgToWiRewriteVectorMultiReductionToVectorReductionPatterns(
-        RewritePatternSet &patterns) {
-  patterns.add<SgToWiRewriteVectorMultiReduction>(patterns.getContext());
+void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+    RewritePatternSet &patterns, ConversionTarget &target) {
+  // vector::MultiDimReductionOp is legal only if the reduction dimension is
+  // not distributed.
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [&](vector::MultiDimReductionOp op) {
+        auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+        // If no layout, mark legal.
+        if (!resLayout || !resLayout.isForSubgroup())
+          return true;
+        VectorType resTy = dyn_cast<VectorType>(op.getType());
+        if (!resTy)
+          return true;
+        // Compute the distributed result vector type based on the layout.
+        FailureOr<VectorType> resDistTypeOrFailure =
+            getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+        if (failed(resDistTypeOrFailure))
+          return true;
+
+        ArrayRef<int64_t> reductionDims = op.getReductionDims();
+        // Only 1 reduction dimension supported.
+        if (reductionDims.size() != 1)
+          return true;
+        // Op is legal if the reduction dimension is not distributed.
+        // If the reduction dim is distributed, each lane does its own local
+        // reduction and result is distributed. So result types before and after
+        // distribution should not match.
+        return resTy != resDistTypeOrFailure.value();
+      });
+  // vector::ReductionOp is legal.
+  target.addDynamicallyLegalOp<vector::ReductionOp>(
+      [&](vector::ReductionOp op) { return true; });
+  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+  patterns.add<LowerVectorMultiReductionPattern>(patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index c4bab915d9e64..90460d9a5837d 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -787,9 +787,10 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
   Value reductionResult = arith::ConstantOp::create(
       rewriter, loc, acc.getType(),
       DenseElementsAttr::get(acc.getType(), zeroAttr));
+  auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+  auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
   // Reduction result should have the same layout as the accumulator.
-  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
-                            xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
+  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   // For each slice of the source, extract the slice vector, do a reduction
   // and, insert the reduced value back to the result vector.
   for (int i = 0; i < nSlices; ++i) {
@@ -801,9 +802,12 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
       sliceOffsets = {0, i};
       sliceSizes = {sourceH, 1};
     }
+
     vector::ExtractStridedSliceOp extractOp =
         vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
                                               sliceSizes, {1, 1});
+    // Extract strided slice has the same layout as src.
+    xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
 
     int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
 
@@ -812,14 +816,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         VectorType::get({nSliceElements}, sourceType.getElementType()),
         extractOp.getResult());
 
-    // Shape cast is currently handled in xegpu side. So layouts must be
-    // retained during lowering. Shape cast output has the same layout as the
-    // accumulator. Shape cast source has the same layout as the original
-    // reduction source.
-    // TODO: other ops generated here may also need layout attributes.
-    auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-    auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-
+    // Shape cast output has the same layout as the accumulator. Shape cast
+    // source has the same layout as the original reduction source.
     xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
     xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
     // Extract and reduction results in scalars, so no result layout is needed.
@@ -828,6 +826,8 @@ Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
         rewriter, loc, kind, slice.getResult(), accExtract);
     reductionResult =
         vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+    // Insert op should have the same layout as the accumulator.
+    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
   }
   return reductionResult;
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index aea4defd0cde5..5e2dac9bbb79a 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -2,6 +2,10 @@
 // RUN: mlir-opt  --xevm-attach-target='module=xevm_* chip=pvc' --allow-unregistered-dialect \
 // RUN: --test-xegpu-sg-to-wi-distribute-experimental --split-input-file %s | FileCheck %s
 
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-xegpu-sg-to-wi-distribute-experimental="enable-REWRITE-to-reductions" \
+// RUN: --split-input-file  %s | FileCheck --check-prefix=CHECK-REWRITE %s
+
 
 
 gpu.module @xevm_module {
@@ -177,4 +181,98 @@ gpu.func @vector_reduction() {
   gpu.return
 }
 
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK-REWRITE-DAG:     %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME:      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
+// CHECK-REWRITE-DAG:     %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG:     %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE:         %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = "some_def"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : () -> (vector<2x16xf32>)
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<2xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+      }
+      [1] : vector<2x16xf32> to vector<2xf32>
+  gpu.return
+}
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK-REWRITE-DAG:     %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
+// CHECK-REWRITE-DAG:     %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG:     %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE:         %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME:      {{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+// CHECK-REWRITE-SAME:      : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = "some_def"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      : () -> (vector<16x2xf32>)
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+      dense<0.0>  : vector<2xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+      }
+      [0] : vector<16x2xf32> to vector<2xf32>
+  gpu.return
+}
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index bddb06e92da66..7302865bb61ca 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -272,10 +272,10 @@ struct TestXeGPUSgToWiDistributeExperimental
            "Work-item Distribution";
   }
 
-  Option<bool> enablePartialReductionLowering{
-      *this, "enable-partial-multi-reduction-lowering",
-      llvm::cl::desc(
-          "Partially lower multi-reduction ops to vector.reduction and stop."),
+  Option<bool> enableRewriteMultiReductionToReductions{
+      *this, "enable-rewrite-multi-reduction-to-reductions",
+      llvm::cl::desc("Partially lower multi-reduction ops to reduction ops if "
+                     "the reduction dimension is distributed."),
       llvm::cl::init(false)};
 
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
@@ -305,29 +305,14 @@ struct TestXeGPUSgToWiDistributeExperimental
     typeConverter.addSourceMaterialization(materializeCast);
     typeConverter.addTargetMaterialization(materializeCast);
 
-    // If `enablePartialReductionLowering` is set, only focus on testing the
-    // partial lowering of vector::MultiReductionOp.
-    if (enablePartialReductionLowering) {
+    // If `enableRewriteMultiReductionToReductions` is set, only focus on
+    // testing the partial lowering of vector::MultiReductionOp.
+    if (enableRewriteMultiReductionToReductions) {
       xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
       ConversionTarget target(*ctx);
       RewritePatternSet patterns(ctx);
-      xegpu::populateXeGPUSgToWiRewriteMultiReductionToReductionPatterns(
-          patterns);
-      // Mark 2D to 1D vector::MultiDimReductionOp as illegal.
-      target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
-          [&](vector::MultiDimReductionOp op) {
-            int64_t sourceRank = op.getSourceVectorType().getRank();
-            VectorType resultType =
-                dyn_cast<VectorType>(op.getResult().getType());
-            if (!resultType)
-              return true;
-            int64_t resultRank = resultType.getRank();
-            return sourceRank != 2 || resultRank != 1;
-          });
-      // vector::ReductionOp is legal.
-      target.addDynamicallyLegalOp<vector::ReductionOp>(
-          [&](vector::ReductionOp op) { return true; });
-      target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+      xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(patterns,
+                                                                     target);
       (void)applyPartialConversion(getOperation(), target, std::move(patterns));
       return;
     }

>From 8d8800a74e4c964b5c594e759f85ba1cc3f59e76 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 7 Feb 2026 01:33:42 +0000
Subject: [PATCH 5/8] save work

---
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |  2 +-
 .../XeGPU/subgroup-distribute-unit.mlir       | 41 ++++++++++---------
 2 files changed, 23 insertions(+), 20 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 5e2dac9bbb79a..d28fe51ab91b1 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -3,7 +3,7 @@
 // RUN: --test-xegpu-sg-to-wi-distribute-experimental --split-input-file %s | FileCheck %s
 
 // RUN: mlir-opt --allow-unregistered-dialect \
-// RUN: --test-xegpu-sg-to-wi-distribute-experimental="enable-REWRITE-to-reductions" \
+// RUN: --test-xegpu-sg-to-wi-distribute-experimental="enable-rewrite-multi-reduction-to-reductions"  \
 // RUN: --split-input-file  %s | FileCheck --check-prefix=CHECK-REWRITE %s
 
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 81f25cc85359f..95a89e2edc84a 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -288,13 +288,9 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 // CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
 // CHECK-NEXT:   %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
 // CHECK-NEXT:   %[[T5:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK-NEXT:   gpu.yield %[[T7]]
-// CHECK-NEXT: }
+// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -356,20 +352,27 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>{{.*}}) {
-// CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
+// CHECK:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:       %[[SRC:.*]] = "some_def"()
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:    : () -> vector<16x2xf32>
 // CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME:    {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
-// CHECK:       %[[T5:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME:     {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T6:.*]] = vector.shape_cast %[[T5]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T7:.*]] = vector.reduction <add>, %[[T6]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T8:.*]] = vector.insert %[[T7]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:       gpu.yield %[[T8]]
-// CHECK:     }
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]]
+// CHECK-SAME:    {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME:    : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[CST]] : vector<16xf32> into f32
+// CHECK:       %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK:       %[[T5:.*]] = vector.shape_cast %[[T4]]
+// CHECK-SAME:    {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME:    : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T5]], %[[CST]] : vector<16xf32> into f32
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {

>From 7e43badadf2d46101be07e45428612e43dee0be5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 7 Feb 2026 01:36:22 +0000
Subject: [PATCH 6/8] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 78f56cda281f0..faafb7e8cee61 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -840,9 +840,6 @@ void LayoutInfoPropagation::visitDpasOp(
       // Step 1. Get all valid layouts for A, B, and C operands.
       // All operands must have at least one valid subgroup layout.
       LayoutInfo layoutD = results[0]->getValue();
-      llvm::errs() << "DPAS result layout: ";
-      layoutD.print(llvm::errs());
-      llvm::errs() << "\n";
       SmallVector<int> sgLayoutD = layoutD.getSgLayout();
       assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
       auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
@@ -997,7 +994,6 @@ void LayoutInfoPropagation::visitStoreNdOp(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
-      llvm::errs() << "Num SG: " << numSgOrErr.value() << "\n";
       auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
                                        instData, numSgOrErr.value());
       if (sgLayouts.empty()) {
@@ -1015,9 +1011,6 @@ void LayoutInfoPropagation::visitStoreNdOp(
           DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
-      llvm::errs() << "Chosen SG Layout:";
-      storeLayout.print(llvm::errs());
-      llvm::errs() << "\n";
     }
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));

>From 3e53afc2e50999209132dff019e2a1ba616aac3b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 7 Feb 2026 01:37:59 +0000
Subject: [PATCH 7/8] save work

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 89 -------------------
 1 file changed, 89 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index f116b20eea4aa..6af3eb6eb4ef7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -23,7 +23,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/ValueRange.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/LogicalResult.h"
@@ -420,94 +419,6 @@ struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
   }
 };
 
-// class SgToWiRewriteVectorMultiReduction
-//     : public OpConversionPattern<vector::MultiDimReductionOp> {
-//   using
-//   OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
-
-//   LogicalResult
-//   matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
-//                   ConversionPatternRewriter &rewriter) const override {
-//     auto layout = xegpu::getDistributeLayoutAttr(op.getSource());
-
-//     // If no layout, nothing to do.
-//     if (!layout || !layout.isForSubgroup())
-//       return failure();
-
-//     VectorType sourceType = op.getSourceVectorType();
-//     // Only 2D vectors are supported.
-//     if (sourceType.getRank() != 2)
-//       return rewriter.notifyMatchFailure(op,
-//                                          "Only 2D reductions are
-//                                          supported.");
-
-//     ArrayRef<int64_t> reductionDims = op.getReductionDims();
-//     // Only 1 reduction dimension supported.
-//     if (reductionDims.size() != 1)
-//       return rewriter.notifyMatchFailure(
-//           op, "Only 1 reduction dimension is supported.");
-
-//     int64_t reductionDim = reductionDims[0];
-
-//     // Get the distributed source type based on layout.
-//     FailureOr<VectorType> sourceDistTypeOrFailure =
-//         getDistVecTypeBasedOnLaneLayout(layout, sourceType);
-//     if (failed(sourceDistTypeOrFailure))
-//       return rewriter.notifyMatchFailure(
-//           op, "Failed to distribute the source vector type.");
-
-//     VectorType sourceDistType = sourceDistTypeOrFailure.value();
-
-//     // Determine which dimension is distributed.
-//     bool dim0Distributed =
-//         sourceDistType.getShape()[0] != sourceType.getShape()[0];
-//     bool dim1Distributed =
-//         sourceDistType.getShape()[1] != sourceType.getShape()[1];
-
-//     if (dim0Distributed && dim1Distributed)
-//       return rewriter.notifyMatchFailure(
-//           op, "Expecting source to be distributed in a single dimension.");
-
-//     int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
-//     if (sourceDistDim == -1)
-//       return rewriter.notifyMatchFailure(
-//           op, "Expecting a distributed source vector.");
-
-//     // Check if reduction is lane-local or requires cross-lane shuffling.
-//     // Lane-local: reduction dimension is NOT the distributed dimension.
-//     // Cross-lane: reduction dimension IS the distributed dimension.
-//     bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
-//                                 (sourceDistDim == 1 && reductionDim == 0);
-
-//     // Get the distributed source and accumulator.
-//     Value distributedSource = adaptor.getSource();
-//     Value distributedAcc = adaptor.getAcc();
-
-//     if (isReductionLaneLocal) {
-//       // For lane-local reduction, each lane reduces its local data.
-//       // Use lowerToVectorReductions to create a sequence of 1D reductions.
-//       Value result = xegpu::lowerToVectorReductions(
-//           cast<TypedValue<VectorType>>(distributedSource),
-//           cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
-//           reductionDim, op.getLoc(), rewriter);
-
-//       rewriter.replaceOp(op, result);
-//       return success();
-//     }
-
-//     // For non-lane-local (cross-lane) reduction, we also lower to a sequence
-//     // of 1D reductions. The individual 1D reductions will be handled by
-//     // SgToWiVectorReduction pattern which performs the cross-lane shuffles.
-//     Value result = xegpu::lowerToVectorReductions(
-//         cast<TypedValue<VectorType>>(distributedSource),
-//         cast<TypedValue<VectorType>>(distributedAcc), op.getKind(),
-//         reductionDim, op.getLoc(), rewriter);
-
-//     rewriter.replaceOp(op, result);
-//     return success();
-//   }
-// };
-
 struct LowerVectorMultiReductionPattern
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;

>From 80561e563e5da373aee71c2a3aa75d2c4a71664e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 9 Feb 2026 23:48:46 +0000
Subject: [PATCH 8/8] save work

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 81 +++++++++++++++++--
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 42 ++++++++++
 2 files changed, 118 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 6af3eb6eb4ef7..7c641f788e842 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -419,6 +419,55 @@ struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
   }
 };
 
+struct SgToWiMultiDimReduction
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+    // If no layout, nothing to do.
+    if (!resLayout || !resLayout.isForSubgroup())
+      return failure();
+    VectorType sourceType = op.getSourceVectorType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          op, "Expecting 2D source vector in vector::MultiDimReductionOp at "
+              "subgroup level.");
+    VectorType resVecTy = dyn_cast<VectorType>(op.getType());
+    if (!resVecTy)
+      return rewriter.notifyMatchFailure(op, "Expecting vector result type");
+    // Compute the distributed vector type based on the layout.
+    FailureOr<VectorType> resDistVecTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
+    if (failed(resDistVecTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to compute the distributed vector type based on the "
+              "layout.");
+    // Check if the reduction is single dim.
+    ArrayRef<int64_t> reductionDims = op.getReductionDims();
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only 1 reduction dimension is supported.");
+    // Check if the reduction is lane-local. If not, distributed result type ==
+    // original result type. This case is handled by
+    // `LowerVectorMultiReductionPattern`.
+    if (resVecTy == resDistVecTyOrFailure.value())
+      return rewriter.notifyMatchFailure(
+          op, "Reduction is not lane-local, expected reduction dimension to be "
+              "not distributed.");
+    // Simply create a new MultiDimReductionOp using adaptor operands and the
+    // new result type.
+    auto newOp = vector::MultiDimReductionOp::create(
+        rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
+        adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
 struct LowerVectorMultiReductionPattern
     : public OpConversionPattern<vector::MultiDimReductionOp> {
   using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
@@ -663,17 +712,42 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
         }
         return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
       });
-  // VectorReductionOp is legal only if its source has no distribute layout
+  // vector::ReductionOp is legal only if its source has no distribute layout
   // attribute.
   target.addDynamicallyLegalOp<vector::ReductionOp>(
       [=](vector::ReductionOp op) -> bool {
         auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
         return !layout;
       });
+  // vector::MultiDimReductionOp is legal only if its result has no distribute
+  // layout attribute.
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+        // If no layout, mark legal.
+        if (!resLayout || !resLayout.isForSubgroup())
+          return true;
+        VectorType resTy = dyn_cast<VectorType>(op.getType());
+        if (!resTy)
+          return true;
+        // Compute the distributed result vector type based on the layout.
+        FailureOr<VectorType> resDistTypeOrFailure =
+            getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+        if (failed(resDistTypeOrFailure))
+          return true;
+
+        ArrayRef<int64_t> reductionDims = op.getReductionDims();
+        // Only 1 reduction dimension supported.
+        if (reductionDims.size() != 1)
+          return true;
+        // Op is legal if the reduction dimension is distributed.
+        return resTy == resDistTypeOrFailure.value();
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
-               SgToWiVectorReduction>(typeConverter, patterns.getContext());
+               SgToWiVectorReduction, SgToWiMultiDimReduction>(
+      typeConverter, patterns.getContext());
 }
 
 void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
@@ -700,9 +774,6 @@ void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
         if (reductionDims.size() != 1)
           return true;
         // Op is legal if the reduction dimension is not distributed.
-        // If the reduction dim is distributed, each lane does its own local
-        // reduction and result is distributed. So result types before and after
-        // distribution should not match.
         return resTy != resDistTypeOrFailure.value();
       });
   // vector::ReductionOp is legal.
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index d28fe51ab91b1..077d92335c089 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -275,4 +275,46 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
       [0] : vector<16x2xf32> to vector<2xf32>
   gpu.return
 }
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
+// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [0] : vector<4x1xf32> to vector<1xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      dense<0.0>  : vector<4x16xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+      dense<0.0>  : vector<16xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+      }
+      [0] : vector<4x16xf32> to vector<16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
+// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x12xf32> to vector<1xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      dense<0.0>  : vector<16x12xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<16xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+      }
+      [1] : vector<16x12xf32> to vector<16xf32>
+  gpu.return
+}
 }



More information about the Mlir-commits mailing list