[Mlir-commits] [mlir] 6b5c440 - [mlir][xegpu] Add support for `vector.reduction` and `vector.multi_reduction` subgroup to work-item distribution. (#180308)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 13 11:50:00 PST 2026


Author: Charitha Saumya
Date: 2026-02-13T11:49:55-08:00
New Revision: 6b5c440a676defbbbef3f6a42404b0be2b3da54c

URL: https://github.com/llvm/llvm-project/commit/6b5c440a676defbbbef3f6a42404b0be2b3da54c
DIFF: https://github.com/llvm/llvm-project/commit/6b5c440a676defbbbef3f6a42404b0be2b3da54c.diff

LOG: [mlir][xegpu] Add support for `vector.reduction` and `vector.multi_reduction` subgroup to work-item distribution. (#180308)

This PR adds support for lowering of `vector.reduction` and
`vector.multi_reduction` ops in subgroup to work-item distribution.

Following cases are considered currently (more support will be added
later):

* `vector.reduction` : This assumes the source vector is distributed to
all lanes and lanes must shuffle data to do a collaborative reduction.
result is shared among all lanes. This is done by emitting
`gpu::ShuffleOp` s and doing a butterfly reduction. Refer
`VectorDistribution` for more details.
* `vector.multi_reduction`: 2 cases are considered,

1. **Reduction is lane-local**: simply lower to a lane local multi
reduction op. each lane does its own reduction. result is distributed.
2. **Reduction is not lane-local:** This one is handled indirectly. In
this case, we rewrite the reduction in terms of `vector.reduction` ops
(plus exrtact. insert) before the WI distribution even begin. Then whole
things is distributed using `gpu::ShuffleOp` s later (not fullly
supported yet).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
    mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
    mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
    mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index ea01975da582f..6f6d58d4ab605 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -82,6 +82,14 @@ void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter);
 void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target);
+/// Appends patterns to rewrite vector::MultiDimReductionOp in terms of
+/// vector::ReductionOps if the multi-reduction involves cross-lane data
+/// movement. This pattern is used as pre-processing step before applying
+/// subgroup to workitem distribution patterns. This pattern will rewrite a
+/// multi reduction in terms of a series of simpler extract, reduction and
+/// insert ops if the reduction require cross-lane data movement.
+void populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+    RewritePatternSet &patterns, ConversionTarget &target);
 
 /// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
 /// Users can control whether an operation to be unrolled or not, as well as
@@ -93,7 +101,7 @@ void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
 ///   1. the unrolled type `unrolledType` and number of unrolled instances
 ///   `numUnrolledInstances` are computed from the `targetShape`.
 ///   2. pack each operand. ExtractStridedSlice are created to break-up the
-///   vector operands. And BuiltinUnrealizedCastop are created to break-up
+///   vector operands. And BuiltinUnrealizedCastOp are created to break-up
 ///    the TensorDesc operands.
 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
 ///   result.

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 4443f86d1e4e2..ebf50c4cd57de 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -129,6 +129,24 @@ SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
 
+/// Given an `input` value representing per-lane data, this function returns the
+/// result after performing a reduction on the input over all lanes (number of
+/// lanes given by `size`). This uses butterfly shuffles to perform the
+/// reduction in a log2(size) number of steps.
+/// NOTE: Implementation taken from TestVectorTransforms.cpp
+Value subgroupReduction(Location loc, OpBuilder &builder, Value input,
+                        vector::CombiningKind kind, uint32_t size);
+
+/// Given a `src` and an `acc` argumments from a vector::MultiDimReductionOp,
+/// lower to a set of vector::ReductionOp ops over 1D slices extracted from
+/// `src`. The reduction is performed along `reductionDim`. The result is a
+/// vector with the same shape as `acc`.
+/// TODO: Only 2D to 1D reduction is supported for now.
+Value lowerToVectorReductions(TypedValue<VectorType> src,
+                              TypedValue<VectorType> acc,
+                              vector::CombiningKind kind, int64_t reductionDim,
+                              Location loc, PatternRewriter &rewriter);
+
 /// Helper Function to find a proper instruction multiple for the user-supplied
 /// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
 /// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..3787fbb44e1b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -88,6 +88,39 @@ static LogicalResult verifyLayouts(Operation *root) {
   return walkResult.wasInterrupted() ? failure() : success();
 }
 
+/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
+/// exactly 1 reduction dimension, it had valid result layout attribute, and
+/// result type can be distributed to lanes using the layout.
+static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
+  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+  // If no layout, not valid.
+  if (!resLayout || !resLayout.isForSubgroup())
+    return false;
+  VectorType resTy = dyn_cast<VectorType>(op.getType());
+  if (!resTy)
+    return false;
+  // Compute the distributed result vector type based on the layout.
+  FailureOr<VectorType> resDistTypeOrFailure =
+      getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+  if (failed(resDistTypeOrFailure))
+    return false;
+  return op.getReductionDims().size() == 1;
+}
+
+/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
+/// is doing its own local reduction. In this case the result layout ensures
+/// that result vector is distributed to lanes, i.e. the result vector type is
+/// 
diff erent from the distributed result vector type.
+static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
+  // Must be valid MultiDimReductionOp.
+  assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
+                                                "MultiDimReductionOp");
+  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+  VectorType resTy = dyn_cast<VectorType>(op.getType());
+  auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+  return resTy != resDistTypeOrFailure.value();
+}
+
 /// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
 /// op. This simply drops the layout attribute from the tensor descriptor type.
 struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
@@ -362,6 +395,133 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern distributes a subgroup-level vector.reduction op to
+/// workitem-level. This require shuffling the data across the workitems (using
+/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
+/// result.
+struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+
+    // If no layout, nothing to do.
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    VectorType srcVecType = op.getSourceVectorType();
+    // Only rank 1 vectors supported.
+    if (srcVecType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only rank 1 reductions can be distributed.");
+    // Lane layout must have the same rank as the vector.
+    if (layout.getRank() != srcVecType.getRank())
+      return rewriter.notifyMatchFailure(
+          op, "Layout rank does not match vector rank.");
+
+    // Get the subgroup size from the layout.
+    int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
+    const auto *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+    if (!uArch)
+      return rewriter.notifyMatchFailure(
+          op, "xegpu::ReductionOp require target attribute attached to "
+              "determine subgroup size");
+
+    // Only subgroup-sized vectors supported.
+    if (sgSize != uArch->getSubgroupSize() ||
+        srcVecType.getShape()[0] % sgSize != 0)
+      return rewriter.notifyMatchFailure(op,
+                                         "Invalid layout or reduction vector "
+                                         "dimension must match subgroup size.");
+
+    if (!op.getType().isIntOrFloat())
+      return rewriter.notifyMatchFailure(
+          op, "Reduction distribution currently only supports floats and "
+              "integer types.");
+
+    // Get the distributed vector (per work-item portion).
+    Value laneValVec = adaptor.getVector();
+
+    // Distribute and reduce across work-items in the subgroup.
+    Value fullReduce = xegpu::subgroupReduction(
+        op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
+
+    // If there's an accumulator, combine it with the reduced value.
+    if (adaptor.getAcc())
+      fullReduce = vector::makeArithReduction(
+          rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
+
+    rewriter.replaceOp(op, fullReduce);
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level vector.multi_reduction op to
+/// workitem-level only if the reduction is lane-local. This means that
+/// reduction dimension is not distributed to lanes and each lane does its own
+/// local reduction.
+struct SgToWiMultiDimReduction
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only lane-local reduction is handled here.
+    if (!isReductionLaneLocal(op))
+      return rewriter.notifyMatchFailure(
+          op, "Only lane-local reduction is supported, expected reduction "
+              "dimension to be "
+              "not distributed.");
+    auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+    VectorType resVecTy = dyn_cast<VectorType>(op.getType());
+    auto resDistVecTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
+    // Simply create a new MultiDimReductionOp using adaptor operands and the
+    // new result type.
+    auto newOp = vector::MultiDimReductionOp::create(
+        rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
+        adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// This pattern rewrites a subgroup-level vector.multi_reduction op to a series
+/// of vector.extract_strided_slice, vector.reduction and
+/// vector.insert_strided_slice ops. This is used when the reduction dimension
+/// is distributed to lanes and a naive (lane-local) distribution is not
+/// possible. Then later on, these partially lowered subgroup-level ops are
+/// further lowered to workitem-level by respective patterns.
+struct LowerVectorMultiReductionPattern
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only non-lane-local reduction is handled here.
+    if (isReductionLaneLocal(op))
+      return rewriter.notifyMatchFailure(
+          op, "Reduction is lane-local, it does not require rewrite.");
+    ArrayRef<int64_t> reductionDims = op.getReductionDims();
+    assert(
+        reductionDims.size() == 1 &&
+        "Expecting single reduction dimension for subgroup multi reduction op");
+
+    // Rewrite MultiDimReductionOp into a sequence of ReductionOps.
+    Value result = xegpu::lowerToVectorReductions(
+        cast<TypedValue<VectorType>>(op.getSource()),
+        cast<TypedValue<VectorType>>(op.getAcc()), op.getKind(),
+        reductionDims[0], op.getLoc(), rewriter);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -551,8 +711,44 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
         }
         return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
       });
+  // vector::ReductionOp is legal only if its source has no distribute layout
+  // attribute.
+  target.addDynamicallyLegalOp<vector::ReductionOp>(
+      [=](vector::ReductionOp op) -> bool {
+        auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+        return !layout;
+      });
+  // vector::MultiDimReductionOp op legality.
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        // Check common conditions for subgroup multi reduction op.
+        if (!isValidSubgroupMultiReductionOp(op))
+          return true;
+        // Lane local reductions are illegal at this point and must be lowered.
+        return !isReductionLaneLocal(op);
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
-               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
+               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+               SgToWiVectorReduction, SgToWiMultiDimReduction>(
       typeConverter, patterns.getContext());
 }
+
+void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
+    RewritePatternSet &patterns, ConversionTarget &target) {
+  // vector::MultiDimReductionOp legality.
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [&](vector::MultiDimReductionOp op) {
+        // Check common conditions for subgroup multi reduction op.
+        if (!isValidSubgroupMultiReductionOp(op))
+          return true;
+        // Lane local reductions are legal. We only rewrite non-lane-local
+        // reductions.
+        return isReductionLaneLocal(op);
+      });
+  // vector::ReductionOp is legal.
+  target.addDynamicallyLegalOp<vector::ReductionOp>(
+      [&](vector::ReductionOp op) { return true; });
+  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+  patterns.add<LowerVectorMultiReductionPattern>(patterns.getContext());
+}

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7671e2bbc3322..b8c4a309b8eb2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1253,69 +1253,6 @@ struct SinkUniformOps final : public gpu::WarpDistributionPattern {
   }
 };
 
-/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
-/// VectorReductionOps. We also insert layouts for the newly created ops.
-static Value lowerToVectorReductions(TypedValue<VectorType> src,
-                                     TypedValue<VectorType> acc,
-                                     vector::CombiningKind kind,
-                                     int64_t reductionDim, Location loc,
-                                     PatternRewriter &rewriter) {
-  // Expecting a 2D source vector.
-  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
-  VectorType sourceType = src.getType();
-  int64_t sourceH = sourceType.getShape()[0];
-  int64_t sourceW = sourceType.getShape()[1];
-  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
-  // Create a constant vector to hold the result of the reduction.
-  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
-  Value reductionResult = arith::ConstantOp::create(
-      rewriter, loc, acc.getType(),
-      DenseElementsAttr::get(acc.getType(), zeroAttr));
-  // Reduction result should have the same layout as the accumulator.
-  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult),
-                            xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc)));
-  // For each slice of the source, extract the slice vector, do a reduction
-  // and, insert the reduced value back to the result vector.
-  for (int i = 0; i < nSlices; ++i) {
-    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-    if (reductionDim == 1) {
-      sliceOffsets = {i, 0};
-      sliceSizes = {1, sourceW};
-    } else {
-      sliceOffsets = {0, i};
-      sliceSizes = {sourceH, 1};
-    }
-    vector::ExtractStridedSliceOp extractOp =
-        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
-                                              sliceSizes, {1, 1});
-
-    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
-
-    vector::ShapeCastOp slice = vector::ShapeCastOp::create(
-        rewriter, loc,
-        VectorType::get({nSliceElements}, sourceType.getElementType()),
-        extractOp.getResult());
-
-    // Shape cast is currently handled in xegpu side. So layouts must be
-    // retained during lowering. Shape cast output has the same layout as the
-    // accumulator. Shape cast source has the same layout as the original
-    // reduction source.
-    // TODO: other ops generated here may also need layout attributes.
-    auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
-    auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
-
-    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
-    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
-    // Extract and reduction results in scalars, so no result layout is needed.
-    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
-    Value reduction = vector::ReductionOp::create(
-        rewriter, loc, kind, slice.getResult(), accExtract);
-    reductionResult =
-        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
-  }
-  return reductionResult;
-}
-
 /// This patterns distribute the `vector.multi_reduction` operation across
 /// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
 /// layouts for the source and accumulator vectors,
@@ -1453,7 +1390,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
           rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
           {sourceDistType, distributedResultType}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
-      Value result = lowerToVectorReductions(
+      Value result = xegpu::lowerToVectorReductions(
           cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
           cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
           reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
@@ -1465,7 +1402,7 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
     // of multiple ReductionOps. Actual distribution is done by the
     // WarpOpReduction pattern.
     rewriter.setInsertionPointAfter(reductionOp);
-    Value result = lowerToVectorReductions(
+    Value result = xegpu::lowerToVectorReductions(
         cast<TypedValue<VectorType>>(reductionOp.getSource()),
         cast<TypedValue<VectorType>>(reductionOp.getAcc()),
         reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
@@ -2151,23 +2088,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
 
-  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
-                          vector::CombiningKind kind, uint32_t size) {
-    // First reduce on a single thread to get per lane reduction value.
-    Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
-    // Parallel reduction using butterfly shuffles.
-    for (uint64_t i = 1; i < size; i <<= 1) {
-      Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
-                                              /*width=*/size,
-                                              /*mode=*/gpu::ShuffleMode::XOR)
-                           .getShuffleResult();
-      laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
-    }
-    return laneVal;
-  };
-
   vector::populateDistributeReduction(
-      patterns, warpReduction,
+      patterns, xegpu::subgroupReduction,
       /*pattern benefit=*/PatternHierarchy::Regular);
 
   vector::populatePropagateWarpVectorDistributionPatterns(

diff  --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index c47fd92fe46d7..5fdab1e759deb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -651,6 +651,82 @@ int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
   return largest;
 }
 
+Value xegpu::subgroupReduction(Location loc, OpBuilder &builder, Value input,
+                               vector::CombiningKind kind, uint32_t size) {
+  // First reduce on a single thread to get per lane reduction value.
+  Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
+  // Parallel reduction using butterfly shuffles.
+  for (uint64_t i = 1; i < size; i <<= 1) {
+    Value shuffled =
+        gpu::ShuffleOp::create(builder, loc, laneVal, i, /**  width = **/ size,
+                               /**  mode = **/ gpu::ShuffleMode::XOR)
+            .getShuffleResult();
+    laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+  }
+  return laneVal;
+}
+
+Value xegpu::lowerToVectorReductions(TypedValue<VectorType> src,
+                                     TypedValue<VectorType> acc,
+                                     vector::CombiningKind kind,
+                                     int64_t reductionDim, Location loc,
+                                     PatternRewriter &rewriter) {
+  // Expecting a 2D source vector.
+  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+  VectorType sourceType = src.getType();
+  int64_t sourceH = sourceType.getShape()[0];
+  int64_t sourceW = sourceType.getShape()[1];
+  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // Create a constant vector to hold the result of the reduction.
+  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+  Value reductionResult = arith::ConstantOp::create(
+      rewriter, loc, acc.getType(),
+      DenseElementsAttr::get(acc.getType(), zeroAttr));
+  auto srcLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(src));
+  auto accLayout = xegpu::getTemporaryLayout(dyn_cast<OpResult>(acc));
+  // Reduction result should have the same layout as the accumulator.
+  xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+  // For each slice of the source, extract the slice vector, do a reduction
+  // and, insert the reduced value back to the result vector.
+  for (int i = 0; i < nSlices; ++i) {
+    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+    if (reductionDim == 1) {
+      sliceOffsets = {i, 0};
+      sliceSizes = {1, sourceW};
+    } else {
+      sliceOffsets = {0, i};
+      sliceSizes = {sourceH, 1};
+    }
+
+    vector::ExtractStridedSliceOp extractOp =
+        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+                                              sliceSizes, {1, 1});
+    // Extract strided slice has the same layout as src.
+    xegpu::setTemporaryLayout(extractOp->getOpResult(0), srcLayout);
+
+    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+
+    vector::ShapeCastOp slice = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get({nSliceElements}, sourceType.getElementType()),
+        extractOp.getResult());
+
+    // Shape cast output has the same layout as the accumulator. Shape cast
+    // source has the same layout as the original reduction source.
+    xegpu::setTemporaryLayout(slice->getOpOperand(0), srcLayout);
+    xegpu::setTemporaryLayout(slice->getOpResult(0), accLayout);
+    // Extract and reduction results in scalars, so no result layout is needed.
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    Value reduction = vector::ReductionOp::create(
+        rewriter, loc, kind, slice.getResult(), accExtract);
+    reductionResult =
+        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+    // Insert op should have the same layout as the accumulator.
+    xegpu::setTemporaryLayout(cast<OpResult>(reductionResult), accLayout);
+  }
+  return reductionResult;
+}
+
 /// Explicit instantiations
 template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
                                            ArrayRef<int> candidateMultiples);

diff  --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..1ec0879d4fb47 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -2,6 +2,10 @@
 // RUN: mlir-opt  --xevm-attach-target='module=xevm_* chip=pvc' --allow-unregistered-dialect \
 // RUN: --test-xegpu-sg-to-wi-distribute-experimental --split-input-file %s | FileCheck %s
 
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-xegpu-sg-to-wi-distribute-experimental="enable-rewrite-multi-reduction-to-reductions"  \
+// RUN: --split-input-file  %s | FileCheck --check-prefix=CHECK-REWRITE %s
+
 
 
 gpu.module @xevm_module {
@@ -149,4 +153,168 @@ gpu.func @prefetch_nd() {
     : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   gpu.return
 }
+
+// CHECK-LABEL: gpu.func @vector_reduction
+// CHECK:     %[[CST:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK:     %[[LANE_RED:.*]] = vector.reduction <add>, %[[CAST:.*]] : vector<2xf32> into f32
+// CHECK-DAG: %[[C16_1:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK:     %[[SHUFFLE1:.*]], %{{.*}} = gpu.shuffle  xor %[[LANE_RED]], %[[C1]], %[[C16_1]] : f32
+// CHECK:     %[[ADD1:.*]] = arith.addf %[[LANE_RED]], %[[SHUFFLE1]] : f32
+// CHECK-DAG: %[[C16_2:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK:     %[[SHUFFLE2:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD1]], %[[C2]], %[[C16_2]] : f32
+// CHECK:     %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[SHUFFLE2]] : f32
+// CHECK-DAG: %[[C16_3:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK:     %[[SHUFFLE3:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD2]], %[[C4]], %[[C16_3]] : f32
+// CHECK:     %[[ADD3:.*]] = arith.addf %[[ADD2]], %[[SHUFFLE3]] : f32
+// CHECK-DAG: %[[C16_4:.*]] = arith.constant 16 : i32
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK:     %[[SHUFFLE4:.*]], %{{.*}} = gpu.shuffle  xor %[[ADD3]], %[[C8]], %[[C16_4]] : f32
+// CHECK:     %[[ADD4:.*]] = arith.addf %[[ADD3]], %[[SHUFFLE4]] : f32
+// CHECK:     %[[FINAL:.*]] = arith.addf %[[ADD4]], %[[CST]] : f32
+gpu.func @vector_reduction() {
+  %acc = arith.constant 1.0 : f32
+  %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : () -> vector<32xf32>
+  %2 = vector.reduction <add>, %0, %acc : vector<32xf32> into f32
+  gpu.return
+}
+
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK-REWRITE-DAG:     %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME:      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
+// CHECK-REWRITE-DAG:     %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG:     %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE:         %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [1, 0], sizes = [1, 16], strides = [1, 1]} : vector<2x16xf32> to vector<1x16xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : vector<1x16xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = "some_def"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : () -> (vector<2x16xf32>)
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<2xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+      }
+      [1] : vector<2x16xf32> to vector<2xf32>
+  gpu.return
+}
+
+// CHECK-REWRITE-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK-REWRITE-DAG:     %[[SRC:.*]] = "some_def"() {layout_result_0 =
+// CHECK-REWRITE-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
+// CHECK-REWRITE-DAG:     %[[ACC:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE-DAG:     %[[ZERO:.*]] = arith.constant
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      dense<0.000000e+00> : vector<2xf32>
+// CHECK-REWRITE:         %[[SLICE0:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST0:.*]] = vector.shape_cast %[[SLICE0]]
+// CHECK-REWRITE-SAME:      {{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC0:.*]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %[[ACC0]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS0:.*]] = vector.insert %[[RED0]], %[[ZERO]] [0]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[SLICE1:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-REWRITE-SAME:       offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REWRITE-NEXT:    %[[CAST1:.*]] = vector.shape_cast %[[SLICE1]]
+// CHECK-REWRITE-SAME:      {{{.*}}, layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+// CHECK-REWRITE-SAME:      : vector<16x1xf32> to vector<16xf32>
+// CHECK-REWRITE-NEXT:    %[[ACC1:.*]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32>
+// CHECK-REWRITE-NEXT:    %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %[[ACC1]] : vector<16xf32> into f32
+// CHECK-REWRITE-NEXT:    %[[INS1:.*]] = vector.insert %[[RED1]], %[[INS0]] [1]
+// CHECK-REWRITE-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-REWRITE-SAME:      : f32 into vector<2xf32>
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = "some_def"()
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      : () -> (vector<16x2xf32>)
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+      dense<0.0>  : vector<2xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+      }
+      [0] : vector<16x2xf32> to vector<2xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x1xf32>
+// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [0] : vector<4x1xf32> to vector<1xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      dense<0.0>  : vector<4x16xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+      dense<0.0>  : vector<16xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+      }
+      [0] : vector<4x16xf32> to vector<16xf32>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK:         %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x12xf32>
+// CHECK:         %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:         %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x12xf32> to vector<1xf32>
+// CHECK:         gpu.return
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) {
+  %c0 = arith.constant 0 : index
+    %src = arith.constant
+      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      dense<0.0>  : vector<16x12xf32>
+    %acc = arith.constant
+      {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>}
+      dense<0.0>  : vector<16xf32>
+    %1 = vector.multi_reduction <add>, %src, %acc
+      {
+        layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+      }
+      [1] : vector<16x12xf32> to vector<16xf32>
+  gpu.return
+}
 }

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 55f8b870b6238..645e889d40657 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -288,13 +288,9 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
 // CHECK:      %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> ({{.*}}) {
 // CHECK-NEXT:   %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<2x16xf32>
 // CHECK-NEXT:   %[[T2:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
+// CHECK-NEXT:   %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
 // CHECK-NEXT:   %[[T5:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
-// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %cst : vector<16xf32> into f32
-// CHECK-NEXT:   %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK-NEXT:   gpu.yield %[[T7]]
-// CHECK-NEXT: }
+// CHECK-NEXT:   %[[T6:.*]] = vector.reduction <add>, %[[T5]], %{{.*}} : vector<16xf32> into f32
 gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {
@@ -356,20 +352,27 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
 
 
 // CHECK-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
-// CHECK:     %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>{{.*}}) {
-// CHECK:       %[[SRC:.*]] = "some_def"() {{.*}} : () -> vector<16x2xf32>
+// CHECK:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:       %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>) {
+// CHECK:       %[[SRC:.*]] = "some_def"()
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:    : () -> vector<16x2xf32>
 // CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME:    {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T4:.*]] = vector.insert %[[T3]], %cst_0 [0] : f32 into vector<2xf32>
-// CHECK:       %[[T5:.*]] = vector.extract_strided_slice %[[SRC]]
-// CHECK-SAME:     {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
-// CHECK:       %[[T6:.*]] = vector.shape_cast %[[T5]] {{.*}} : vector<16x1xf32> to vector<16xf32>
-// CHECK:       %[[T7:.*]] = vector.reduction <add>, %[[T6]], %{{.*}} : vector<16xf32> into f32
-// CHECK:       %[[T8:.*]] = vector.insert %[[T7]], %[[T4]] [1] : f32 into vector<2xf32>
-// CHECK:       gpu.yield %[[T8]]
-// CHECK:     }
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK:       %[[T2:.*]] = vector.shape_cast %[[T1]]
+// CHECK-SAME:    {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME:    : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[CST]] : vector<16xf32> into f32
+// CHECK:       %[[T4:.*]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME:    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK:       %[[T5:.*]] = vector.shape_cast %[[T4]]
+// CHECK-SAME:    {layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+// CHECK-SAME:     layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+// CHECK-SAME:    : vector<16x1xf32> to vector<16xf32>
+// CHECK:       %[[T6:.*]] = vector.reduction <add>, %[[T5]], %[[CST]] : vector<16xf32> into f32
 gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index) {
   %c0 = arith.constant 0 : index
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2xf32>) {

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 20bcb24a301e6..33af2c5b33d89 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -273,6 +273,12 @@ struct TestXeGPUSgToWiDistributeExperimental
            "Work-item Distribution";
   }
 
+  Option<bool> enableRewriteMultiReductionToReductions{
+      *this, "enable-rewrite-multi-reduction-to-reductions",
+      llvm::cl::desc("Partially lower multi-reduction ops to reduction ops if "
+                     "the reduction dimension is distributed."),
+      llvm::cl::init(false)};
+
   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
     registry.insert<arith::ArithDialect>();
     registry.insert<memref::MemRefDialect>();
@@ -284,7 +290,8 @@ struct TestXeGPUSgToWiDistributeExperimental
 
   TestXeGPUSgToWiDistributeExperimental() = default;
   TestXeGPUSgToWiDistributeExperimental(
-      const TestXeGPUSgToWiDistributeExperimental &pass) = default;
+      const TestXeGPUSgToWiDistributeExperimental &pass)
+      : PassWrapper(pass) {}
 
   void runOnOperation() override {
     MLIRContext *ctx = &getContext();
@@ -298,6 +305,19 @@ struct TestXeGPUSgToWiDistributeExperimental
     };
     typeConverter.addSourceMaterialization(materializeCast);
     typeConverter.addTargetMaterialization(materializeCast);
+
+    // If `enableRewriteMultiReductionToReductions` is set, only focus on
+    // testing the partial lowering of vector::MultiReductionOp.
+    if (enableRewriteMultiReductionToReductions) {
+      xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+      ConversionTarget target(*ctx);
+      RewritePatternSet patterns(ctx);
+      xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(patterns,
+                                                                     target);
+      (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+      return;
+    }
+
     ConversionTarget target(*ctx);
     RewritePatternSet patterns(ctx);
     xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(


        


More information about the Mlir-commits mailing list