[Mlir-commits] [mlir] [mlir][xegpu] Add support for `vector.multi_reduction` and `vector.shape_cast` SIMT distribution. (PR #157560)

Charitha Saumya llvmlistbot at llvm.org
Tue Sep 9 11:33:25 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/157560

>From eaaca7f54a9333b1841283b4483cb9c8f91f9f6b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 16:42:52 +0000
Subject: [PATCH 01/19] save

---
 .../Vector/Transforms/VectorDistribute.cpp    | 242 +++++++++++++++---
 1 file changed, 213 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a91cba7..2d9fcaee37282 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,13 +15,19 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
 #include <utility>
 
 using namespace mlir;
@@ -939,8 +945,40 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
   }
 };
 
+static VectorType
+tryFindDistributedType(TypedValue<VectorType> source,
+                       WarpExecuteOnLane0Op warpOp,
+                       const DistributionMapFn &distributionMapFn) {
+  VectorType distributedType = source.getType();
+  // Check if the source is yielded from the warp op.
+  gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+    return operand.get() == source;
+  });
+
+  if (it != yieldOp->getOpOperands().end()) {
+    // If the source is yielded from the warp op, we can use the matching
+    // warp result type as the distributed source type.
+    distributedType =
+        cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+  } else {
+    // If the source is not yielded from the warp op, we need to compute
+    // the distributed source type based on the distribution map and the
+    // warp size.
+    AffineMap map = distributionMapFn(source);
+    VectorType computed =
+        getDistributedType(source.getType(), map, warpOp.getWarpSize());
+    if (!computed)
+      return source.getType();
+    distributedType = computed;
+  }
+  return distributedType;
+}
+
 struct WarpOpBroadcast : public WarpDistributionPattern {
-  using Base::Base;
+  WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -953,18 +991,23 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
     auto destVecType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
     Value broadcastSrc = broadcastOp.getSource();
-    Type broadcastSrcType = broadcastSrc.getType();
+    Type srcDistributedType = broadcastSrc.getType();
+
+    if (isa<VectorType>(srcDistributedType))
+      srcDistributedType =
+          tryFindDistributedType(cast<TypedValue<VectorType>>(broadcastSrc),
+                                 warpOp, distributionMapFn);
 
     // Check that the broadcast actually spans a set of values uniformly across
     // all threads. In other words, check that each thread can reconstruct
     // their own broadcast.
     // For that we simply check that the broadcast we want to build makes sense.
-    if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
+    if (vector::isBroadcastableTo(srcDistributedType, destVecType) !=
         vector::BroadcastableToResult::Success)
       return failure();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
+        rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = vector::BroadcastOp::create(
         rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
@@ -972,49 +1015,83 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
                                 broadcasted);
     return success();
   }
+
+private:
+  DistributionMapFn distributionMapFn;
 };
 
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-  using Base::Base;
+
+  WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
-
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
     unsigned int operandNumber = operand->getOperandNumber();
-    auto castDistributedType =
+    VectorType sourceType = oldCastOp.getSourceVectorType();
+    VectorType distributedResultType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
-    VectorType castOriginalType = oldCastOp.getSourceVectorType();
-    VectorType castResultType = castDistributedType;
-
-    // We expect the distributed type to have a smaller rank than the original
-    // type. Prepend with size-one dimensions to make them the same.
-    unsigned castDistributedRank = castDistributedType.getRank();
-    unsigned castOriginalRank = castOriginalType.getRank();
-    if (castDistributedRank < castOriginalRank) {
-      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
-      llvm::append_range(shape, castDistributedType.getShape());
-      castDistributedType =
-          VectorType::get(shape, castDistributedType.getElementType());
+    VectorType distributedSourceType = sourceType;
+    bool isResultDistributed = distributedResultType.getNumElements() <
+                               oldCastOp.getResultVectorType().getNumElements();
+
+    // If the result is not distributed, source distribted type is the same
+    // as the source type. If the result is distributed, we need to compute the
+    // distributed source type according to following rules:
+    // 1. If the source type is yielded from the warp op, we can use the
+    //    matching warp result type as the distributed source type.
+    // 2. If the source type is not yielded from the warp op, we need
+    //    to compute the distributed source type based on the distribution map
+    //    and the warp size.
+    if (isResultDistributed) {
+      // Check if the source is yielded from the warp op.
+      gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      auto *it =
+          llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
+            return operand.get() == oldCastOp.getSource();
+          });
+
+      if (it != yieldOp->getOpOperands().end()) {
+        // If the source is yielded from the warp op, we can use the matching
+        // warp result type as the distributed source type.
+        distributedSourceType =
+            cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
+      } else {
+        // If the source is not yielded from the warp op, we need to compute
+        // the distributed source type based on the distribution map and the
+        // warp size.
+        AffineMap map = distributionMapFn(oldCastOp.getSource());
+        distributedSourceType =
+            getDistributedType(sourceType, map, warpOp.getWarpSize());
+        if (!distributedSourceType)
+          return rewriter.notifyMatchFailure(
+              oldCastOp,
+              "cannot compute distributed source type for shape cast");
+      }
     }
 
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
+        rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value newCast = vector::ShapeCastOp::create(
-        rewriter, oldCastOp.getLoc(), castResultType,
+        rewriter, oldCastOp.getLoc(), distributedResultType,
         newWarpOp->getResult(newRetIndices[0]));
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
     return success();
   }
+
+private:
+  DistributionMapFn distributionMapFn;
 };
 
 /// Sink out vector.create_mask op feeding into a warp op yield.
@@ -1996,6 +2073,114 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
+struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
+  VectorMultiDimReductionDistribution(MLIRContext *context,
+                                      PatternBenefit benefit = 1)
+      : WarpDistributionPattern(context, benefit) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+
+    // Col reduction.
+    if (reductionDims[0] == 0) {
+      // Yield the source vector and the accumulator.
+      if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+        return rewriter.notifyMatchFailure(
+            warpOp, "Source vector dimension must be divisible by warp size.");
+      SmallVector<int64_t> shape(sourceType.getShape());
+      shape[1] = shape[1] / warpOp.getWarpSize();
+      auto sourceDistributedType = VectorType::get(shape, elementType);
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistributedType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+      // Create new reduction op.
+      // auto newOp = vector::MultiDimReductionOp::create(
+      //     rewriter, reductionOp.getLoc(), distributedResultType,
+      //     reductionOp.getKind(),
+      //     /** source = **/ newWarpOp.getResult(newRetIndices[0]),
+      //     /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]),
+      //     reductionDims);
+      // Create a constant zero value for storing the reduction result.
+      // rewriter.setInsertionPointAfter(reductionOp);
+      auto zeroAttr =
+          rewriter.getZeroAttr(distributedResultType.getElementType());
+      Value result = arith::ConstantOp::create(
+          rewriter, reductionOp->getLoc(), distributedResultType,
+          DenseElementsAttr::get(distributedResultType, zeroAttr));
+      int nCols = sourceDistributedType.getShape()[1];
+      Value source = newWarpOp.getResult(newRetIndices[0]);
+      Value acc = newWarpOp.getResult(newRetIndices[1]);
+      for (int i = 0; i < nCols; ++i) {
+        Value col = vector::ExtractStridedSliceOp::create(
+            rewriter, reductionOp.getLoc(), source, {0, i},
+            {sourceDistributedType.getShape()[0], 1}, {1, 1});
+        col = vector::ShapeCastOp::create(
+            rewriter, reductionOp.getLoc(),
+            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+            col);
+        Value accCol =
+            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+        Value colReduce = vector::ReductionOp::create(
+            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+        // Insert the reduced column into the result.
+        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                          colReduce, result, i);
+      }
+      // Replace the warp op result with the new reduction op.
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+      return success();
+    }
+    // Row reduction.
+    // Create a constant zero value for storing the reduction result.
+    rewriter.setInsertionPointAfter(reductionOp);
+    auto zeroAttr =
+        rewriter.getZeroAttr(distributedResultType.getElementType());
+    Value result = arith::ConstantOp::create(
+        rewriter, reductionOp->getLoc(), distributedResultType,
+        DenseElementsAttr::get(distributedResultType, zeroAttr));
+    // Value result = arith::ConstantOp::create(
+    //     rewriter, reductionOp.getLoc(),
+    //     rewriter.getIntegerAttr(reductionOp.getType(), 0));
+    int nRows = sourceType.getShape()[0];
+    // For each row, do a vector reduction.
+    for (int i = 0; i < nRows; ++i) {
+      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                               reductionOp.getSource(), i);
+      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                            reductionOp.getAcc(), i);
+      Value rowReduce = vector::ReductionOp::create(
+          rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+      result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                        rowReduce, result, i);
+    }
+    // Replace the warp op result with the final result.
+    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
@@ -2016,16 +2201,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
-           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
-           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
-          patterns.getContext(), benefit);
+  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpExtract,
+               WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+               WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+               WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+      patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
-  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
-                               benefit);
+  patterns.add<WarpOpScfForOp, WarpOpShapeCast, WarpOpBroadcast>(
+      patterns.getContext(), distributionMapFn, benefit);
 }
 
 void mlir::vector::populateDistributeReduction(

>From 56c3441e9443660788e51064f8206c5e4ac9fbaf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:07:10 +0000
Subject: [PATCH 02/19] save

---
 .../Vector/Transforms/VectorDistribute.cpp    | 110 +++++------------
 .../Vector/vector-warp-distribute.mlir        | 111 ++++++++++++++++++
 2 files changed, 143 insertions(+), 78 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2d9fcaee37282..6410a895fc9ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -945,40 +945,8 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
   }
 };
 
-static VectorType
-tryFindDistributedType(TypedValue<VectorType> source,
-                       WarpExecuteOnLane0Op warpOp,
-                       const DistributionMapFn &distributionMapFn) {
-  VectorType distributedType = source.getType();
-  // Check if the source is yielded from the warp op.
-  gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  auto *it = llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
-    return operand.get() == source;
-  });
-
-  if (it != yieldOp->getOpOperands().end()) {
-    // If the source is yielded from the warp op, we can use the matching
-    // warp result type as the distributed source type.
-    distributedType =
-        cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
-  } else {
-    // If the source is not yielded from the warp op, we need to compute
-    // the distributed source type based on the distribution map and the
-    // warp size.
-    AffineMap map = distributionMapFn(source);
-    VectorType computed =
-        getDistributedType(source.getType(), map, warpOp.getWarpSize());
-    if (!computed)
-      return source.getType();
-    distributedType = computed;
-  }
-  return distributedType;
-}
-
 struct WarpOpBroadcast : public WarpDistributionPattern {
-  WarpOpBroadcast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
-      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -991,23 +959,18 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
     auto destVecType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
     Value broadcastSrc = broadcastOp.getSource();
-    Type srcDistributedType = broadcastSrc.getType();
-
-    if (isa<VectorType>(srcDistributedType))
-      srcDistributedType =
-          tryFindDistributedType(cast<TypedValue<VectorType>>(broadcastSrc),
-                                 warpOp, distributionMapFn);
+    Type broadcastSrcType = broadcastSrc.getType();
 
     // Check that the broadcast actually spans a set of values uniformly across
     // all threads. In other words, check that each thread can reconstruct
     // their own broadcast.
     // For that we simply check that the broadcast we want to build makes sense.
-    if (vector::isBroadcastableTo(srcDistributedType, destVecType) !=
+    if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
         vector::BroadcastableToResult::Success)
       return failure();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastSrc}, {srcDistributedType}, newRetIndices);
+        rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value broadcasted = vector::BroadcastOp::create(
         rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
@@ -1015,9 +978,6 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
                                 broadcasted);
     return success();
   }
-
-private:
-  DistributionMapFn distributionMapFn;
 };
 
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
@@ -2100,37 +2060,37 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
       return rewriter.notifyMatchFailure(
           warpOp, "Only 1 reduction dimension is supported.");
 
+    // Create a constant vector to store the result of the reduction per lane.
+    TypedAttr zeroAttr =
+        rewriter.getZeroAttr(distributedResultType.getElementType());
+    Value result = arith::ConstantOp::create(
+        rewriter, reductionOp->getLoc(), distributedResultType,
+        DenseElementsAttr::get(distributedResultType, zeroAttr));
+
     // Col reduction.
     if (reductionDims[0] == 0) {
-      // Yield the source vector and the accumulator.
+      // Source vector must be distributable to lanes in the col dimension.
       if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
         return rewriter.notifyMatchFailure(
             warpOp, "Source vector dimension must be divisible by warp size.");
+      // Compute source distributed type.
       SmallVector<int64_t> shape(sourceType.getShape());
       shape[1] = shape[1] / warpOp.getWarpSize();
       auto sourceDistributedType = VectorType::get(shape, elementType);
+
+      // Yield the source and acc vectors from the WarpOp.
       SmallVector<size_t> newRetIndices;
       auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
           rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
           {sourceDistributedType, distributedResultType}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
-      // Create new reduction op.
-      // auto newOp = vector::MultiDimReductionOp::create(
-      //     rewriter, reductionOp.getLoc(), distributedResultType,
-      //     reductionOp.getKind(),
-      //     /** source = **/ newWarpOp.getResult(newRetIndices[0]),
-      //     /** accumulator = **/ newWarpOp.getResult(newRetIndices[1]),
-      //     reductionDims);
-      // Create a constant zero value for storing the reduction result.
-      // rewriter.setInsertionPointAfter(reductionOp);
-      auto zeroAttr =
-          rewriter.getZeroAttr(distributedResultType.getElementType());
-      Value result = arith::ConstantOp::create(
-          rewriter, reductionOp->getLoc(), distributedResultType,
-          DenseElementsAttr::get(distributedResultType, zeroAttr));
+
       int nCols = sourceDistributedType.getShape()[1];
       Value source = newWarpOp.getResult(newRetIndices[0]);
       Value acc = newWarpOp.getResult(newRetIndices[1]);
+      // For each column owned by a lane, extract the column (of size nRows x
+      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+      // result back to the result vector.
       for (int i = 0; i < nCols; ++i) {
         Value col = vector::ExtractStridedSliceOp::create(
             rewriter, reductionOp.getLoc(), source, {0, i},
@@ -2143,7 +2103,6 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
             vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
         Value colReduce = vector::ReductionOp::create(
             rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
-        // Insert the reduced column into the result.
         result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
                                           colReduce, result, i);
       }
@@ -2151,19 +2110,13 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
       rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
       return success();
     }
-    // Row reduction.
-    // Create a constant zero value for storing the reduction result.
+    // For row reductions, we simply rewrite the MultiReductionOp in terms of
+    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+    // pattern.
     rewriter.setInsertionPointAfter(reductionOp);
-    auto zeroAttr =
-        rewriter.getZeroAttr(distributedResultType.getElementType());
-    Value result = arith::ConstantOp::create(
-        rewriter, reductionOp->getLoc(), distributedResultType,
-        DenseElementsAttr::get(distributedResultType, zeroAttr));
-    // Value result = arith::ConstantOp::create(
-    //     rewriter, reductionOp.getLoc(),
-    //     rewriter.getIntegerAttr(reductionOp.getType(), 0));
     int nRows = sourceType.getShape()[0];
-    // For each row, do a vector reduction.
+    // For each row of the source, extract the row vector, do a reduction and,
+    // insert the result back to the result.
     for (int i = 0; i < nRows; ++i) {
       Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
                                                reductionOp.getSource(), i);
@@ -2201,15 +2154,16 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpExtract,
-               WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
-               WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
-               WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
+           WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
+           WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
+           WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
-  patterns.add<WarpOpScfForOp, WarpOpShapeCast, WarpOpBroadcast>(
-      patterns.getContext(), distributionMapFn, benefit);
+  patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(),
+                                                distributionMapFn, benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 4d2c964a6df3c..bf70fbbd27244 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -850,6 +850,83 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
   return %r : f32
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
+// CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP:     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP:     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP:     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP:   }
+// CHECK-PROP:   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP:   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP:   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP:   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP:   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP:   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP:   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP:   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP:   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP:   return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %acc = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
+    gpu.yield %1 : vector<64xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
+// CHECK-PROP:    %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP:    %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP:    %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP:    %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP:    %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP:    %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP:    %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP:      %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP:      gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP:    }
+// CHECK-PROP:    %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:    %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:    %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP:    %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:    %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP:    %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:    %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP:    %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:    %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP:    %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:    %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP:    %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+//
+// CHECK-PROP:    %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:    %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:    %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP:    %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:    %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP:    %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:    %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP:    %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:    %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP:    %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:    %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP:    %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP:    %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP:    return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+  %zero = arith.constant dense<0.0> : vector<2xf32>
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<2x32xf32>)
+    %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
+    gpu.yield %1 : vector<2xf32>
+  }
+  return %r : vector<2xf32>
+}
+
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_duplicate_yield(
@@ -1567,6 +1644,40 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 // CHECK-PROP:   %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
 // CHECK-PROP:   return %[[CAST]] : vector<4xf32>
 
+// -----
+func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) {
+    %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32>
+    %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32>
+    gpu.yield %3 : vector<32x64xf32>
+  }
+  return %r : vector<32x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d
+// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32>
+// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32>
+// CHECK-PROP:    return %[[CAST]] : vector<32x2xf32>
+
+// -----
+func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) {
+    %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32>
+    %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32>
+    gpu.yield %3 : vector<8x4x2xf32>
+  }
+  return %r : vector<8x4x2xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result
+// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32>
+// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32>
+// CHECK-PROP:    return %[[CAST]] : vector<8x4x2xf32>
+
 // -----
 
 func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {

>From 01880b561e94c6cb752e6eddb16957e00dbdc97f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:26:49 +0000
Subject: [PATCH 03/19] save

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6410a895fc9ae..8dc1418e09006 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2033,6 +2033,12 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
+// This patterns distribute the `vector.multi_reduction` operation across
+// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
+// that source vector is distributed in column dimension (i.e. Each lane owns
+// complete column(s) of the source vector.
+// TODO: Add support for the case where source rows are distributed accross
+// lanes. Requires DistributionMapFn to express the data distribution.
 struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
   VectorMultiDimReductionDistribution(MLIRContext *context,
                                       PatternBenefit benefit = 1)

>From 53da9928117634d6eb929f81cbfa59ed4c06d884 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:28:13 +0000
Subject: [PATCH 04/19] save

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8dc1418e09006..c88c001f34843 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2036,9 +2036,9 @@ struct WarpOpReduction : public WarpDistributionPattern {
 // This patterns distribute the `vector.multi_reduction` operation across
 // lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
 // that source vector is distributed in column dimension (i.e. Each lane owns
-// complete column(s) of the source vector.
-// TODO: Add support for the case where source rows are distributed accross
-// lanes. Requires DistributionMapFn to express the data distribution.
+// complete column(s) of the source vector).
+// TODO: Add support for the case where source rows are distributed across
+// lanes. Requires `DistributionMapFn` to express the data distribution.
 struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
   VectorMultiDimReductionDistribution(MLIRContext *context,
                                       PatternBenefit benefit = 1)

>From affd4aadb2e0f3f7cd19b0805b34067c1fa65371 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 19 Aug 2025 23:48:08 +0000
Subject: [PATCH 05/19] save

---
 .../Vector/vector-warp-distribute.mlir        | 74 +++++++++----------
 1 file changed, 37 insertions(+), 37 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bf70fbbd27244..bf0191655d654 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -879,44 +879,44 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
 
 // -----
 // CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
-// CHECK-PROP:    %[[C16:.*]] = arith.constant 16 : i32
-// CHECK-PROP:    %[[C8:.*]] = arith.constant 8 : i32
-// CHECK-PROP:    %[[C4:.*]] = arith.constant 4 : i32
-// CHECK-PROP:    %[[C2:.*]] = arith.constant 2 : i32
-// CHECK-PROP:    %[[C1:.*]] = arith.constant 1 : i32
-// CHECK-PROP:    %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-PROP:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-PROP:    %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
-// CHECK-PROP:      %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
-// CHECK-PROP:      gpu.yield %[[SRC]] : vector<2x32xf32>
-// CHECK-PROP:    }
-// CHECK-PROP:    %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:    %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:    %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
-// CHECK-PROP:    %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:    %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
-// CHECK-PROP:    %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:    %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
-// CHECK-PROP:    %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:    %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
-// CHECK-PROP:    %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:    %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
-// CHECK-PROP:    %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+// CHECK-PROP-DAG:    %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP-DAG:    %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP-DAG:    %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP-DAG:    %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP-DAG:    %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP-DAG:    %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP-DAG:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP:        %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP:          %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP:          gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP:        }
+// CHECK-PROP:        %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:        %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:        %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP:        %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:        %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP:        %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:        %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP:        %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:        %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP:        %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:        %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP:        %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
 //
-// CHECK-PROP:    %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:    %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:    %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
-// CHECK-PROP:    %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:    %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
-// CHECK-PROP:    %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:    %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
-// CHECK-PROP:    %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:    %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
-// CHECK-PROP:    %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:    %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
-// CHECK-PROP:    %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
-// CHECK-PROP:    %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
-// CHECK-PROP:    return %[[R]] : vector<2xf32>
+// CHECK-PROP:        %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:        %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:        %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP:        %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:        %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP:        %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:        %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP:        %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:        %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP:        %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:        %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP:        %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP:        %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP:        return %[[R]] : vector<2xf32>
 func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
   %zero = arith.constant dense<0.0> : vector<2xf32>
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {

>From df59c20f5d8020ab9ba78f1c360334c738a60404 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 00:01:32 +0000
Subject: [PATCH 06/19] save

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index c88c001f34843..b0b52919c69ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -15,19 +15,13 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
-#include <cstddef>
 #include <utility>
 
 using namespace mlir;

>From 55797318492b6a38801aa27bf9ec97d26523322e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 23:31:35 +0000
Subject: [PATCH 07/19] save

---
 .../Vector/Transforms/VectorDistribute.cpp    | 52 +++++++++++++------
 1 file changed, 37 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index b0b52919c69ce..ab0f1b55d04da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2033,9 +2033,8 @@ struct WarpOpReduction : public WarpDistributionPattern {
 // complete column(s) of the source vector).
 // TODO: Add support for the case where source rows are distributed across
 // lanes. Requires `DistributionMapFn` to express the data distribution.
-struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
-  VectorMultiDimReductionDistribution(MLIRContext *context,
-                                      PatternBenefit benefit = 1)
+struct WarpOpMultiReduction : public WarpDistributionPattern {
+  WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1)
       : WarpDistributionPattern(context, benefit) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -2047,18 +2046,46 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
         cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
     unsigned operandNumber = yieldOperand->getOperandNumber();
     VectorType sourceType = reductionOp.getSourceVectorType();
-    VectorType distributedResultType =
-        cast<VectorType>(warpOp.getResult(operandNumber).getType());
-    Type elementType = distributedResultType.getElementType();
+
     // Only 2D vectors are supported.
     if (sourceType.getRank() != 2)
       return rewriter.notifyMatchFailure(warpOp,
                                          "Only 2D reductions are supported.");
     ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
-    // Only 1 reduction dimension supported.
+    // Only 1 reduction dimension supported. This also ensures that result is
+    // also vector type.
     if (reductionDims.size() != 1)
       return rewriter.notifyMatchFailure(
           warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    auto resultType = cast<VectorType>(reductionOp.getType());
+    auto distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+
+    // Currently we make the following assumptions.
+    // 1. The source vector is distributed in the column dimension. Each lane
+    // owns complete column(s) of the source vector.
+    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+    // case each lane owns its portion of the result (i.e. result is also
+    // distributed).
+    // 3. If reduction dim == 1, its a row reduction that require cross lanes
+    // shuffles. In this case result is not distributed and broadcasted instead.
+    // TODO: These assumptions are fairly restrictive. For example, source
+    // vector can have row distributed layout. Improve support for such cases.
+    if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Source vector dimension must be divisible by warp size.");
+    bool isResultDistributed =
+        distributedResultType.getNumElements() < resultType.getNumElements();
+    if (reductionDim == 0 && !isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be distributed in a col reduction.");
+    if (reductionDim == 1 && isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be broadcasted in a row reduction.");
 
     // Create a constant vector to store the result of the reduction per lane.
     TypedAttr zeroAttr =
@@ -2066,14 +2093,9 @@ struct VectorMultiDimReductionDistribution : public WarpDistributionPattern {
     Value result = arith::ConstantOp::create(
         rewriter, reductionOp->getLoc(), distributedResultType,
         DenseElementsAttr::get(distributedResultType, zeroAttr));
-
     // Col reduction.
-    if (reductionDims[0] == 0) {
-      // Source vector must be distributable to lanes in the col dimension.
-      if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
-        return rewriter.notifyMatchFailure(
-            warpOp, "Source vector dimension must be divisible by warp size.");
-      // Compute source distributed type.
+    if (reductionDim == 0) {
+      // Compute source distributed type assuming each lane owns cols.
       SmallVector<int64_t> shape(sourceType.getShape());
       shape[1] = shape[1] / warpOp.getWarpSize();
       auto sourceDistributedType = VectorType::get(shape, elementType);
@@ -2158,7 +2180,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
       .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
            WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
            WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
-           WarpOpInsertStridedSlice, VectorMultiDimReductionDistribution>(
+           WarpOpInsertStridedSlice, WarpOpMultiReduction>(
           patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);

>From 07c0364d64109faf740023107ab68ec0f242d9ca Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 20 Aug 2025 23:36:26 +0000
Subject: [PATCH 08/19] save

---
 .../Vector/Transforms/VectorDistribute.cpp    |  3 +-
 .../Vector/vector-warp-distribute.mlir        | 32 ++++++++++---------
 2 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index ab0f1b55d04da..aecb6a11a7b36 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2034,8 +2034,7 @@ struct WarpOpReduction : public WarpDistributionPattern {
 // TODO: Add support for the case where source rows are distributed across
 // lanes. Requires `DistributionMapFn` to express the data distribution.
 struct WarpOpMultiReduction : public WarpDistributionPattern {
-  WarpOpMultiReduction(MLIRContext *context, PatternBenefit benefit = 1)
-      : WarpDistributionPattern(context, benefit) {}
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index bf0191655d654..95b8a48404f20 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -852,21 +852,23 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
 
 // -----
 // CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
-// CHECK-PROP:   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
-// CHECK-PROP:     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
-// CHECK-PROP:     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
-// CHECK-PROP:     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
-// CHECK-PROP:   }
-// CHECK-PROP:   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP:   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP:   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
-// CHECK-PROP:   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
-// CHECK-PROP:   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0 {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP:   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP:   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
-// CHECK-PROP:   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
-// CHECK-PROP:   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
-// CHECK-PROP:   return %[[R]] : vector<2xf32>
+// CHECK-PROP      :   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP      :     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP      :     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP      :     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP      :   }
+// CHECK-PROP      :   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME :     {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP      :   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP      :   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP      :   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP      :   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME :     {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP      :   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP      :   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP      :   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP      :   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP      :   return %[[R]] : vector<2xf32>
 func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
     %0 = "some_def"() : () -> (vector<32x64xf32>)

>From 4ed74d89b7808c17bed035a906acf15d2a96c51f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 28 Aug 2025 22:21:49 +0000
Subject: [PATCH 09/19] save work

---
 .../Vector/Transforms/VectorDistribute.cpp    | 58 ++++++++++++++++---
 1 file changed, 51 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2cd743d1ee8e8..4a6ea07c1c236 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1031,7 +1031,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
       return failure();
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
-    unsigned int operandNumber = operand->getOperandNumber();
+    unsigned operandNumber = operand->getOperandNumber();
     VectorType sourceType = oldCastOp.getSourceVectorType();
     VectorType distributedResultType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
@@ -2069,12 +2069,56 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
-// This patterns distribute the `vector.multi_reduction` operation across
-// lanes in a warp. Currently only 2D to 1D reductions are supported and assumes
-// that source vector is distributed in column dimension (i.e. Each lane owns
-// complete column(s) of the source vector).
-// TODO: Add support for the case where source rows are distributed across
-// lanes. Requires `DistributionMapFn` to express the data distribution.
+/// This patterns distribute the `vector.multi_reduction` operation across
+/// lanes in a warp. Currently only 2D to 1D reductions are supported and
+/// assumes that source vector is distributed in column dimension (i.e. Each
+/// lane owns complete column(s) of the source vector).
+/// TODO: Add support for the case where source rows are distributed across
+/// lanes. Requires `DistributionMapFn` to express the data distribution.
+/// Example 1 (Col reduction):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
+///   vector<32xf32> gpu.yield %1 : vector<32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
+/// vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
+/// }
+/// %c = arith.constant dense<0.0> : vector<1xf32>
+/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
+/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
+/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
+/// ```
+/// Example 2 (Row reduction):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
+///   vector<2xf32>
+///   gpu.yield %1 : vector<2xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = arith.constant dense<0.0> : vector<2xf32>
+///   %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
+///   %3 = ("warp.reduction %2") : f32
+///   %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
+///   ... repeat for row 1
+///   gpu.yield %1 : vector<2xf32>
+/// }
 struct WarpOpMultiReduction : public WarpDistributionPattern {
   using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

>From 116e4bceb48ebe85e35fd61dd9e52867897b0a39 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 2 Sep 2025 20:49:19 +0000
Subject: [PATCH 10/19] save work

---
 mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 4a6ea07c1c236..dddfcaf4f273d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
@@ -1039,7 +1040,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
     bool isResultDistributed = distributedResultType.getNumElements() <
                                oldCastOp.getResultVectorType().getNumElements();
 
-    // If the result is not distributed, source distribted type is the same
+    // If the result is not distributed, source distributed type is the same
     // as the source type. If the result is distributed, we need to compute the
     // distributed source type according to following rules:
     // 1. If the source type is yielded from the warp op, we can use the
@@ -1051,7 +1052,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
       // Check if the source is yielded from the warp op.
       gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
           warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-      auto *it =
+      OpOperand *it =
           llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
             return operand.get() == oldCastOp.getSource();
           });
@@ -2155,7 +2156,9 @@ struct WarpOpMultiReduction : public WarpDistributionPattern {
     // case each lane owns its portion of the result (i.e. result is also
     // distributed).
     // 3. If reduction dim == 1, its a row reduction that require cross lanes
-    // shuffles. In this case result is not distributed and broadcasted instead.
+    // shuffles. In this case, the reduction result is not distributed across
+    // lanes. Instead each lane owns a complete copy of the result
+    // (broadcasted).
     // TODO: These assumptions are fairly restrictive. For example, source
     // vector can have row distributed layout. Improve support for such cases.
     if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)

>From a78aec590ed062d9770cd747d7bc61114fc6f23e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 3 Sep 2025 23:07:45 +0000
Subject: [PATCH 11/19] move work

---
 .../Vector/Transforms/VectorDistribute.cpp    | 174 -----------------
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 183 +++++++++++++++++-
 .../Vector/vector-warp-distribute.mlir        | 113 -----------
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  79 ++++++++
 4 files changed, 258 insertions(+), 291 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index dddfcaf4f273d..aacbb4c23af3e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -2070,180 +2070,6 @@ struct WarpOpReduction : public WarpDistributionPattern {
   DistributedReductionFn distributedReductionFn;
 };
 
-/// This patterns distribute the `vector.multi_reduction` operation across
-/// lanes in a warp. Currently only 2D to 1D reductions are supported and
-/// assumes that source vector is distributed in column dimension (i.e. Each
-/// lane owns complete column(s) of the source vector).
-/// TODO: Add support for the case where source rows are distributed across
-/// lanes. Requires `DistributionMapFn` to express the data distribution.
-/// Example 1 (Col reduction):
-/// ```
-/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
-///   %0 = "some_def"() : () -> (vector<16x32xf32>)
-///   %acc = "some_def"() : () -> (vector<32xf32>)
-///   %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
-///   vector<32xf32> gpu.yield %1 : vector<32xf32>
-/// }
-/// ```
-/// is lowered to:
-/// ```
-/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
-/// vector<1xf32>) {
-///   %0 = "some_def"() : () -> (vector<16x32xf32>)
-///   %acc = "some_def"() : () -> (vector<32xf32>)
-///   gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
-/// }
-/// %c = arith.constant dense<0.0> : vector<1xf32>
-/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
-/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
-/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
-/// ```
-/// Example 2 (Row reduction):
-/// ```
-/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-///   %0 = "some_def"() : () -> (vector<2x32xf32>)
-///   %acc = "some_def"() : () -> (vector<2xf32>)
-///   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
-///   vector<2xf32>
-///   gpu.yield %1 : vector<2xf32>
-/// }
-/// ```
-/// is lowered to:
-/// ```
-/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-///   %0 = "some_def"() : () -> (vector<2x32xf32>)
-///   %acc = "some_def"() : () -> (vector<2xf32>)
-///   %1 = arith.constant dense<0.0> : vector<2xf32>
-///   %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
-///   %3 = ("warp.reduction %2") : f32
-///   %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
-///   ... repeat for row 1
-///   gpu.yield %1 : vector<2xf32>
-/// }
-struct WarpOpMultiReduction : public WarpDistributionPattern {
-  using Base::Base;
-  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand =
-        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
-    if (!yieldOperand)
-      return failure();
-    auto reductionOp =
-        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
-    unsigned operandNumber = yieldOperand->getOperandNumber();
-    VectorType sourceType = reductionOp.getSourceVectorType();
-
-    // Only 2D vectors are supported.
-    if (sourceType.getRank() != 2)
-      return rewriter.notifyMatchFailure(warpOp,
-                                         "Only 2D reductions are supported.");
-    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
-    // Only 1 reduction dimension supported. This also ensures that result is
-    // also vector type.
-    if (reductionDims.size() != 1)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Only 1 reduction dimension is supported.");
-    int64_t reductionDim = reductionDims[0];
-    auto resultType = cast<VectorType>(reductionOp.getType());
-    auto distributedResultType =
-        cast<VectorType>(warpOp.getResult(operandNumber).getType());
-    Type elementType = distributedResultType.getElementType();
-
-    // Currently we make the following assumptions.
-    // 1. The source vector is distributed in the column dimension. Each lane
-    // owns complete column(s) of the source vector.
-    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
-    // case each lane owns its portion of the result (i.e. result is also
-    // distributed).
-    // 3. If reduction dim == 1, its a row reduction that require cross lanes
-    // shuffles. In this case, the reduction result is not distributed across
-    // lanes. Instead each lane owns a complete copy of the result
-    // (broadcasted).
-    // TODO: These assumptions are fairly restrictive. For example, source
-    // vector can have row distributed layout. Improve support for such cases.
-    if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
-      return rewriter.notifyMatchFailure(
-          warpOp, "Source vector dimension must be divisible by warp size.");
-    bool isResultDistributed =
-        distributedResultType.getNumElements() < resultType.getNumElements();
-    if (reductionDim == 0 && !isResultDistributed)
-      return rewriter.notifyMatchFailure(
-          warpOp,
-          "Expecting result vector to be distributed in a col reduction.");
-    if (reductionDim == 1 && isResultDistributed)
-      return rewriter.notifyMatchFailure(
-          warpOp,
-          "Expecting result vector to be broadcasted in a row reduction.");
-
-    // Create a constant vector to store the result of the reduction per lane.
-    TypedAttr zeroAttr =
-        rewriter.getZeroAttr(distributedResultType.getElementType());
-    Value result = arith::ConstantOp::create(
-        rewriter, reductionOp->getLoc(), distributedResultType,
-        DenseElementsAttr::get(distributedResultType, zeroAttr));
-    // Col reduction.
-    if (reductionDim == 0) {
-      // Compute source distributed type assuming each lane owns cols.
-      SmallVector<int64_t> shape(sourceType.getShape());
-      shape[1] = shape[1] / warpOp.getWarpSize();
-      auto sourceDistributedType = VectorType::get(shape, elementType);
-
-      // Yield the source and acc vectors from the WarpOp.
-      SmallVector<size_t> newRetIndices;
-      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
-          {sourceDistributedType, distributedResultType}, newRetIndices);
-      rewriter.setInsertionPointAfter(newWarpOp);
-
-      int nCols = sourceDistributedType.getShape()[1];
-      Value source = newWarpOp.getResult(newRetIndices[0]);
-      Value acc = newWarpOp.getResult(newRetIndices[1]);
-      // For each column owned by a lane, extract the column (of size nRows x
-      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
-      // result back to the result vector.
-      for (int i = 0; i < nCols; ++i) {
-        Value col = vector::ExtractStridedSliceOp::create(
-            rewriter, reductionOp.getLoc(), source, {0, i},
-            {sourceDistributedType.getShape()[0], 1}, {1, 1});
-        col = vector::ShapeCastOp::create(
-            rewriter, reductionOp.getLoc(),
-            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
-            col);
-        Value accCol =
-            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
-        Value colReduce = vector::ReductionOp::create(
-            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
-        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
-                                          colReduce, result, i);
-      }
-      // Replace the warp op result with the new reduction op.
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
-      return success();
-    }
-    // For row reductions, we simply rewrite the MultiReductionOp in terms of
-    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
-    // pattern.
-    rewriter.setInsertionPointAfter(reductionOp);
-    int nRows = sourceType.getShape()[0];
-    // For each row of the source, extract the row vector, do a reduction and,
-    // insert the result back to the result.
-    for (int i = 0; i < nRows; ++i) {
-      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
-                                               reductionOp.getSource(), i);
-      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
-                                            reductionOp.getAcc(), i);
-      Value rowReduce = vector::ReductionOp::create(
-          rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
-      result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
-                                        rowReduce, result, i);
-    }
-    // Replace the warp op result with the final result.
-    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
-
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index dddb5eaece2cb..050fa0cd1d342 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,6 +807,180 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// This patterns distribute the `vector.multi_reduction` operation across
+/// lanes in a warp. Currently only 2D to 1D reductions are supported and
+/// assumes that source vector is distributed in column dimension (i.e. Each
+/// lane owns complete column(s) of the source vector).
+/// TODO: Add support for the case where source rows are distributed across
+/// lanes. Requires `DistributionMapFn` to express the data distribution.
+/// Example 1 (Col reduction):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
+///   vector<32xf32> gpu.yield %1 : vector<32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
+/// vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
+/// }
+/// %c = arith.constant dense<0.0> : vector<1xf32>
+/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
+/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
+/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
+/// ```
+/// Example 2 (Row reduction):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
+///   vector<2xf32>
+///   gpu.yield %1 : vector<2xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = arith.constant dense<0.0> : vector<2xf32>
+///   %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
+///   %3 = ("warp.reduction %2") : f32
+///   %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
+///   ... repeat for row 1
+///   gpu.yield %1 : vector<2xf32>
+/// }
+struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
+  using Base::Base;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported. This also ensures that result is
+    // also vector type.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    auto resultType = cast<VectorType>(reductionOp.getType());
+    auto distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    Type elementType = distributedResultType.getElementType();
+
+    // Currently we make the following assumptions.
+    // 1. The source vector is distributed in the column dimension. Each lane
+    // owns complete column(s) of the source vector.
+    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
+    // case each lane owns its portion of the result (i.e. result is also
+    // distributed).
+    // 3. If reduction dim == 1, its a row reduction that require cross lanes
+    // shuffles. In this case, the reduction result is not distributed across
+    // lanes. Instead each lane owns a complete copy of the result
+    // (broadcasted).
+    // TODO: These assumptions are fairly restrictive. For example, source
+    // vector can have row distributed layout. Improve support for such cases.
+    if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Source vector dimension must be divisible by warp size.");
+    bool isResultDistributed =
+        distributedResultType.getNumElements() < resultType.getNumElements();
+    if (reductionDim == 0 && !isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be distributed in a col reduction.");
+    if (reductionDim == 1 && isResultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting result vector to be broadcasted in a row reduction.");
+
+    // Create a constant vector to store the result of the reduction per lane.
+    TypedAttr zeroAttr =
+        rewriter.getZeroAttr(distributedResultType.getElementType());
+    Value result = arith::ConstantOp::create(
+        rewriter, reductionOp->getLoc(), distributedResultType,
+        DenseElementsAttr::get(distributedResultType, zeroAttr));
+    // Col reduction.
+    if (reductionDim == 0) {
+      // Compute source distributed type assuming each lane owns cols.
+      SmallVector<int64_t> shape(sourceType.getShape());
+      shape[1] = shape[1] / warpOp.getWarpSize();
+      auto sourceDistributedType = VectorType::get(shape, elementType);
+
+      // Yield the source and acc vectors from the WarpOp.
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistributedType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+
+      int nCols = sourceDistributedType.getShape()[1];
+      Value source = newWarpOp.getResult(newRetIndices[0]);
+      Value acc = newWarpOp.getResult(newRetIndices[1]);
+      // For each column owned by a lane, extract the column (of size nRows x
+      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
+      // result back to the result vector.
+      for (int i = 0; i < nCols; ++i) {
+        Value col = vector::ExtractStridedSliceOp::create(
+            rewriter, reductionOp.getLoc(), source, {0, i},
+            {sourceDistributedType.getShape()[0], 1}, {1, 1});
+        col = vector::ShapeCastOp::create(
+            rewriter, reductionOp.getLoc(),
+            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
+            col);
+        Value accCol =
+            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
+        Value colReduce = vector::ReductionOp::create(
+            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
+        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                          colReduce, result, i);
+      }
+      // Replace the warp op result with the new reduction op.
+      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+      return success();
+    }
+    // For row reductions, we simply rewrite the MultiReductionOp in terms of
+    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
+    // pattern.
+    rewriter.setInsertionPointAfter(reductionOp);
+    int nRows = sourceType.getShape()[0];
+    // For each row of the source, extract the row vector, do a reduction and,
+    // insert the result back to the result.
+    for (int i = 0; i < nRows; ++i) {
+      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                               reductionOp.getSource(), i);
+      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                            reductionOp.getAcc(), i);
+      Value rowReduce = vector::ReductionOp::create(
+          rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+      result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
+                                        rowReduce, result, i);
+    }
+    // Replace the warp op result with the final result.
+    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -819,10 +993,11 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               UpdateNdOffsetDistribution, GpuBarrierDistribution>(
-      patterns.getContext());
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+           GpuBarrierDistribution, VectorMultiReductionDistribution>(
+          patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 755222494d223..8750582ef1e1f 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -850,85 +850,6 @@ func.func @vector_reduction_acc(%laneid: index) -> (f32) {
   return %r : f32
 }
 
-// -----
-// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
-// CHECK-PROP      :   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
-// CHECK-PROP      :     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
-// CHECK-PROP      :     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
-// CHECK-PROP      :     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
-// CHECK-PROP      :   }
-// CHECK-PROP      :   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0
-// CHECK-PROP-SAME :     {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP      :   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP      :   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
-// CHECK-PROP      :   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
-// CHECK-PROP      :   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0
-// CHECK-PROP-SAME :     {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP      :   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP      :   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
-// CHECK-PROP      :   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
-// CHECK-PROP      :   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
-// CHECK-PROP      :   return %[[R]] : vector<2xf32>
-func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<32x64xf32>)
-    %acc = "some_def"() : () -> (vector<64xf32>)
-    %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
-    gpu.yield %1 : vector<64xf32>
-  }
-  return %r : vector<2xf32>
-}
-
-// -----
-// CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
-// CHECK-PROP-DAG:    %[[C16:.*]] = arith.constant 16 : i32
-// CHECK-PROP-DAG:    %[[C8:.*]] = arith.constant 8 : i32
-// CHECK-PROP-DAG:    %[[C4:.*]] = arith.constant 4 : i32
-// CHECK-PROP-DAG:    %[[C2:.*]] = arith.constant 2 : i32
-// CHECK-PROP-DAG:    %[[C1:.*]] = arith.constant 1 : i32
-// CHECK-PROP-DAG:    %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-PROP-DAG:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-PROP:        %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
-// CHECK-PROP:          %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
-// CHECK-PROP:          gpu.yield %[[SRC]] : vector<2x32xf32>
-// CHECK-PROP:        }
-// CHECK-PROP:        %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:        %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:        %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
-// CHECK-PROP:        %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:        %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
-// CHECK-PROP:        %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:        %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
-// CHECK-PROP:        %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:        %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
-// CHECK-PROP:        %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:        %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
-// CHECK-PROP:        %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
-//
-// CHECK-PROP:        %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:        %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:        %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
-// CHECK-PROP:        %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:        %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
-// CHECK-PROP:        %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:        %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
-// CHECK-PROP:        %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:        %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
-// CHECK-PROP:        %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:        %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
-// CHECK-PROP:        %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
-// CHECK-PROP:        %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
-// CHECK-PROP:        return %[[R]] : vector<2xf32>
-func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
-  %zero = arith.constant dense<0.0> : vector<2xf32>
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<2x32xf32>)
-    %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
-    gpu.yield %1 : vector<2xf32>
-  }
-  return %r : vector<2xf32>
-}
-
 // -----
 
 // CHECK-PROP-LABEL:   func @warp_duplicate_yield(
@@ -1646,40 +1567,6 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 // CHECK-PROP:   %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
 // CHECK-PROP:   return %[[CAST]] : vector<4xf32>
 
-// -----
-func.func @warp_propagate_shape_cast_2d_to_2d(%laneid: index, %src: memref<64x32xf32>) -> vector<32x2xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<32x2xf32>) {
-    %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<64x32xf32>, vector<64x32xf32>
-    %3 = vector.shape_cast %2 : vector<64x32xf32> to vector<32x64xf32>
-    gpu.yield %3 : vector<32x64xf32>
-  }
-  return %r : vector<32x2xf32>
-}
-
-// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_2d_to_2d
-// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [false, true]} : memref<64x32xf32>, vector<2x32xf32>
-// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<2x32xf32> to vector<32x2xf32>
-// CHECK-PROP:    return %[[CAST]] : vector<32x2xf32>
-
-// -----
-func.func @warp_propagate_shape_cast_non_distributed_result(%laneid: index, %src: memref<64xf32>) -> vector<8x4x2xf32> {
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0.000000e+00 : f32
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x4x2xf32>) {
-    %2 = vector.transfer_read %src[%c0], %cst : memref<64xf32>, vector<64xf32>
-    %3 = vector.shape_cast %2 : vector<64xf32> to vector<8x4x2xf32>
-    gpu.yield %3 : vector<8x4x2xf32>
-  }
-  return %r : vector<8x4x2xf32>
-}
-
-// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast_non_distributed_result
-// CHECK-PROP:    %[[READ:.*]] = vector.transfer_read {{.*}} {in_bounds = [true]} : memref<64xf32>, vector<64xf32>
-// CHECK-PROP:    %[[CAST:.*]] = vector.shape_cast %[[READ]] : vector<64xf32> to vector<8x4x2xf32>
-// CHECK-PROP:    return %[[CAST]] : vector<8x4x2xf32>
-
 // -----
 
 func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..cfb5428c92400 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,82 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
+// CHECK-PROP      :   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
+// CHECK-PROP      :     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
+// CHECK-PROP      :     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
+// CHECK-PROP      :     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
+// CHECK-PROP      :   }
+// CHECK-PROP      :   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME :     {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP      :   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP      :   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
+// CHECK-PROP      :   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
+// CHECK-PROP      :   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0
+// CHECK-PROP-SAME :     {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
+// CHECK-PROP      :   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
+// CHECK-PROP      :   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
+// CHECK-PROP      :   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
+// CHECK-PROP      :   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
+// CHECK-PROP      :   return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<32x64xf32>)
+    %acc = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
+    gpu.yield %1 : vector<64xf32>
+  }
+  return %r : vector<2xf32>
+}
+
+// -----
+// CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
+// CHECK-PROP-DAG:    %[[C16:.*]] = arith.constant 16 : i32
+// CHECK-PROP-DAG:    %[[C8:.*]] = arith.constant 8 : i32
+// CHECK-PROP-DAG:    %[[C4:.*]] = arith.constant 4 : i32
+// CHECK-PROP-DAG:    %[[C2:.*]] = arith.constant 2 : i32
+// CHECK-PROP-DAG:    %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-PROP-DAG:    %[[C32:.*]] = arith.constant 32 : i32
+// CHECK-PROP-DAG:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-PROP:        %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
+// CHECK-PROP:          %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
+// CHECK-PROP:          gpu.yield %[[SRC]] : vector<2x32xf32>
+// CHECK-PROP:        }
+// CHECK-PROP:        %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:        %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:        %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
+// CHECK-PROP:        %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:        %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
+// CHECK-PROP:        %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:        %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
+// CHECK-PROP:        %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:        %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
+// CHECK-PROP:        %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:        %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
+// CHECK-PROP:        %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
+//
+// CHECK-PROP:        %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
+// CHECK-PROP:        %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
+// CHECK-PROP:        %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
+// CHECK-PROP:        %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
+// CHECK-PROP:        %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
+// CHECK-PROP:        %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
+// CHECK-PROP:        %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
+// CHECK-PROP:        %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
+// CHECK-PROP:        %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
+// CHECK-PROP:        %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
+// CHECK-PROP:        %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
+// CHECK-PROP:        %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
+// CHECK-PROP:        %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+// CHECK-PROP:        return %[[R]] : vector<2xf32>
+func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+  %zero = arith.constant dense<0.0> : vector<2xf32>
+  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+    %0 = "some_def"() : () -> (vector<2x32xf32>)
+    %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
+    gpu.yield %1 : vector<2xf32>
+  }
+  return %r : vector<2xf32>
+}

>From 2ba43fc7827a64550a038453271ab972f2e94b01 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 3 Sep 2025 23:08:26 +0000
Subject: [PATCH 12/19] save work

---
 .../Vector/Transforms/VectorDistribute.cpp    | 80 ++++++-------------
 1 file changed, 24 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index aacbb4c23af3e..c84eb2c9f8857 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,7 +17,6 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
@@ -1021,75 +1020,44 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-
-  WarpOpShapeCast(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
-      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
         getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
     if (!operand)
       return failure();
+
     auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
 
-    unsigned operandNumber = operand->getOperandNumber();
-    VectorType sourceType = oldCastOp.getSourceVectorType();
-    VectorType distributedResultType =
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto castDistributedType =
         cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
-    VectorType distributedSourceType = sourceType;
-    bool isResultDistributed = distributedResultType.getNumElements() <
-                               oldCastOp.getResultVectorType().getNumElements();
-
-    // If the result is not distributed, source distributed type is the same
-    // as the source type. If the result is distributed, we need to compute the
-    // distributed source type according to following rules:
-    // 1. If the source type is yielded from the warp op, we can use the
-    //    matching warp result type as the distributed source type.
-    // 2. If the source type is not yielded from the warp op, we need
-    //    to compute the distributed source type based on the distribution map
-    //    and the warp size.
-    if (isResultDistributed) {
-      // Check if the source is yielded from the warp op.
-      gpu::YieldOp yieldOp = cast<gpu::YieldOp>(
-          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-      OpOperand *it =
-          llvm::find_if(yieldOp->getOpOperands(), [&](OpOperand &operand) {
-            return operand.get() == oldCastOp.getSource();
-          });
-
-      if (it != yieldOp->getOpOperands().end()) {
-        // If the source is yielded from the warp op, we can use the matching
-        // warp result type as the distributed source type.
-        distributedSourceType =
-            cast<VectorType>(warpOp->getResultTypes()[it->getOperandNumber()]);
-      } else {
-        // If the source is not yielded from the warp op, we need to compute
-        // the distributed source type based on the distribution map and the
-        // warp size.
-        AffineMap map = distributionMapFn(oldCastOp.getSource());
-        distributedSourceType =
-            getDistributedType(sourceType, map, warpOp.getWarpSize());
-        if (!distributedSourceType)
-          return rewriter.notifyMatchFailure(
-              oldCastOp,
-              "cannot compute distributed source type for shape cast");
-      }
+    VectorType castOriginalType = oldCastOp.getSourceVectorType();
+    VectorType castResultType = castDistributedType;
+
+    // We expect the distributed type to have a smaller rank than the original
+    // type. Prepend with size-one dimensions to make them the same.
+    unsigned castDistributedRank = castDistributedType.getRank();
+    unsigned castOriginalRank = castOriginalType.getRank();
+    if (castDistributedRank < castOriginalRank) {
+      SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
+      llvm::append_range(shape, castDistributedType.getShape());
+      castDistributedType =
+          VectorType::get(shape, castDistributedType.getElementType());
     }
 
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {oldCastOp.getSource()}, {distributedSourceType},
+        rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value newCast = vector::ShapeCastOp::create(
-        rewriter, oldCastOp.getLoc(), distributedResultType,
+        rewriter, oldCastOp.getLoc(), castResultType,
         newWarpOp->getResult(newRetIndices[0]));
     rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
     return success();
   }
-
-private:
-  DistributionMapFn distributionMapFn;
 };
 
 /// Sink out vector.create_mask op feeding into a warp op yield.
@@ -2091,15 +2059,15 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
   patterns
-      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpExtract,
-           WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar,
-           WarpOpInsert, WarpOpCreateMask, WarpOpExtractStridedSlice,
-           WarpOpInsertStridedSlice, WarpOpMultiReduction, WarpOpStep>(
+      .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
+           WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
+           WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
           patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
-  patterns.add<WarpOpScfForOp, WarpOpShapeCast>(patterns.getContext(),
-                                                distributionMapFn, benefit);
+  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
+                               benefit);
 }
 
 void mlir::vector::populateDistributeReduction(

>From 3dea80c8bbaf229235771f2e9c66d36ef075b185 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 4 Sep 2025 17:43:03 +0000
Subject: [PATCH 13/19] save test

---
 mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index cfb5428c92400..8e2e96dfc05cc 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -339,14 +339,16 @@ gpu.module @test {
 // CHECK-PROP      :   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
 // CHECK-PROP      :   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
 // CHECK-PROP      :   return %[[R]] : vector<2xf32>
-func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
+gpu.module @test {
+gpu.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
     %0 = "some_def"() : () -> (vector<32x64xf32>)
     %acc = "some_def"() : () -> (vector<64xf32>)
     %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
     gpu.yield %1 : vector<64xf32>
   }
-  return %r : vector<2xf32>
+  gpu.return %r : vector<2xf32>
+}
 }
 
 // -----
@@ -389,12 +391,14 @@ func.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
 // CHECK-PROP:        %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
 // CHECK-PROP:        %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
 // CHECK-PROP:        return %[[R]] : vector<2xf32>
-func.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
+gpu.module @test {
+gpu.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
   %zero = arith.constant dense<0.0> : vector<2xf32>
   %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
     %0 = "some_def"() : () -> (vector<2x32xf32>)
     %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
     gpu.yield %1 : vector<2xf32>
   }
-  return %r : vector<2xf32>
+  gpu.return %r : vector<2xf32>
+}
 }

>From 3c06f28573cb36e45f152f81373a4e26f385f5b5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 5 Sep 2025 22:11:31 +0000
Subject: [PATCH 14/19] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 243 +++++++++++++-----
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  83 ------
 2 files changed, 179 insertions(+), 147 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 050fa0cd1d342..5af45d1324b3b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace xegpu {
@@ -72,27 +73,43 @@ namespace {
 /// | 32x16                 | [2, 8]      | 16x2                     |
 /// | 2x32x16               | [1, 16]     | 2x32x1                   |
 static FailureOr<VectorType>
-getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
+getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
                                 VectorType originalType) {
   if (!layout)
     return failure();
+  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
+         "Expecting a valid layout.");
+  SmallVector<int64_t> effectiveLaneLayout;
+  // If the layout is a slice, we need to get effective lane layout by removing
+  // sliced dims.
+  if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+    ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
+    llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
+    for (auto [i, dim] :
+         llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
+      if (!lookUp.contains(i))
+        effectiveLaneLayout.push_back(dim);
+    }
+  } else {
+    effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
+  }
 
-  auto laneLayout = layout.getLaneLayout().asArrayRef();
-  assert(originalType.getShape().size() >= laneLayout.size() &&
+  assert(originalType.getShape().size() >= effectiveLaneLayout.size() &&
          "Rank of the original vector type should be greater or equal to the "
          "size of the lane layout to distribute the vector type.");
   SmallVector<int64_t> distributedShape(originalType.getShape());
   // Only distribute the last `laneLayout.size()` dimensions. The remaining
   // dimensions are not distributed.
-  unsigned distributionStart = originalType.getRank() - laneLayout.size();
+  unsigned distributionStart =
+      originalType.getRank() - effectiveLaneLayout.size();
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
     if (i < distributionStart)
       continue;
 
     // Check if the dimension can be distributed evenly.
-    if (dim % laneLayout[i - distributionStart] != 0)
+    if (dim % effectiveLaneLayout[i - distributionStart] != 0)
       return failure();
-    distributedShape[i] = dim / laneLayout[i - distributionStart];
+    distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
   }
   return VectorType::get(distributedShape, originalType.getElementType());
 }
@@ -858,7 +875,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
 ///   gpu.yield %1 : vector<2xf32>
 /// }
 struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
-  using Base::Base;
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
@@ -869,83 +886,108 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
         cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
     unsigned operandNumber = yieldOperand->getOperandNumber();
     VectorType sourceType = reductionOp.getSourceVectorType();
-
     // Only 2D vectors are supported.
     if (sourceType.getRank() != 2)
       return rewriter.notifyMatchFailure(warpOp,
                                          "Only 2D reductions are supported.");
     ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
-    // Only 1 reduction dimension supported. This also ensures that result is
-    // also vector type.
+    // Only 1 reduction dimension supported. This also ensures that the result
+    // is vector type.
     if (reductionDims.size() != 1)
       return rewriter.notifyMatchFailure(
           warpOp, "Only 1 reduction dimension is supported.");
     int64_t reductionDim = reductionDims[0];
-    auto resultType = cast<VectorType>(reductionOp.getType());
-    auto distributedResultType =
+    VectorType distributedResultType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    VectorType resultType = cast<VectorType>(reductionOp.getType());
     Type elementType = distributedResultType.getElementType();
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(reductionOp.getSource());
 
-    // Currently we make the following assumptions.
-    // 1. The source vector is distributed in the column dimension. Each lane
-    // owns complete column(s) of the source vector.
-    // 2. If the reduction dim == 0, its a lane-local col reduction. In this
-    // case each lane owns its portion of the result (i.e. result is also
-    // distributed).
-    // 3. If reduction dim == 1, its a row reduction that require cross lanes
-    // shuffles. In this case, the reduction result is not distributed across
-    // lanes. Instead each lane owns a complete copy of the result
-    // (broadcasted).
-    // TODO: These assumptions are fairly restrictive. For example, source
-    // vector can have row distributed layout. Improve support for such cases.
-    if (sourceType.getShape()[1] % warpOp.getWarpSize() != 0)
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+    if (failed(sourceDistTypeOrFailure))
       return rewriter.notifyMatchFailure(
-          warpOp, "Source vector dimension must be divisible by warp size.");
-    bool isResultDistributed =
+          warpOp, "Failed to distribute the source vector type.");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Only single dimension distribution is supported.
+    bool dim0Distributed =
+        sourceDistType.getShape()[0] != sourceType.getShape()[0];
+    bool dim1Distributed =
+        sourceDistType.getShape()[1] != sourceType.getShape()[1];
+    if (dim0Distributed && dim1Distributed)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+    if (sourceDistDim == -1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting a distributed source vector.");
+    bool resultDistributed =
         distributedResultType.getNumElements() < resultType.getNumElements();
-    if (reductionDim == 0 && !isResultDistributed)
+    // If the lane owns all the data required for reduction (i.e. reduction is
+    // fully parallel accross lanes), then each lane owns part of the result
+    // (i.e. result is distributed). If the reduction require cross-lane
+    // shuffling, then the result is shared among all lanes (broadcasted).
+    // Therefore we expect following cases:
+    //
+    // | Source vector        | Reduction dim  | Result vector  |
+    // |----------------------|----------------|----------------|
+    // |  dim-0 distributed   |       0        | broadcasted    |
+    // |  dim-0 distributed   |       1        | distributed    |
+    // |  dim-1 distributed   |       0        | distributed    |
+    // |  dim-1 distributed   |       1        | broadcasted    |
+
+    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+                                (sourceDistDim == 1 && reductionDim == 0);
+    if (isReductionLaneLocal && !resultDistributed)
       return rewriter.notifyMatchFailure(
-          warpOp,
-          "Expecting result vector to be distributed in a col reduction.");
-    if (reductionDim == 1 && isResultDistributed)
+          warpOp, "Expecting a distributed result for lane-local reduction.");
+
+    if (!isReductionLaneLocal && resultDistributed)
       return rewriter.notifyMatchFailure(
           warpOp,
-          "Expecting result vector to be broadcasted in a row reduction.");
+          "Expecting a broadcasted result for non-lane-local reduction.");
 
     // Create a constant vector to store the result of the reduction per lane.
+    rewriter.setInsertionPoint(warpOp);
     TypedAttr zeroAttr =
         rewriter.getZeroAttr(distributedResultType.getElementType());
     Value result = arith::ConstantOp::create(
         rewriter, reductionOp->getLoc(), distributedResultType,
         DenseElementsAttr::get(distributedResultType, zeroAttr));
-    // Col reduction.
-    if (reductionDim == 0) {
-      // Compute source distributed type assuming each lane owns cols.
-      SmallVector<int64_t> shape(sourceType.getShape());
-      shape[1] = shape[1] / warpOp.getWarpSize();
-      auto sourceDistributedType = VectorType::get(shape, elementType);
 
+    // Handle lane-local reduction case. In this case we fully distribute the
+    // reduction.
+    if (isReductionLaneLocal) {
       // Yield the source and acc vectors from the WarpOp.
       SmallVector<size_t> newRetIndices;
       auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
           rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
-          {sourceDistributedType, distributedResultType}, newRetIndices);
+          {sourceDistType, distributedResultType}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
 
-      int nCols = sourceDistributedType.getShape()[1];
+      int nSlices = sourceDistType.getShape()[sourceDistDim];
       Value source = newWarpOp.getResult(newRetIndices[0]);
       Value acc = newWarpOp.getResult(newRetIndices[1]);
-      // For each column owned by a lane, extract the column (of size nRows x
-      // 1), shape cast to 1D (nRows), do a vector.reduction and, insert the
-      // result back to the result vector.
-      for (int i = 0; i < nCols; ++i) {
+      // For each slice owned by a lane, extract the slice, shape cast to 1D, do
+      // a vector.reduction and, insert the result back to the result vector.
+      for (int i = 0; i < nSlices; ++i) {
+        SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+        if (sourceDistDim == 0) {
+          sliceOffsets = {i, 0};
+          sliceSizes = {1, sourceDistType.getShape()[1]};
+        } else {
+          sliceOffsets = {0, i};
+          sliceSizes = {sourceDistType.getShape()[0], 1};
+        }
         Value col = vector::ExtractStridedSliceOp::create(
-            rewriter, reductionOp.getLoc(), source, {0, i},
-            {sourceDistributedType.getShape()[0], 1}, {1, 1});
+            rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes,
+            {1, 1});
+        int64_t col1DSize =
+            sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1];
         col = vector::ShapeCastOp::create(
             rewriter, reductionOp.getLoc(),
-            VectorType::get({sourceDistributedType.getShape()[0]}, elementType),
-            col);
+            VectorType::get({col1DSize}, elementType), col);
         Value accCol =
             vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
         Value colReduce = vector::ReductionOp::create(
@@ -957,26 +999,79 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
       rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
       return success();
     }
-    // For row reductions, we simply rewrite the MultiReductionOp in terms of
-    // multiple ReductionOps. Actual distribution is done by the WarpOpReduction
-    // pattern.
+    // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
+    // of multiple ReductionOps. Actual distribution is done by the
+    // WarpOpReduction pattern.
     rewriter.setInsertionPointAfter(reductionOp);
-    int nRows = sourceType.getShape()[0];
-    // For each row of the source, extract the row vector, do a reduction and,
-    // insert the result back to the result.
-    for (int i = 0; i < nRows; ++i) {
-      Value source = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
-                                               reductionOp.getSource(), i);
-      Value acc = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
-                                            reductionOp.getAcc(), i);
-      Value rowReduce = vector::ReductionOp::create(
-          rewriter, reductionOp.getLoc(), reductionOp.getKind(), source, acc);
+    int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0];
+    // For each slice of the source, extract the slice vector, do a reduction
+    // and, insert the result back to the result.
+    for (int i = 0; i < nSlices; ++i) {
+      SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+      if (sourceDistDim == 1) {
+        sliceOffsets = {i, 0};
+        sliceSizes = {1, sourceType.getShape()[1]};
+      } else {
+        sliceOffsets = {0, i};
+        sliceSizes = {sourceType.getShape()[0], 1};
+      }
+      Value col = vector::ExtractStridedSliceOp::create(
+          rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets,
+          sliceSizes, {1, 1});
+      int64_t col1DSize = sourceType.getShape()[sourceDistDim];
+      col = vector::ShapeCastOp::create(
+          rewriter, reductionOp.getLoc(),
+          VectorType::get({col1DSize}, elementType), col);
+      Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
+                                               reductionOp.getAcc(), i);
+      Value colReduce = vector::ReductionOp::create(
+          rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
       result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
-                                        rowReduce, result, i);
+                                        colReduce, result, i);
     }
     // Replace the warp op result with the final result.
     rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+    return success();
+  }
+};
 
+struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
+    if (!yieldOperand)
+      return failure();
+    auto shapeCastOp =
+        cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    auto resultDistTy =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
+    if (!sourceLayout)
+      return rewriter.notifyMatchFailure(
+          warpOp, "the source of shape_cast op lacks distribution layout");
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout,
+                                        shapeCastOp.getSourceVectorType());
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "failed to get distributed vector type for source");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Create a new warp op that yields the source of the shape_cast op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value source = newWarpOp.getResult(newRetIndices[0]);
+    // Create a new shape_cast op outside the warp op.
+    Value newShapeCast = vector::ShapeCastOp::create(
+        rewriter, shapeCastOp.getLoc(), resultDistTy, source);
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+                                newShapeCast);
     return success();
   }
 };
@@ -998,6 +1093,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
            DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
            GpuBarrierDistribution, VectorMultiReductionDistribution>(
           patterns.getContext());
+  patterns.add<VectorShapeCastDistribution>(patterns.getContext(),
+                                            /*benefit=*/2);
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -1012,8 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
-      auto layout =
-          xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
@@ -1074,6 +1170,25 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // TODO: shuffleFn is not used.
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
+
+  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
+                          vector::CombiningKind kind, uint32_t size) {
+    // First reduce on a single thread to get per lane reduction value.
+    Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+    // Parallel reduction using butterfly shuffles.
+    for (uint64_t i = 1; i < size; i <<= 1) {
+      Value shuffled =
+          builder
+              .create<gpu::ShuffleOp>(loc, laneVal, i,
+                                      /*width=*/size,
+                                      /*mode=*/gpu::ShuffleMode::XOR)
+              .getShuffleResult();
+      laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+    }
+    return laneVal;
+  };
+
+  vector::populateDistributeReduction(patterns, warpReduction);
   vector::populatePropagateWarpVectorDistributionPatterns(
       patterns, distributionFn, shuffleFn);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8e2e96dfc05cc..54ef56e013abb 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,86 +319,3 @@ gpu.module @test {
     gpu.return
   }
 }
-
-// -----
-// CHECK-PROP-LABEL: func.func @vector_multi_reduction_col_reduce
-// CHECK-PROP      :   %[[W:.*]]:2 = gpu.warp_execute_on_lane_0({{.*}})[32] -> (vector<32x2xf32>, vector<2xf32>) {
-// CHECK-PROP      :     %[[SOURCE:.*]] = "some_def"() : () -> vector<32x64xf32>
-// CHECK-PROP      :     %[[ACC:.*]] = "some_def"() : () -> vector<64xf32>
-// CHECK-PROP      :     gpu.yield %[[SOURCE]], %[[ACC]] : vector<32x64xf32>, vector<64xf32>
-// CHECK-PROP      :   }
-// CHECK-PROP      :   %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#0
-// CHECK-PROP-SAME :     {offsets = [0, 0], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP      :   %[[COL0CAST:.*]] = vector.shape_cast %[[COL0]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP      :   %[[ACC0:.*]] = vector.extract %[[W]]#1[0] : f32 from vector<2xf32>
-// CHECK-PROP      :   %[[REDUCE0:.*]] = vector.reduction <add>, %[[COL0CAST]], %[[ACC0]] : vector<32xf32> into f32
-// CHECK-PROP      :   %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#0
-// CHECK-PROP-SAME :     {offsets = [0, 1], sizes = [32, 1], strides = [1, 1]} : vector<32x2xf32> to vector<32x1xf32>
-// CHECK-PROP      :   %[[COL1CAST:.*]] = vector.shape_cast %[[COL1]] : vector<32x1xf32> to vector<32xf32>
-// CHECK-PROP      :   %[[ACC1:.*]] = vector.extract %[[W]]#1[1] : f32 from vector<2xf32>
-// CHECK-PROP      :   %[[REDUCE1:.*]] = vector.reduction <add>, %[[COL1CAST]], %[[ACC1]] : vector<32xf32> into f32
-// CHECK-PROP      :   %[[R:.*]] = vector.from_elements %[[REDUCE0]], %[[REDUCE1]] : vector<2xf32>
-// CHECK-PROP      :   return %[[R]] : vector<2xf32>
-gpu.module @test {
-gpu.func @vector_multi_reduction_col_reduce(%laneid: index) -> vector<2xf32> {
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<32x64xf32>)
-    %acc = "some_def"() : () -> (vector<64xf32>)
-    %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<32x64xf32> to vector<64xf32>
-    gpu.yield %1 : vector<64xf32>
-  }
-  gpu.return %r : vector<2xf32>
-}
-}
-
-// -----
-// CHECK-PROP-LABEL:  func.func @vector_multi_reduction_row_reduce
-// CHECK-PROP-DAG:    %[[C16:.*]] = arith.constant 16 : i32
-// CHECK-PROP-DAG:    %[[C8:.*]] = arith.constant 8 : i32
-// CHECK-PROP-DAG:    %[[C4:.*]] = arith.constant 4 : i32
-// CHECK-PROP-DAG:    %[[C2:.*]] = arith.constant 2 : i32
-// CHECK-PROP-DAG:    %[[C1:.*]] = arith.constant 1 : i32
-// CHECK-PROP-DAG:    %[[C32:.*]] = arith.constant 32 : i32
-// CHECK-PROP-DAG:    %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-PROP:        %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2x1xf32>) {
-// CHECK-PROP:          %[[SRC:.*]] = "some_def"() : () -> vector<2x32xf32>
-// CHECK-PROP:          gpu.yield %[[SRC]] : vector<2x32xf32>
-// CHECK-PROP:        }
-// CHECK-PROP:        %[[T1:.*]] = vector.extract %[[W]][0, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:        %[[SR:.*]], %{{.*}} = gpu.shuffle  xor %[[T1]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:        %[[T2:.*]] = arith.addf %[[T1]], %[[SR]] : f32
-// CHECK-PROP:        %[[SR0:.*]], %{{.*}} = gpu.shuffle  xor %[[T2]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:        %[[T3:.*]] = arith.addf %[[T2]], %[[SR0]] : f32
-// CHECK-PROP:        %[[SR2:.*]], %{{.*}} = gpu.shuffle  xor %[[T3]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:        %[[T4:.*]] = arith.addf %[[T3]], %[[SR2]] : f32
-// CHECK-PROP:        %[[SR4:.*]], %{{.*}} = gpu.shuffle  xor %[[T4]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:        %[[T5:.*]] = arith.addf %[[T4]], %[[SR4]] : f32
-// CHECK-PROP:        %[[SR6:.*]], %{{.*}} = gpu.shuffle  xor %[[T5]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:        %[[T6:.*]] = arith.addf %[[T5]], %[[SR6]] : f32
-// CHECK-PROP:        %[[R0:.*]] = arith.addf %[[T6]], %[[CST]] : f32
-//
-// CHECK-PROP:        %[[T8:.*]] = vector.extract %[[W]][1, 0] : f32 from vector<2x1xf32>
-// CHECK-PROP:        %[[SR8:.*]], %{{.*}} = gpu.shuffle  xor %[[T8]], %[[C1]], %[[C32]] : f32
-// CHECK-PROP:        %[[T9:.*]] = arith.addf %[[T8]], %[[SR8]] : f32
-// CHECK-PROP:        %[[SR10:.*]], %{{.*}} = gpu.shuffle  xor %[[T9]], %[[C2]], %[[C32]] : f32
-// CHECK-PROP:        %[[T10:.*]] = arith.addf %[[T9]], %[[SR10]] : f32
-// CHECK-PROP:        %[[SR12:.*]], %{{.*}} = gpu.shuffle  xor %[[T10]], %[[C4]], %[[C32]] : f32
-// CHECK-PROP:        %[[T11:.*]] = arith.addf %[[T10]], %[[SR12]] : f32
-// CHECK-PROP:        %[[SR14:.*]], %{{.*}} = gpu.shuffle  xor %[[T11]], %[[C8]], %[[C32]] : f32
-// CHECK-PROP:        %[[T12:.*]] = arith.addf %[[T11]], %[[SR14]] : f32
-// CHECK-PROP:        %[[SR16:.*]], %{{.*}} = gpu.shuffle  xor %[[T12]], %[[C16]], %[[C32]] : f32
-// CHECK-PROP:        %[[T13:.*]] = arith.addf %[[T12]], %[[SR16]] : f32
-// CHECK-PROP:        %[[R1:.*]] = arith.addf %[[T13]], %[[CST]] : f32
-// CHECK-PROP:        %[[R:.*]] = vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
-// CHECK-PROP:        return %[[R]] : vector<2xf32>
-gpu.module @test {
-gpu.func @vector_multi_reduction_row_reduce(%laneid: index) -> vector<2xf32> {
-  %zero = arith.constant dense<0.0> : vector<2xf32>
-  %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
-    %0 = "some_def"() : () -> (vector<2x32xf32>)
-    %1 = vector.multi_reduction <add>, %0, %zero [1] : vector<2x32xf32> to vector<2xf32>
-    gpu.yield %1 : vector<2xf32>
-  }
-  gpu.return %r : vector<2xf32>
-}
-}

>From 8728eee0cab9a6940d8888c0a3c0cb1d36d154a7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 8 Sep 2025 21:42:56 +0000
Subject: [PATCH 15/19] save work

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |   4 +
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 191 +++++++++---------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 113 +++++++++++
 3 files changed, 215 insertions(+), 93 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index ddf6b4ac85a90..59dca9f0d852a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -27,6 +27,10 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
                            "vector::VectorDialect"];
+  let options = [Option<
+    "enableSGReductions", "enable-sg-reductions", "bool",
+    /*default=*/"true",
+    "Enable subgroup reductions using subgroup shuffles.">];
 }
 
 def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5af45d1324b3b..a436a3d35dba6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -58,6 +58,24 @@ namespace {
 //===----------------------------------------------------------------------===//
 // SIMT Distribution Patterns
 //===----------------------------------------------------------------------===//
+static SmallVector<int64_t>
+computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) {
+  SmallVector<int64_t> effectiveLaneLayout;
+  // If the layout is a slice, we need to get effective lane layout by removing
+  // sliced dims.
+  if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+    ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
+    llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
+    for (auto [i, dim] :
+         llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
+      if (!lookUp.contains(i))
+        effectiveLaneLayout.push_back(dim);
+    }
+  } else {
+    effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
+  }
+  return effectiveLaneLayout;
+}
 
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
@@ -79,20 +97,7 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
     return failure();
   assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
          "Expecting a valid layout.");
-  SmallVector<int64_t> effectiveLaneLayout;
-  // If the layout is a slice, we need to get effective lane layout by removing
-  // sliced dims.
-  if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
-    ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
-    llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
-    for (auto [i, dim] :
-         llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
-      if (!lookUp.contains(i))
-        effectiveLaneLayout.push_back(dim);
-    }
-  } else {
-    effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
-  }
+  SmallVector<int64_t> effectiveLaneLayout = computeEffectiveLaneLayout(layout);
 
   assert(originalType.getShape().size() >= effectiveLaneLayout.size() &&
          "Rank of the original vector type should be greater or equal to the "
@@ -824,13 +829,64 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
+/// VectorReductionOps.
+static Value lowerToVectorReductions(TypedValue<VectorType> src,
+                                     TypedValue<VectorType> acc,
+                                     vector::CombiningKind kind,
+                                     int64_t reductionDim, Location loc,
+                                     PatternRewriter &rewriter) {
+  // Expecting a 2D source vector.
+  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+  VectorType sourceType = src.getType();
+  int64_t sourceH = sourceType.getShape()[0];
+  int64_t sourceW = sourceType.getShape()[1];
+  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // Create a constant vector to hold the result of the reduction.
+  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+  Value reductionResult = arith::ConstantOp::create(
+      rewriter, loc, acc.getType(),
+      DenseElementsAttr::get(acc.getType(), zeroAttr));
+  // For each slice of the source, extract the slice vector, do a reduction
+  // and, insert the reduced value back to the result vector.
+  for (int i = 0; i < nSlices; ++i) {
+    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+    if (reductionDim == 1) {
+      sliceOffsets = {i, 0};
+      sliceSizes = {1, sourceW};
+    } else {
+      sliceOffsets = {0, i};
+      sliceSizes = {sourceH, 1};
+    }
+    vector::ExtractStridedSliceOp extractOp =
+        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+                                              sliceSizes, {1, 1});
+    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+    Value slice = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get({nSliceElements}, sourceType.getElementType()),
+        extractOp.getResult());
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    Value reduction =
+        vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
+    reductionResult =
+        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+  }
+  return reductionResult;
+}
+
 /// This patterns distribute the `vector.multi_reduction` operation across
-/// lanes in a warp. Currently only 2D to 1D reductions are supported and
-/// assumes that source vector is distributed in column dimension (i.e. Each
-/// lane owns complete column(s) of the source vector).
-/// TODO: Add support for the case where source rows are distributed across
-/// lanes. Requires `DistributionMapFn` to express the data distribution.
-/// Example 1 (Col reduction):
+/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
+/// layouts for the source and accumulator vectors,
+/// * If the reduction dimension is distributed across lanes, the reduction is
+///   non-lane-local and the reduction is done using warp shuffles. Here we
+///   simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
+///   the warp op body.
+/// * If the reduction dimension is not distributed across lanes, the reduction
+///   is lane-local. In this case, we yield the source and accumulator vectors
+///   from the warp op and perform the lane-local reduction outside the warp op
+///   using a sequence of ReductionOps.
+/// Example 1 (Reduction is lane-local):
 /// ```
 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
 ///   %0 = "some_def"() : () -> (vector<16x32xf32>)
@@ -852,7 +908,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
 /// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
 /// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
 /// ```
-/// Example 2 (Row reduction):
+/// Example 2 (Reduction is non-lane-local):
 /// ```
 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
 ///   %0 = "some_def"() : () -> (vector<2x32xf32>)
@@ -900,7 +956,6 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
     VectorType distributedResultType =
         cast<VectorType>(warpOp.getResult(operandNumber).getType());
     VectorType resultType = cast<VectorType>(reductionOp.getType());
-    Type elementType = distributedResultType.getElementType();
     xegpu::DistributeLayoutAttr sourceLayout =
         xegpu::getDistributeLayoutAttr(reductionOp.getSource());
 
@@ -948,16 +1003,8 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
           warpOp,
           "Expecting a broadcasted result for non-lane-local reduction.");
 
-    // Create a constant vector to store the result of the reduction per lane.
-    rewriter.setInsertionPoint(warpOp);
-    TypedAttr zeroAttr =
-        rewriter.getZeroAttr(distributedResultType.getElementType());
-    Value result = arith::ConstantOp::create(
-        rewriter, reductionOp->getLoc(), distributedResultType,
-        DenseElementsAttr::get(distributedResultType, zeroAttr));
-
     // Handle lane-local reduction case. In this case we fully distribute the
-    // reduction.
+    // reduction result.
     if (isReductionLaneLocal) {
       // Yield the source and acc vectors from the WarpOp.
       SmallVector<size_t> newRetIndices;
@@ -965,70 +1012,22 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
           rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
           {sourceDistType, distributedResultType}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
-
-      int nSlices = sourceDistType.getShape()[sourceDistDim];
-      Value source = newWarpOp.getResult(newRetIndices[0]);
-      Value acc = newWarpOp.getResult(newRetIndices[1]);
-      // For each slice owned by a lane, extract the slice, shape cast to 1D, do
-      // a vector.reduction and, insert the result back to the result vector.
-      for (int i = 0; i < nSlices; ++i) {
-        SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-        if (sourceDistDim == 0) {
-          sliceOffsets = {i, 0};
-          sliceSizes = {1, sourceDistType.getShape()[1]};
-        } else {
-          sliceOffsets = {0, i};
-          sliceSizes = {sourceDistType.getShape()[0], 1};
-        }
-        Value col = vector::ExtractStridedSliceOp::create(
-            rewriter, reductionOp.getLoc(), source, sliceOffsets, sliceSizes,
-            {1, 1});
-        int64_t col1DSize =
-            sourceDistType.getShape()[sourceDistDim == 1 ? 0 : 1];
-        col = vector::ShapeCastOp::create(
-            rewriter, reductionOp.getLoc(),
-            VectorType::get({col1DSize}, elementType), col);
-        Value accCol =
-            vector::ExtractOp::create(rewriter, reductionOp.getLoc(), acc, i);
-        Value colReduce = vector::ReductionOp::create(
-            rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
-        result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
-                                          colReduce, result, i);
-      }
-      // Replace the warp op result with the new reduction op.
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber), result);
+      Value result = lowerToVectorReductions(
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
+          reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+      // Replace the warp op result with the final result.
+      rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
       return success();
     }
     // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
     // of multiple ReductionOps. Actual distribution is done by the
     // WarpOpReduction pattern.
     rewriter.setInsertionPointAfter(reductionOp);
-    int nSlices = sourceType.getShape()[sourceDistDim == 0 ? 1 : 0];
-    // For each slice of the source, extract the slice vector, do a reduction
-    // and, insert the result back to the result.
-    for (int i = 0; i < nSlices; ++i) {
-      SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
-      if (sourceDistDim == 1) {
-        sliceOffsets = {i, 0};
-        sliceSizes = {1, sourceType.getShape()[1]};
-      } else {
-        sliceOffsets = {0, i};
-        sliceSizes = {sourceType.getShape()[0], 1};
-      }
-      Value col = vector::ExtractStridedSliceOp::create(
-          rewriter, reductionOp.getLoc(), reductionOp.getSource(), sliceOffsets,
-          sliceSizes, {1, 1});
-      int64_t col1DSize = sourceType.getShape()[sourceDistDim];
-      col = vector::ShapeCastOp::create(
-          rewriter, reductionOp.getLoc(),
-          VectorType::get({col1DSize}, elementType), col);
-      Value accCol = vector::ExtractOp::create(rewriter, reductionOp.getLoc(),
-                                               reductionOp.getAcc(), i);
-      Value colReduce = vector::ReductionOp::create(
-          rewriter, reductionOp.getLoc(), reductionOp.getKind(), col, accCol);
-      result = vector::InsertOp::create(rewriter, reductionOp.getLoc(),
-                                        colReduce, result, i);
-    }
+    Value result = lowerToVectorReductions(
+        cast<TypedValue<VectorType>>(reductionOp.getSource()),
+        cast<TypedValue<VectorType>>(reductionOp.getAcc()),
+        reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
     // Replace the warp op result with the final result.
     rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
     return success();
@@ -1082,6 +1081,11 @@ namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
+  XeGPUSubgroupDistributePass() = default;
+  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
+      default;
+  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
+      : XeGPUSubgroupDistributeBase(options) {}
   void runOnOperation() override;
 };
 } // namespace
@@ -1150,8 +1154,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     if (vecRank == 0)
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
-    // TODO: support more layout types
-    auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
+    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
     // If no layout is specified, assume the inner most dimension is distributed
     // for now.
     if (!layout)
@@ -1159,7 +1162,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
           vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
     SmallVector<unsigned int> distributedDims;
     // Get the distributed dimensions based on the layout.
-    ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
+    SmallVector<int64_t> laneLayout = computeEffectiveLaneLayout(layout);
     for (unsigned i = 0; i < laneLayout.size(); ++i) {
       if (laneLayout[i] > 1)
         distributedDims.push_back(i);
@@ -1188,7 +1191,9 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     return laneVal;
   };
 
-  vector::populateDistributeReduction(patterns, warpReduction);
+  if (enableSGReductions)
+    vector::populateDistributeReduction(patterns, warpReduction);
+
   vector::populatePropagateWarpVectorDistributionPatterns(
       patterns, distributionFn, shuffleFn);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..77a475cef126c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file %s | FileCheck %s
 
+// RUN: mlir-opt -xegpu-subgroup-distribute="enable-sg-reductions=false" -allow-unregistered-dialect \
+// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
+
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
 // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
@@ -319,3 +322,113 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] ->
+// CHECK-SAME:    (!xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x2xf32>) {
+// CHECK:             %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x32xf32>
+// CHECK-NEXT:        gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x32xf32>
+// CHECK-NEXT:  }
+// CHECK:       %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-NEXT:  %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-NEXT:  %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
+// CHECK:       %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-NEXT:  %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-NEXT:  %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-NEXT:  vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}  : () -> (vector<16x32xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0.0>  : vector<32xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}  [0]
+    : vector<16x32xf32> to vector<32xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<32xf32> to vector<1x32xf32>
+  xegpu.store_nd %3, %0 : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK-REDUCTION:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<2x16xf32,
+// CHECK-REDUCTION-SAME:      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32) {
+// CHECK-REDUCTION:           %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[ROW0:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:      %[[ROW1:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:      gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32
+// CHECK-REDUCTION-NEXT:    }
+// CHECK-REDUCTION-NEXT:    vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}  : () -> (vector<2x16xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} dense<0.0>  : vector<2xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+    [1] : vector<2x16xf32> to vector<2xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<2xf32> to vector<2x1xf32>
+  %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<2x1xf32> to vector<2x16xf32>
+  xegpu.store_nd %4, %0 : vector<2x16xf32>, !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK:             %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%0)[16] ->
+// CHECK-SAME:          (!xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<2x16xf32>) {
+// CHECK:                 %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<32x16xf32>
+// CHECK-NEXT:            gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<32x16xf32>
+// CHECK-NEXT:        }
+// CHECK:             %[[ROW0:.*]] = vector.extract %[[W]]#1[0] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:        %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
+// CHECK:             %[[ROW1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:        %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-NEXT:        vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}  : () -> (vector<32x16xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0.0>  : vector<32xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}  [1]
+    : vector<32x16xf32> to vector<32xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<32xf32> to vector<32x1xf32>
+  xegpu.store_nd %3, %0 : vector<32x1xf32>, !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK-REDUCTION:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<16x2xf32,
+// CHECK-REDUCTION-SAME:    #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32) {
+// CHECK-REDUCTION:          %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
+// CHECK-REDUCTION-NEXT:     %[[COL0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REDUCTION-NEXT:     %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-REDUCTION-NEXT:     %[[R0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:     %[[COL1:.*]] = vector.extract_strided_slice %5 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REDUCTION-NEXT:     %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-REDUCTION-NEXT:     %[[R1:.*]] = vector.reduction <add>, %[[CAST1]], %cst : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:     gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32
+// CHECK-REDUCTION-NEXT:   }
+// CHECK-REDUCTION-NEXT:   vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}  : () -> (vector<16x2xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<0.0>  : vector<2xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+    [0] : vector<16x2xf32> to vector<2xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : vector<2xf32> to vector<1x2xf32>
+  %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : vector<1x2xf32> to vector<16x2xf32>
+  xegpu.store_nd %4, %0 : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  gpu.return
+}
+}

>From 9b72ac0300caf075b36dfbb38164e196766addd3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 8 Sep 2025 22:17:28 +0000
Subject: [PATCH 16/19] save work

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp     | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a436a3d35dba6..4faaa487f47a7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -58,6 +58,14 @@ namespace {
 //===----------------------------------------------------------------------===//
 // SIMT Distribution Patterns
 //===----------------------------------------------------------------------===//
+
+/// In certain cases, we may need to favor XeGPU specific distribution patterns
+/// over generic vector distribution patterns. In such cases, we can assign
+/// priorities to patterns.
+enum class PatternPriority : int { Regular = 1, High = 2 };
+
+/// Helper function to compute the effective lane layout from a
+/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr.
 static SmallVector<int64_t>
 computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) {
   SmallVector<int64_t> effectiveLaneLayout;
@@ -1034,6 +1042,8 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region.
 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -1098,7 +1108,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
            GpuBarrierDistribution, VectorMultiReductionDistribution>(
           patterns.getContext());
   patterns.add<VectorShapeCastDistribution>(patterns.getContext(),
-                                            /*benefit=*/2);
+                                            /*benefit=*/PatternPriority::High);
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {

>From 18547136864d07cc4e85abc9bc421c1bfd79290e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 8 Sep 2025 22:44:34 +0000
Subject: [PATCH 17/19] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 20 +++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 0e52f8ca744ab..bf4b75f6ae2b0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/LogicalResult.h"
 
@@ -62,7 +63,8 @@ namespace {
 /// In certain cases, we may need to favor XeGPU specific distribution patterns
 /// over generic vector distribution patterns. In such cases, we can assign
 /// priorities to patterns.
-enum class PatternPriority : int { Regular = 1, High = 2 };
+static constexpr unsigned regularPatternBenefit = 1;
+static constexpr unsigned highPatternBenefit = 2;
 
 /// Helper function to compute the effective lane layout from a
 /// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr.
@@ -1300,9 +1302,12 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
       .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
            DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
            GpuBarrierDistribution, VectorMultiReductionDistribution,
-           LoadDistribution, StoreDistribution>(patterns.getContext());
-  patterns.add<VectorShapeCastDistribution>(patterns.getContext(),
-                                            /*benefit=*/PatternPriority::High);
+           LoadDistribution, StoreDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/regularPatternBenefit);
+  patterns.add<VectorShapeCastDistribution>(
+      patterns.getContext(),
+      /*pattern benefit=*/highPatternBenefit);
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -1396,10 +1401,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   };
 
   if (enableSGReductions)
-    vector::populateDistributeReduction(patterns, warpReduction);
+    vector::populateDistributeReduction(
+        patterns, warpReduction,
+        /*pattern benefit=*/regularPatternBenefit);
 
   vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
+      patterns, distributionFn, shuffleFn,
+      /*pattern benefit=*/regularPatternBenefit);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
     signalPassFailure();
     return;

>From 797aa3e3d5b63b2c250a0eaac2eb0f561aaecc2e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 8 Sep 2025 22:58:52 +0000
Subject: [PATCH 18/19] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bf4b75f6ae2b0..6f06c751549d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -33,9 +33,7 @@
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace xegpu {

>From 232808e7ad63d256240ad3c33dda4d40f489e5f8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 9 Sep 2025 18:32:49 +0000
Subject: [PATCH 19/19] save work

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  9 ++++++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 29 ++++---------------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 21 ++++++++++++++
 3 files changed, 35 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 04cfd58d846a7..507d748c5f4f6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -162,6 +162,15 @@ SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
 SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
+
+/// Helper function to compute the effective lane layout from a
+/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. For
+/// LayoutAttr, this will simply return the lane layout. For SliceAttr, it will
+/// compute the effective lane layout by removing the sliced dimensions from the
+/// parent lane layout.
+SmallVector<int64_t>
+computeEffectiveLaneLayout(xegpu::DistributeLayoutAttr layout);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6f06c751549d3..0f6368583c763 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -64,27 +64,6 @@ namespace {
 static constexpr unsigned regularPatternBenefit = 1;
 static constexpr unsigned highPatternBenefit = 2;
 
-/// Helper function to compute the effective lane layout from a
-/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr.
-static SmallVector<int64_t>
-computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) {
-  SmallVector<int64_t> effectiveLaneLayout;
-  // If the layout is a slice, we need to get effective lane layout by removing
-  // sliced dims.
-  if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
-    ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
-    llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
-    for (auto [i, dim] :
-         llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
-      if (!lookUp.contains(i))
-        effectiveLaneLayout.push_back(dim);
-    }
-  } else {
-    effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
-  }
-  return effectiveLaneLayout;
-}
-
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
 /// descriptor shape by corresponding lane_layout dimension. If
@@ -105,9 +84,11 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
     return failure();
   assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
          "Expecting a valid layout.");
-  SmallVector<int64_t> effectiveLaneLayout = computeEffectiveLaneLayout(layout);
+  SmallVector<int64_t> effectiveLaneLayout =
+      xegpu::computeEffectiveLaneLayout(layout);
 
-  assert(originalType.getShape().size() >= effectiveLaneLayout.size() &&
+  assert(static_cast<size_t>(originalType.getRank()) >=
+             effectiveLaneLayout.size() &&
          "Rank of the original vector type should be greater or equal to the "
          "size of the lane layout to distribute the vector type.");
   SmallVector<int64_t> distributedShape(originalType.getShape());
@@ -1369,7 +1350,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
           vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
     SmallVector<unsigned int> distributedDims;
     // Get the distributed dimensions based on the layout.
-    SmallVector<int64_t> laneLayout = computeEffectiveLaneLayout(layout);
+    SmallVector<int64_t> laneLayout = xegpu::computeEffectiveLaneLayout(layout);
     for (unsigned i = 0; i < laneLayout.size(); ++i) {
       if (laneLayout[i] > 1)
         distributedDims.push_back(i);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b72d5648b29f9..aa55ed27d6ecf 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -484,3 +484,24 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
   results.append(addElementwise(builder, loc, a, b));
   return results;
 }
+
+SmallVector<int64_t>
+xegpu::computeEffectiveLaneLayout(const xegpu::DistributeLayoutAttr layout) {
+  if (!layout)
+    return {};
+  SmallVector<int64_t> effectiveLaneLayout;
+  // If the layout is a slice, we need to get effective lane layout by removing
+  // sliced dims.
+  if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
+    ArrayRef<int64_t> slicedDims = sliceAttr.flatten().getDims().asArrayRef();
+    llvm::DenseSet<int64_t> lookUp(slicedDims.begin(), slicedDims.end());
+    for (auto [i, dim] :
+         llvm::enumerate(sliceAttr.getParent().getLaneLayoutAsInt())) {
+      if (!lookUp.contains(i))
+        effectiveLaneLayout.push_back(dim);
+    }
+  } else {
+    effectiveLaneLayout = cast<xegpu::LayoutAttr>(layout).getLaneLayoutAsInt();
+  }
+  return effectiveLaneLayout;
+}



More information about the Mlir-commits mailing list