[Mlir-commits] [mlir] 9b0d7dd - [mlir][xegpu] Add support for `vector.multi_reduction` and `vector.shape_cast` SIMT distribution. (#157560)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 12 09:37:09 PDT 2025


Author: Charitha Saumya
Date: 2025-09-12T09:37:04-07:00
New Revision: 9b0d7ddb04665e76cfa90b5d69c6183b90772243

URL: https://github.com/llvm/llvm-project/commit/9b0d7ddb04665e76cfa90b5d69c6183b90772243
DIFF: https://github.com/llvm/llvm-project/commit/9b0d7ddb04665e76cfa90b5d69c6183b90772243.diff

LOG: [mlir][xegpu] Add support for `vector.multi_reduction` and `vector.shape_cast` SIMT distribution.  (#157560)

Add support for distributing the `vector.multi_reduction` operation
across lanes in a warp. Currently only 2D to 1D reductions are
supported. Given layouts for the source and accumulator vectors,
* If the reduction dimension is distributed across lanes, the reduction
is non-lane-local and the reduction is done using warp shuffles. Here we
simply rewrite the `MultiDimReductionOp` to a sequence of `ReductionOp`s
inside the warp op body. Actual distribution will be done by
`WarpOpReduction` pattern.
* If the reduction dimension is not distributed across lanes, the
reduction is lane-local. In this case, we yield the source and
accumulator vectors from the warp op and perform the lane-local
reduction outside the warp op using a sequence of `ReductionOp`s.

PR also adds support for distributing `vector.shape_cast` based on
layouts.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index cfe3e800484ce..1f1d367118365 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -194,26 +194,29 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Get the num of effective subgroups",
                     "int64_t",
                     "getNumSubgroups", (ins), [{
-                        std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
+                        std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getEffectiveSgLayoutAsInt();
                         if (sgLayout.has_value())
                           return computeProduct(*sgLayout);
                         return 0;
                     }], [{}]>,
-    InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
+    InterfaceMethod<"Get the order of the layout attribute",
+                    "DenseI32ArrayAttr",
+                    "getOrder">,
+    InterfaceMethod<"Get the effective SgLayout of the layout attribute as integer array",
                     "SmallVector<int64_t>",
-                    "getSgLayoutAsInt">,
-    InterfaceMethod<"Get the SgData field of the attribute as integer array",
+                    "getEffectiveSgLayoutAsInt">,
+    InterfaceMethod<"Get the effective SgData of the layout attribute as integer array",
                     "SmallVector<int64_t>",
-                    "getSgDataAsInt">,
-    InterfaceMethod<"Get the InstData field of the attribute as integer array",
+                    "getEffectiveSgDataAsInt">,
+    InterfaceMethod<"Get the effective InstData of the layout attribute as integer array",
                     "SmallVector<int64_t>",
-                    "getInstDataAsInt">,
-    InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
+                    "getEffectiveInstDataAsInt">,
+    InterfaceMethod<"Get the effective LaneLayout of the layout attribute as integer array",
                     "SmallVector<int64_t>",
-                    "getLaneLayoutAsInt">,
-    InterfaceMethod<"Get the LaneData field of the attribute as integer array",
+                    "getEffectiveLaneLayoutAsInt">,
+    InterfaceMethod<"Get the effective LaneData of the layout attribute as integer array",
                     "SmallVector<int64_t>",
-                    "getLaneDataAsInt">,
+                    "getEffectiveLaneDataAsInt">,
     InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
                     "xegpu::DistributeLayoutAttr",
                     "dropSgLayoutAndData">,
@@ -231,7 +234,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "getOffsets",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isSliceOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
   ];
 }
 
@@ -391,31 +398,31 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
                              getLaneLayout(), getLaneData(), getOrder());
     }
 
-    SmallVector<int64_t> getSgLayoutAsInt() const {
+    SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
       if (DenseI32ArrayAttr layout = getSgLayout())
         return llvm::to_vector_of<int64_t>(layout.asArrayRef());
       return {};
     }
 
-    SmallVector<int64_t> getSgDataAsInt() const {
+    SmallVector<int64_t> getEffectiveSgDataAsInt() const {
       if (DenseI32ArrayAttr data = getSgData())
         return llvm::to_vector_of<int64_t>(data.asArrayRef());
       return {};
     }
 
-    SmallVector<int64_t> getInstDataAsInt() const {
+    SmallVector<int64_t> getEffectiveInstDataAsInt() const {
       if (DenseI32ArrayAttr inst = getInstData())
         return llvm::to_vector_of<int64_t>(inst.asArrayRef());
       return {};
     }
 
-    SmallVector<int64_t> getLaneLayoutAsInt() const {
+    SmallVector<int64_t> getEffectiveLaneLayoutAsInt() const {
       if (DenseI32ArrayAttr layout = getLaneLayout())
         return llvm::to_vector_of<int64_t>(layout.asArrayRef());
       return {};
     }
 
-    SmallVector<int64_t> getLaneDataAsInt() const {
+    SmallVector<int64_t> getEffectiveLaneDataAsInt() const {
       if (DenseI32ArrayAttr data = getLaneData())
         return llvm::to_vector_of<int64_t>(data.asArrayRef());
       return {};
@@ -433,6 +440,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    /// Check if this is slice of some other layout.
+    bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -499,10 +509,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
-    SmallVector<int64_t> getSgLayoutAsInt() const {
+    SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      auto layout = parent.getSgLayoutAsInt();
+      auto layout = parent.getEffectiveSgLayoutAsInt();
       if (layout.size()) {
         ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
         return XeGPUDialect::slice(ArrayRef<int64_t>(layout), dims);
@@ -512,10 +522,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Returns the SgData of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
-    SmallVector<int64_t> getSgDataAsInt() const {
+    SmallVector<int64_t> getEffectiveSgDataAsInt() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      auto data = parent.getSgDataAsInt();
+      auto data = parent.getEffectiveSgDataAsInt();
       if (data.size()) {
         ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
         return XeGPUDialect::slice(ArrayRef<int64_t>(data), dims);
@@ -525,10 +535,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Returns the InstData of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
-    SmallVector<int64_t> getInstDataAsInt() const {
+    SmallVector<int64_t> getEffectiveInstDataAsInt() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      auto inst = parent.getInstDataAsInt();
+      auto inst = parent.getEffectiveInstDataAsInt();
       if (inst.size()) {
         ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
         return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(inst), dims);
@@ -538,10 +548,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Returns the LaneLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
-    SmallVector<int64_t> getLaneLayoutAsInt() const {
+    SmallVector<int64_t> getEffectiveLaneLayoutAsInt() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      auto layout = parent.getLaneLayoutAsInt();
+      auto layout = parent.getEffectiveLaneLayoutAsInt();
       if (layout.size()) {
         ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
         return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(layout), dims);
@@ -551,10 +561,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Returns the LaneData of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
-    SmallVector<int64_t> getLaneDataAsInt() const {
+    SmallVector<int64_t> getEffectiveLaneDataAsInt() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      auto data = parent.getLaneDataAsInt();
+      auto data = parent.getEffectiveLaneDataAsInt();
       if (data.size()) {
         ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
         return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(data), dims);
@@ -594,6 +604,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    /// Check if this is slice of some other layout.
+    bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
+
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index ddf6b4ac85a90..59dca9f0d852a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -27,6 +27,10 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
                            "vector::VectorDialect"];
+  let options = [Option<
+    "enableSGReductions", "enable-sg-reductions", "bool",
+    /*default=*/"true",
+    "Enable subgroup reductions using subgroup shuffles.">];
 }
 
 def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7f3be7f91c56b..94c5509fd7c29 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -133,22 +133,23 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
   };
 
   // check the sgLayout and sgData
-  auto maybeSgShape =
-      tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
+  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
+                                    attr.getEffectiveSgDataAsInt());
   if (!maybeSgShape)
     return false;
   auto sgShape = maybeSgShape.value();
 
   // check InstData, it neither have layout nor need round-robin
   auto maybeInstShape =
-      tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false);
+      tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
   if (!maybeInstShape)
     return false;
   auto instShape = maybeInstShape.value();
 
   // check LaneLayout and LaneData
-  auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
-                                      attr.getLaneDataAsInt(), false);
+  auto maybeLaneShape =
+      tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
+                    attr.getEffectiveLaneDataAsInt(), false);
   return maybeLaneShape.has_value();
 }
 
@@ -282,9 +283,10 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   if (!hasDefaultOrder())
     return mlir::emitError(loc, "order attribute is currently not supported.");
 
-  auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
-    return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-  });
+  auto dims =
+      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
+        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+      });
 
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
@@ -298,8 +300,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
   if (!isForWorkgroup())
     return failure();
 
-  SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getSgDataAsInt();
+  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
   if (sgShape.empty()) {
     if (auto derivedShape = computeShapeRatio(shape, sgLayout))
       sgShape = derivedShape.value();
@@ -385,8 +387,8 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
   if (!isForWorkgroup())
     return failure();
 
-  SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getSgDataAsInt();
+  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
   if (sgShape.empty()) {
     if (auto derivedShape = computeShapeRatio(shape, sgLayout))
       sgShape = derivedShape.value();
@@ -409,6 +411,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                                   shape);
 }
 
+bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
+  auto flattenedThis = flatten();
+  // If other is a LayoutAttr, just compare directly with parent of
+  // flattenedThis.
+  if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
+    return flattenedThis.getParent() == otherLayout;
+  // If other is a SliceAttr, flatten it first before comparing.
+  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+  // Both must have common parent LayoutAttr.
+  if (flattenedThis.getParent() != flattenedOther.getParent())
+    return false;
+  // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
+  // dims.
+  llvm::SmallDenseSet<int64_t> thisDims(
+      flattenedThis.getDims().asArrayRef().begin(),
+      flattenedThis.getDims().asArrayRef().end());
+  return llvm::all_of(flattenedOther.getDims().asArrayRef(),
+                      [&](int64_t dim) { return thisDims.contains(dim); });
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 5d5ff69e06886..7efa4b9fbd934 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -85,16 +85,16 @@ struct ConvertLayoutOpPattern
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
                                 PatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
-    xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
-    if (input_layout.getInstDataAsInt().empty() ||
-        target_layout.getInstDataAsInt().empty())
+    xegpu::DistributeLayoutAttr inputLayout = op.getInputLayoutAttr();
+    xegpu::DistributeLayoutAttr targetLayout = op.getTargetLayoutAttr();
+    if (inputLayout.getEffectiveInstDataAsInt().empty() ||
+        targetLayout.getEffectiveInstDataAsInt().empty())
       return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
 
-    input_layout = input_layout.dropInstData();
-    target_layout = target_layout.dropInstData();
+    inputLayout = inputLayout.dropInstData();
+    targetLayout = targetLayout.dropInstData();
     auto newOp = rewriter.createOrFold<xegpu::ConvertLayoutOp>(
-        op.getLoc(), op.getType(), op.getSource(), input_layout, target_layout);
+        op.getLoc(), op.getType(), op.getSource(), inputLayout, targetLayout);
     rewriter.replaceOp(op, newOp);
     return success();
   }
@@ -145,8 +145,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   xegpu::DistributeLayoutAttr layout =
       xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (!layout.getInstDataAsInt().empty())
-      return layout.getInstDataAsInt();
+    if (!layout.getEffectiveInstDataAsInt().empty())
+      return layout.getEffectiveInstDataAsInt();
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -226,7 +226,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
     Type valTy = value.getType();
     if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
       xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
-      return layout && !layout.getInstDataAsInt().empty();
+      return layout && !layout.getEffectiveInstDataAsInt().empty();
     }
     auto shapedType = dyn_cast<ShapedType>(valTy);
     return shapedType && !llvm::equal(tileShape, shapedType.getShape());

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b33669259249a..21c1583bf2633 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -58,6 +58,12 @@ namespace {
 // SIMT Distribution Patterns
 //===----------------------------------------------------------------------===//
 
+/// In certain cases, we may need to favor XeGPU specific distribution patterns
+/// over generic vector distribution patterns. In such cases, we can assign
+/// priorities to patterns.
+static constexpr unsigned regularPatternBenefit = 1;
+static constexpr unsigned highPatternBenefit = 2;
+
 /// Helper function to get  distributed vector type for a source vector type
 /// according to the lane_layout. We simply divide each dimension of tensor
 /// descriptor shape by corresponding lane_layout dimension. If
@@ -72,27 +78,31 @@ namespace {
 /// | 32x16                 | [2, 8]      | 16x2                     |
 /// | 2x32x16               | [1, 16]     | 2x32x1                   |
 static FailureOr<VectorType>
-getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
+getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
                                 VectorType originalType) {
   if (!layout)
     return failure();
-
-  auto laneLayout = layout.getLaneLayout().asArrayRef();
-  assert(originalType.getShape().size() >= laneLayout.size() &&
+  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
+         "Expecting a valid layout.");
+  SmallVector<int64_t> effectiveLaneLayout =
+      layout.getEffectiveLaneLayoutAsInt();
+  assert(static_cast<size_t>(originalType.getRank()) >=
+             effectiveLaneLayout.size() &&
          "Rank of the original vector type should be greater or equal to the "
          "size of the lane layout to distribute the vector type.");
   SmallVector<int64_t> distributedShape(originalType.getShape());
   // Only distribute the last `laneLayout.size()` dimensions. The remaining
   // dimensions are not distributed.
-  unsigned distributionStart = originalType.getRank() - laneLayout.size();
+  unsigned distributionStart =
+      originalType.getRank() - effectiveLaneLayout.size();
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
     if (i < distributionStart)
       continue;
 
     // Check if the dimension can be distributed evenly.
-    if (dim % laneLayout[i - distributionStart] != 0)
+    if (dim % effectiveLaneLayout[i - distributionStart] != 0)
       return failure();
-    distributedShape[i] = dim / laneLayout[i - distributionStart];
+    distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
   }
   return VectorType::get(distributedShape, originalType.getElementType());
 }
@@ -1001,12 +1011,282 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Helper to rewrite a 2D VectorMultiReductionOp into a sequence of 1D
+/// VectorReductionOps.
+static Value lowerToVectorReductions(TypedValue<VectorType> src,
+                                     TypedValue<VectorType> acc,
+                                     vector::CombiningKind kind,
+                                     int64_t reductionDim, Location loc,
+                                     PatternRewriter &rewriter) {
+  // Expecting a 2D source vector.
+  assert(src.getType().getRank() == 2 && "expected a 2D source vector");
+  VectorType sourceType = src.getType();
+  int64_t sourceH = sourceType.getShape()[0];
+  int64_t sourceW = sourceType.getShape()[1];
+  int nSlices = (reductionDim == 0) ? sourceW : sourceH;
+  // Create a constant vector to hold the result of the reduction.
+  TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
+  Value reductionResult = arith::ConstantOp::create(
+      rewriter, loc, acc.getType(),
+      DenseElementsAttr::get(acc.getType(), zeroAttr));
+  // For each slice of the source, extract the slice vector, do a reduction
+  // and, insert the reduced value back to the result vector.
+  for (int i = 0; i < nSlices; ++i) {
+    SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
+    if (reductionDim == 1) {
+      sliceOffsets = {i, 0};
+      sliceSizes = {1, sourceW};
+    } else {
+      sliceOffsets = {0, i};
+      sliceSizes = {sourceH, 1};
+    }
+    vector::ExtractStridedSliceOp extractOp =
+        vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
+                                              sliceSizes, {1, 1});
+    int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
+    Value slice = vector::ShapeCastOp::create(
+        rewriter, loc,
+        VectorType::get({nSliceElements}, sourceType.getElementType()),
+        extractOp.getResult());
+    Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
+    Value reduction =
+        vector::ReductionOp::create(rewriter, loc, kind, slice, accExtract);
+    reductionResult =
+        vector::InsertOp::create(rewriter, loc, reduction, reductionResult, i);
+  }
+  return reductionResult;
+}
+
+/// This patterns distribute the `vector.multi_reduction` operation across
+/// lanes in a warp. Currently only 2D to 1D reductions are supported. Given
+/// layouts for the source and accumulator vectors,
+/// * If the reduction dimension is distributed across lanes, the reduction is
+///   non-lane-local and the reduction is done using warp shuffles. Here we
+///   simply rewrite the MultiDimReductionOp to a sequence of ReductionOps in
+///   the warp op body.
+/// * If the reduction dimension is not distributed across lanes, the reduction
+///   is lane-local. In this case, we yield the source and accumulator vectors
+///   from the warp op and perform the lane-local reduction outside the warp op
+///   using a sequence of ReductionOps.
+/// Example 1 (Reduction is lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [0] : vector<16x32xf32> to
+///   vector<32xf32> gpu.yield %1 : vector<32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<16x1xf32>,
+/// vector<1xf32>) {
+///   %0 = "some_def"() : () -> (vector<16x32xf32>)
+///   %acc = "some_def"() : () -> (vector<32xf32>)
+///   gpu.yield %0, %acc : vector<16x32xf32>, vector<32xf32>
+/// }
+/// %c = arith.constant dense<0.0> : vector<1xf32>
+/// %1 = vector.shape_cast %r#0 : vector<16x1xf32> to vector<16xf32>
+/// %2 = vector.reduction <add>, %1, %r#1 : vector<16xf32> to f32
+/// %3 = vector.insert %2, %c[0] : f32 into vector<1xf32>
+/// ```
+/// Example 2 (Reduction is non-lane-local):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<2x32xf32> to
+///   vector<2xf32>
+///   gpu.yield %1 : vector<2xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<2xf32>) {
+///   %0 = "some_def"() : () -> (vector<2x32xf32>)
+///   %acc = "some_def"() : () -> (vector<2xf32>)
+///   %1 = arith.constant dense<0.0> : vector<2xf32>
+///   %2 = vector.extract %0[0] : vector<32xf32> from <vector<2x32xf32>>
+///   %3 = ("warp.reduction %2") : f32
+///   %4 = vector.insert %3, %1[0] : f32 into vector<2xf32>
+///   ... repeat for row 1
+///   gpu.yield %1 : vector<2xf32>
+/// }
+struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::MultiDimReductionOp>);
+    if (!yieldOperand)
+      return failure();
+    auto reductionOp =
+        cast<vector::MultiDimReductionOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    VectorType sourceType = reductionOp.getSourceVectorType();
+    // Only 2D vectors are supported.
+    if (sourceType.getRank() != 2)
+      return rewriter.notifyMatchFailure(warpOp,
+                                         "Only 2D reductions are supported.");
+    ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
+    // Only 1 reduction dimension supported. This also ensures that the result
+    // is vector type.
+    if (reductionDims.size() != 1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Only 1 reduction dimension is supported.");
+    int64_t reductionDim = reductionDims[0];
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    VectorType resultType = cast<VectorType>(reductionOp.getType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(reductionOp.getSource());
+
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Failed to distribute the source vector type.");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Only single dimension distribution is supported.
+    bool dim0Distributed =
+        sourceDistType.getShape()[0] != sourceType.getShape()[0];
+    bool dim1Distributed =
+        sourceDistType.getShape()[1] != sourceType.getShape()[1];
+    if (dim0Distributed && dim1Distributed)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting source to be distributed in a single dimension.");
+    int64_t sourceDistDim = dim0Distributed ? 0 : (dim1Distributed ? 1 : -1);
+    if (sourceDistDim == -1)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting a distributed source vector.");
+    bool resultDistributed =
+        distributedResultType.getNumElements() < resultType.getNumElements();
+    // If the lane owns all the data required for reduction (i.e. reduction is
+    // fully parallel accross lanes), then each lane owns part of the result
+    // (i.e. result is distributed). If the reduction require cross-lane
+    // shuffling, then the result is shared among all lanes (broadcasted).
+    // Therefore we expect following cases:
+    //
+    // | Source vector        | Reduction dim  | Result vector  |
+    // |----------------------|----------------|----------------|
+    // |  dim-0 distributed   |       0        | broadcasted    |
+    // |  dim-0 distributed   |       1        | distributed    |
+    // |  dim-1 distributed   |       0        | distributed    |
+    // |  dim-1 distributed   |       1        | broadcasted    |
+
+    bool isReductionLaneLocal = (sourceDistDim == 0 && reductionDim == 1) ||
+                                (sourceDistDim == 1 && reductionDim == 0);
+    if (isReductionLaneLocal && !resultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp, "Expecting a distributed result for lane-local reduction.");
+
+    if (!isReductionLaneLocal && resultDistributed)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "Expecting a broadcasted result for non-lane-local reduction.");
+
+    // Handle lane-local reduction case. In this case we fully distribute the
+    // reduction result.
+    if (isReductionLaneLocal) {
+      // Yield the source and acc vectors from the WarpOp.
+      SmallVector<size_t> newRetIndices;
+      auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {reductionOp.getSource(), reductionOp.getAcc()},
+          {sourceDistType, distributedResultType}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+      Value result = lowerToVectorReductions(
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[0])),
+          cast<TypedValue<VectorType>>(newWarpOp->getResult(newRetIndices[1])),
+          reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+      // Replace the warp op result with the final result.
+      rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+      return success();
+    }
+    // For non-lane-local case, we simply rewrite the MultiReductionOp in terms
+    // of multiple ReductionOps. Actual distribution is done by the
+    // WarpOpReduction pattern.
+    rewriter.setInsertionPointAfter(reductionOp);
+    Value result = lowerToVectorReductions(
+        cast<TypedValue<VectorType>>(reductionOp.getSource()),
+        cast<TypedValue<VectorType>>(reductionOp.getAcc()),
+        reductionOp.getKind(), reductionDim, reductionOp.getLoc(), rewriter);
+    // Replace the warp op result with the final result.
+    rewriter.replaceAllUsesWith(reductionOp.getResult(), result);
+    return success();
+  }
+};
+
+/// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region.
+struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
+    if (!yieldOperand)
+      return failure();
+    auto shapeCastOp =
+        cast<vector::ShapeCastOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandNumber = yieldOperand->getOperandNumber();
+    auto resultDistTy =
+        cast<VectorType>(warpOp.getResult(operandNumber).getType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(shapeCastOp.getSource());
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(shapeCastOp.getResult());
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "the source or result of shape_cast op lacks distribution layout");
+
+    // For rank reducing or increasing shape_cast ops, the lower rank layout
+    // must be a slice of higher rank layout.
+    int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
+    int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
+    if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
+      return rewriter.notifyMatchFailure(
+          warpOp, "shape_cast is rank reducing but source layout is not a "
+                  "slice of result layout");
+    if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
+      return rewriter.notifyMatchFailure(
+          warpOp, "shape_cast is rank increasing but result layout is not a "
+                  "slice of source layout");
+
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout,
+                                        shapeCastOp.getSourceVectorType());
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "failed to get distributed vector type for source");
+    VectorType sourceDistType = sourceDistTypeOrFailure.value();
+    // Create a new warp op that yields the source of the shape_cast op.
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {shapeCastOp.getSource()}, {sourceDistType},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value source = newWarpOp.getResult(newRetIndices[0]);
+    // Create a new shape_cast op outside the warp op.
+    Value newShapeCast = vector::ShapeCastOp::create(
+        rewriter, shapeCastOp.getLoc(), resultDistTy, source);
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandNumber),
+                                newShapeCast);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
 struct XeGPUSubgroupDistributePass final
     : public xegpu::impl::XeGPUSubgroupDistributeBase<
           XeGPUSubgroupDistributePass> {
+  XeGPUSubgroupDistributePass() = default;
+  XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
+      default;
+  XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
+      : XeGPUSubgroupDistributeBase(options) {}
   void runOnOperation() override;
 };
 } // namespace
@@ -1016,8 +1296,13 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   patterns
       .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
            DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
-           GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
-          patterns.getContext());
+           GpuBarrierDistribution, VectorMultiReductionDistribution,
+           LoadDistribution, StoreDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/regularPatternBenefit);
+  patterns.add<VectorShapeCastDistribution>(
+      patterns.getContext(),
+      /*pattern benefit=*/highPatternBenefit);
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -1032,8 +1317,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
-      auto layout =
-          xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
@@ -1074,18 +1358,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     if (vecRank == 0)
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
-    // TODO: support more layout types
-    auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
+    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
     // If no layout is specified, assume the inner most dimension is distributed
     // for now.
     if (!layout)
       return AffineMap::getMultiDimMapWithTargets(
           vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
     SmallVector<unsigned int> distributedDims;
-    // Get the distributed dimensions based on the layout.
-    ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
-    for (unsigned i = 0; i < laneLayout.size(); ++i) {
-      if (laneLayout[i] > 1)
+    for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
+      if (v > 1)
         distributedDims.push_back(i);
     }
     return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
@@ -1094,8 +1375,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
   // TODO: shuffleFn is not used.
   auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
                       int64_t warpSz) { return Value(); };
+
+  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
+                          vector::CombiningKind kind, uint32_t size) {
+    // First reduce on a single thread to get per lane reduction value.
+    Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+    // Parallel reduction using butterfly shuffles.
+    for (uint64_t i = 1; i < size; i <<= 1) {
+      Value shuffled =
+          builder
+              .create<gpu::ShuffleOp>(loc, laneVal, i,
+                                      /*width=*/size,
+                                      /*mode=*/gpu::ShuffleMode::XOR)
+              .getShuffleResult();
+      laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+    }
+    return laneVal;
+  };
+
+  if (enableSGReductions)
+    vector::populateDistributeReduction(
+        patterns, warpReduction,
+        /*pattern benefit=*/regularPatternBenefit);
+
   vector::populatePropagateWarpVectorDistributionPatterns(
-      patterns, distributionFn, shuffleFn);
+      patterns, distributionFn, shuffleFn,
+      /*pattern benefit=*/regularPatternBenefit);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
     signalPassFailure();
     return;

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5d0f1d18402f2..3f48400fedf5e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -52,9 +52,9 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
   if (layout && layout.isForWorkgroup()) {
-    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
-    if (!layout.getSgDataAsInt().empty())
-      sgShape = layout.getSgDataAsInt();
+    SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
+    if (!layout.getEffectiveSgDataAsInt().empty())
+      sgShape = layout.getEffectiveSgDataAsInt();
     else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
       sgShape = *maybeDerivedSgData;
     SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
@@ -488,7 +488,7 @@ struct WgToSgVectorBroadcastOp
         VectorType::get(sgShape, resultType.getElementType());
 
     // Check if the output layout is distributable
-    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
+    SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
     if (sgLayout.empty())
       return failure();
 

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 60acea06c9a12..30ca9816df5bc 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt -xegpu-subgroup-distribute -allow-unregistered-dialect -canonicalize -cse -split-input-file %s | FileCheck %s
 
+// RUN: mlir-opt -xegpu-subgroup-distribute="enable-sg-reductions=false" -allow-unregistered-dialect \
+// RUN: -canonicalize -cse -split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
+
 // CHECK-LABEL: gpu.func @store_nd_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<16xf32>) {
 // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
@@ -320,6 +323,116 @@ gpu.module @test {
   }
 }
 
+// -----
+// CHECK-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction
+// CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] ->
+// CHECK-SAME:    (!xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x2xf32>) {
+// CHECK:             %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x32xf32>
+// CHECK-NEXT:        gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, vector<16x32xf32>
+// CHECK-NEXT:  }
+// CHECK:       %[[COL0:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-NEXT:  %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-NEXT:  %[[RED0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
+// CHECK:       %[[COL1:.*]] = vector.extract_strided_slice %[[W]]#1 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-NEXT:  %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-NEXT:  %[[RED1:.*]] = vector.reduction <add>, %[[CAST1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-NEXT:  vector.from_elements %[[RED0]], %[[RED1]] : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}  : () -> (vector<16x32xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.0>  : vector<32xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}  [0]
+    : vector<16x32xf32> to vector<32xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<32xf32> to vector<1x32xf32>
+  xegpu.store_nd %3, %0 : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction
+// CHECK-REDUCTION:         %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<2x16xf32,
+// CHECK-REDUCTION-SAME:      #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32) {
+// CHECK-REDUCTION:           %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[ROW0:.*]] = vector.extract %[[SRC]][0] : vector<16xf32> from vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:      %[[ROW1:.*]] = vector.extract %[[SRC]][1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-REDUCTION-NEXT:      %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:      gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, f32, f32
+// CHECK-REDUCTION-NEXT:    }
+// CHECK-REDUCTION-NEXT:    vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}  : () -> (vector<2x16xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} dense<0.0>  : vector<2xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}
+    [1] : vector<2x16xf32> to vector<2xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : vector<2xf32> to vector<2x1xf32>
+  %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<2x1xf32> to vector<2x16xf32>
+  xegpu.store_nd %4, %0 : vector<2x16xf32>, !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL:   gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction
+// CHECK:             %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%0)[16] ->
+// CHECK-SAME:          (!xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<2x16xf32>) {
+// CHECK:                 %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<32x16xf32>
+// CHECK-NEXT:            gpu.yield %{{.*}}, %[[SRC]] : !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<32x16xf32>
+// CHECK-NEXT:        }
+// CHECK:             %[[ROW0:.*]] = vector.extract %[[W]]#1[0] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:        %[[R0:.*]] = vector.reduction <add>, %[[ROW0]], %{{.*}} : vector<16xf32> into f32
+// CHECK:             %[[ROW1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
+// CHECK-NEXT:        %[[R1:.*]] = vector.reduction <add>, %[[ROW1]], %{{.*}} : vector<16xf32> into f32
+// CHECK-NEXT:        vector.from_elements %[[R0]], %[[R1]] : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}  : () -> (vector<32x16xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>} dense<0.0>  : vector<32xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>}  [1]
+    : vector<32x16xf32> to vector<32xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : vector<32xf32> to vector<32x1xf32>
+  xegpu.store_nd %3, %0 : vector<32x1xf32>, !xegpu.tensor_desc<32x1xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-REDUCTION-LABEL: gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction
+// CHECK-REDUCTION:       %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.tensor_desc<16x2xf32,
+// CHECK-REDUCTION-SAME:    #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32) {
+// CHECK-REDUCTION:          %[[SRC:.*]] = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : () -> vector<16x2xf32>
+// CHECK-REDUCTION-NEXT:     %[[COL0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REDUCTION-NEXT:     %[[CAST0:.*]] = vector.shape_cast %[[COL0]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-REDUCTION-NEXT:     %[[R0:.*]] = vector.reduction <add>, %[[CAST0]], %{{.*}} : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:     %[[COL1:.*]] = vector.extract_strided_slice %5 {offsets = [0, 1], sizes = [16, 1], strides = [1, 1]} : vector<16x2xf32> to vector<16x1xf32>
+// CHECK-REDUCTION-NEXT:     %[[CAST1:.*]] = vector.shape_cast %[[COL1]] : vector<16x1xf32> to vector<16xf32>
+// CHECK-REDUCTION-NEXT:     %[[R1:.*]] = vector.reduction <add>, %[[CAST1]], %cst : vector<16xf32> into f32
+// CHECK-REDUCTION-NEXT:     gpu.yield %4, %[[R1]], %[[R0]] : !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, f32, f32
+// CHECK-REDUCTION-NEXT:   }
+// CHECK-REDUCTION-NEXT:   vector.from_elements %[[W]]#2, %[[W]]#1 : vector<2xf32>
+gpu.module @test {
+gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction() {
+  %0 = "some_def"() : () -> !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  %src = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}  : () -> (vector<16x2xf32>)
+  %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>} dense<0.0>  : vector<2xf32>
+  %1 = vector.multi_reduction <add>, %src, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>}
+    [0] : vector<16x2xf32> to vector<2xf32>
+  %3 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+    : vector<2xf32> to vector<1x2xf32>
+  %4 = vector.broadcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : vector<1x2xf32> to vector<16x2xf32>
+  xegpu.store_nd %4, %0 : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+  gpu.return
+}
+}
+
 // -----
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 200323c7a4e51..e1ba45c60ac36 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -170,7 +170,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
     if (!sliceAttr || sliceAttr.getRank() != 1)
       return failure();
 
-    std::optional<SmallVector<int64_t>> sgShape = sliceAttr.getSgDataAsInt();
+    std::optional<SmallVector<int64_t>> sgShape =
+        sliceAttr.getEffectiveSgDataAsInt();
     if (!sgShape)
       return failure();
 


        


More information about the Mlir-commits mailing list