[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (PR #155443)

Nishant Patel llvmlistbot at llvm.org
Fri Sep 12 10:33:06 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/155443

>From f6e6e10d44923f518120eadbae690e2f197f9e7d Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 25 Aug 2025 18:59:31 +0000
Subject: [PATCH 01/13] Add vector.step distribution pattern

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 204 +++++++++++++++++-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  40 ++++
 2 files changed, 239 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 93b4efcd125ec..39d0e9ea2e91d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -737,8 +737,18 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     if (!vecAttr || !vecAttr.isSplat() || !vecType)
       return failure();
 
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-    if (!layout || !layout.getSgLayout())
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto attr = op->getAttr(layoutName);
+
+    xegpu::DistributeLayoutAttr layout = nullptr;
+    // Try to get either SliceAttr or LayoutAttr, and keep as is
+    if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
+      layout = trySlice;
+    } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
+      layout = tryLayout;
+    }
+
+    if (!layout)
       return failure();
 
     ArrayRef<int64_t> wgShape = vecType.getShape();
@@ -754,8 +764,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+    // Do nothing if layout is a SliceAttr
+    if (auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout)) {
+      if (auto newLayout = layoutAttr.dropSgLayoutAndData()) {
+        xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+      }
+    }
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -763,6 +777,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+<<<<<<< HEAD
+=======
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+    : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    SmallVector<Value> newLoadOps;
+    auto chunkSizeAttr =
+        rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+    VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+    for (auto [offsets, mask] :
+         llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      xegpu::setLayoutAttr(newLoadOp->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+    : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+    if (!valueType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = valueType.getShape();
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    auto chunkSizeOpt = op.getChunkSize();
+    int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+    auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+    for (auto [val, offs, mask] : llvm::zip(
+             adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+      if (auto layoutAttr =
+              op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          op->setAttr("layout_result_0", newLayout);
+      }
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto attr = op->getAttr(layoutName);
+
+    xegpu::DistributeLayoutAttr layoutAttr = nullptr;
+    // Try to get either SliceAttr or LayoutAttr, and keep as is
+    if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
+      layoutAttr = trySlice;
+    } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
+      layoutAttr = tryLayout;
+    }
+
+    if (!layoutAttr)
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+    std::optional<SmallVector<int64_t>> sgShape =
+        getSgShapeAndCount(wgShape, layoutAttr).first;
+    if (!sgShape)
+      return failure();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto maybeOffsets = layoutAttr.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(maybeOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    Value base = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *maybeOffsets) {
+      Value bcast =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+      newOps.push_back(add);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+>>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern)
 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
   LogicalResult
@@ -824,8 +971,14 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+<<<<<<< HEAD
            WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
           patterns.getContext());
+=======
+           WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+           WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
+           WgToSgStoreMatrixOp, WgToSgVectorStepOp>(patterns.getContext());
+>>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern)
 }
 } // namespace xegpu
 } // namespace mlir
@@ -947,9 +1100,50 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         auto vecType = dyn_cast<VectorType>(op.getType());
         if (!vecType)
           return true;
-        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+
+        auto layoutName = xegpu::getLayoutName(op->getResult(0));
+        auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+        if (sliceAttr)
+          return isLegal(sliceAttr);
+
+        auto layoutAttr = op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+        if (layoutAttr)
+          return isLegal(layoutAttr);
+
+        // If neither attribute is present, consider the op legal.
+        return true;
+      });
+
+  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+      [=](xegpu::LoadGatherOp op) -> bool {
+        auto layout = xegpu::getLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+      [=](xegpu::StoreScatterOp op) -> bool {
+        // Check if the layout attribute is present on the result.
+        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+        if (!layout)
+          return true;
+        return isLegal(layout);
       });
 
+  target.addDynamicallyLegalOp<vector::StepOp>([&](vector::StepOp op) -> bool {
+    // Check for either a SliceAttr or LayoutAttr on the result.
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+    if (sliceAttr)
+      return isLegal(sliceAttr);
+
+    auto layoutAttr = op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+    if (layoutAttr)
+      return isLegal(layoutAttr);
+
+    // If neither attribute is present, consider the op legal.
+    return true;
+  });
+
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
       [=](vector::BroadcastOp op) -> bool {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..0869d0346fed7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
 
 //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,43 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_step_op
+  gpu.func @vector_step_op_slice_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+    gpu.return
+  }
+
+  gpu.func @vector_step_op_layout_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+    %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+    gpu.return
+  }
+
+  gpu.func @constant_with_slice_attr() {
+    //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+    gpu.return
+  }
 }

>From ab57d1b226086686aaa3cf294bfb3a2147e677dd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 26 Aug 2025 15:01:43 +0000
Subject: [PATCH 02/13] Add vector.shape_cast pattern

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 180 +++++++-----------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  13 ++
 2 files changed, 78 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 39d0e9ea2e91d..286a74bfebb34 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -535,8 +535,17 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
-    if (!layout || !layout.getSgLayout())
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto attr = op->getAttr(layoutName);
+
+    xegpu::DistributeLayoutAttr layout = nullptr;
+    if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
+      layout = trySlice;
+    } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
+      layout = tryLayout;
+    }
+
+    if (!layout)
       return failure();
 
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -741,7 +750,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto attr = op->getAttr(layoutName);
 
     xegpu::DistributeLayoutAttr layout = nullptr;
-    // Try to get either SliceAttr or LayoutAttr, and keep as is
     if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
       layout = trySlice;
     } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
@@ -764,7 +772,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    // Do nothing if layout is a SliceAttr
     if (auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout)) {
       if (auto newLayout = layoutAttr.dropSgLayoutAndData()) {
         xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
@@ -777,90 +784,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
-<<<<<<< HEAD
-=======
-// This pattern transforms the LoadGatherOp with explicit offsets to load
-// subgroup data, similar to WgToSgLoadNdOpWithOffset.
-struct WgToSgLoadGatherOpWithOffset
-    : public OpConversionPattern<xegpu::LoadGatherOp> {
-  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    if (!op.getOffsets())
-      return failure();
-
-    Location loc = op.getLoc();
-    VectorType resultType = op.getResult().getType();
-    ArrayRef<int64_t> wgShape = resultType.getShape();
-
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-    if (!layout || !layout.getSgLayout())
-      return failure();
-
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-
-    SmallVector<Value> newLoadOps;
-    auto chunkSizeAttr =
-        rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
-    VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
-    for (auto [offsets, mask] :
-         llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
-      auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-      xegpu::setLayoutAttr(newLoadOp->getResult(0),
-                           layout.dropSgLayoutAndData());
-      newLoadOps.push_back(newLoadOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newLoadOps});
-    return success();
-  }
-};
-
-// This pattern transforms the StoreScatterOp with explicit offsets to store
-// subgroup data, similar to WgToSgStoreNdOpWithOffset.
-struct WgToSgStoreScatterOpWithOffset
-    : public OpConversionPattern<xegpu::StoreScatterOp> {
-  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    if (!op.getOffsets())
-      return failure();
-
-    Location loc = op.getLoc();
-    VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
-    if (!valueType)
-      return failure();
-
-    ArrayRef<int64_t> wgShape = valueType.getShape();
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
-    if (!layout || !layout.getSgLayout())
-      return failure();
-
-    auto chunkSizeOpt = op.getChunkSize();
-    int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
-    auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
-    for (auto [val, offs, mask] : llvm::zip(
-             adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
-      rewriter.create<xegpu::StoreScatterOp>(
-          loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
-      // Update the layout_result_0 attribute to drop sg_layout and sg_data.
-      if (auto layoutAttr =
-              op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
-        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
-          op->setAttr("layout_result_0", newLayout);
-      }
-    }
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
   using OpConversionPattern<vector::StepOp>::OpConversionPattern;
   LogicalResult
@@ -870,7 +793,6 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
     auto attr = op->getAttr(layoutName);
 
     xegpu::DistributeLayoutAttr layoutAttr = nullptr;
-    // Try to get either SliceAttr or LayoutAttr, and keep as is
     if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
       layoutAttr = trySlice;
     } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
@@ -909,7 +831,42 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
   }
 };
 
->>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern)
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+    : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newShapeCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newShapeCast =
+          rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+      xegpu::setLayoutAttr(newShapeCast->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newShapeCastOps.push_back(newShapeCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+    return success();
+  }
+};
+
 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
   LogicalResult
@@ -971,14 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-<<<<<<< HEAD
-           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
-          patterns.getContext());
-=======
-           WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
-           WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp, WgToSgVectorStepOp>(patterns.getContext());
->>>>>>> ddbdb0d7eb4f (Add vector.step distribution pattern)
+           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+           WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1114,21 +1065,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return true;
       });
 
-  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
-      [=](xegpu::LoadGatherOp op) -> bool {
-        auto layout = xegpu::getLayoutAttr(op.getResult());
-        return isLegal(layout);
-      });
-
-  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
-      [=](xegpu::StoreScatterOp op) -> bool {
-        // Check if the layout attribute is present on the result.
-        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
-        if (!layout)
-          return true;
-        return isLegal(layout);
-      });
-
   target.addDynamicallyLegalOp<vector::StepOp>([&](vector::StepOp op) -> bool {
     // Check for either a SliceAttr or LayoutAttr on the result.
     auto layoutName = xegpu::getLayoutName(op->getResult(0));
@@ -1149,6 +1085,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<vector::ShapeCastOp>(
+      [=](vector::ShapeCastOp op) -> bool {
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
@@ -1174,8 +1115,17 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
           }
         }
 
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
-        return isLegal(layout);
+        auto layoutName = xegpu::getLayoutName(op->getResult(0));
+        auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+        if (sliceAttr)
+          return isLegal(sliceAttr);
+
+        auto layoutAttr = op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+        if (layoutAttr)
+          return isLegal(layoutAttr);
+
+        // If neither attribute is present, consider the op legal.
+        return true;
       });
 
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 0869d0346fed7..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -356,9 +356,22 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: constant_with_slice_attr
   gpu.func @constant_with_slice_attr() {
     //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
     %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_shape_cast
+  gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+    gpu.return
+  }
 }

>From 2c96a5c1509fb8d4e4a91a0924373220e2f04af6 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 26 Aug 2025 16:36:20 +0000
Subject: [PATCH 03/13] Support slice for shapecast

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 59 +++++++++++--------
 1 file changed, 34 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 286a74bfebb34..1c13f59151a34 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -784,6 +784,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern distributes the vector.step ops to work at subgroup level
 struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
   using OpConversionPattern<vector::StepOp>::OpConversionPattern;
   LogicalResult
@@ -845,8 +846,17 @@ struct WgToSgVectorShapeCastOp
       return failure();
 
     ArrayRef<int64_t> wgShape = resultType.getShape();
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-    if (!layout || !layout.getSgLayout())
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto attr = op->getAttr(layoutName);
+
+    xegpu::DistributeLayoutAttr layout = nullptr;
+    if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
+      layout = trySlice;
+    } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
+      layout = tryLayout;
+    }
+
+    if (!layout)
       return failure();
 
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -857,13 +867,16 @@ struct WgToSgVectorShapeCastOp
     for (auto src : adaptor.getSource()) {
       auto newShapeCast =
           rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
-      xegpu::setLayoutAttr(newShapeCast->getResult(0),
-                           layout.dropSgLayoutAndData());
-      newShapeCastOps.push_back(newShapeCast.getResult());
-    }
+      if (auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(layout)) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData()) {
+          xegpu::setLayoutAttr(newShapeCast->getResult(0), newLayout);
+        }
+        newShapeCastOps.push_back(newShapeCast.getResult());
+      }
 
-    rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
-    return success();
+      rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+      return success();
+    }
   }
 };
 
@@ -1065,31 +1078,27 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return true;
       });
 
-  target.addDynamicallyLegalOp<vector::StepOp>([&](vector::StepOp op) -> bool {
-    // Check for either a SliceAttr or LayoutAttr on the result.
-    auto layoutName = xegpu::getLayoutName(op->getResult(0));
-    auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
-    if (sliceAttr)
-      return isLegal(sliceAttr);
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+      [=](Operation *op) -> bool {
+        // Check for either a SliceAttr or LayoutAttr on the result.
+        auto layoutName = xegpu::getLayoutName(op->getResult(0));
+        auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+        if (sliceAttr)
+          return isLegal(sliceAttr);
 
-    auto layoutAttr = op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
-    if (layoutAttr)
-      return isLegal(layoutAttr);
+        auto layoutAttr = op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+        if (layoutAttr)
+          return isLegal(layoutAttr);
 
-    // If neither attribute is present, consider the op legal.
-    return true;
-  });
+        // If neither attribute is present, consider the op legal.
+        return true;
+      });
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
       [=](vector::BroadcastOp op) -> bool {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp>(
-      [=](vector::ShapeCastOp op) -> bool {
-        return isLegal(xegpu::getLayoutAttr(op.getResult()));
-      });
-
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

>From 6a19aa52c45966a4b291cb5b306e29b533b64f99 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 26 Aug 2025 16:39:38 +0000
Subject: [PATCH 04/13] Clean up

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp       | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 1c13f59151a34..a551e530c0e08 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -793,27 +793,27 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
     auto layoutName = xegpu::getLayoutName(op->getResult(0));
     auto attr = op->getAttr(layoutName);
 
-    xegpu::DistributeLayoutAttr layoutAttr = nullptr;
+    xegpu::DistributeLayoutAttr layout = nullptr;
     if (auto trySlice = dyn_cast_if_present<xegpu::SliceAttr>(attr)) {
-      layoutAttr = trySlice;
+      layout = trySlice;
     } else if (auto tryLayout = dyn_cast_if_present<xegpu::LayoutAttr>(attr)) {
-      layoutAttr = tryLayout;
+      layout = tryLayout;
     }
 
-    if (!layoutAttr)
+    if (!layout)
       return failure();
 
     Location loc = op.getLoc();
     VectorType type = op.getResult().getType();
     auto wgShape = type.getShape();
     std::optional<SmallVector<int64_t>> sgShape =
-        getSgShapeAndCount(wgShape, layoutAttr).first;
+        getSgShapeAndCount(wgShape, layout).first;
     if (!sgShape)
       return failure();
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = layoutAttr.getOffsets(rewriter, loc, sgId, wgShape);
+    auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
     if (failed(maybeOffsets))
       return failure();
 

>From f4d7108dbd22c3266913040ea5d1bd264174c745 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 29 Aug 2025 18:50:06 +0000
Subject: [PATCH 05/13] Temp

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 22 +------------------
 1 file changed, 1 insertion(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 059641af2219a..f749d55e501e9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -467,6 +467,7 @@ struct WgToSgVectorBroadcastOp
   LogicalResult
   matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
@@ -475,34 +476,14 @@ struct WgToSgVectorBroadcastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    // TODO: Currently only supports cases where the source and result ranks
-    // are the same.
-    auto srcType =
-        dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
-    if (!srcType || srcType.getRank() != resultType.getRank())
-      return failure();
-
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    // Check if the output layout is distributable
-    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt();
-    if (sgLayout.empty())
-      return failure();
 
     if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
       return failure();
 
-    // Check if the srcShape has unit dim in dimensions being broadcasted,
-    // and the other dimensions are the same as the destination type
-    // TODO: Generalize it
-    auto srcShape = srcType.getShape();
-    for (size_t i = 0; i < srcShape.size(); ++i) {
-      if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
-        return failure();
-    }
-
     SmallVector<Value> newBroadcastOps;
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
@@ -518,7 +499,6 @@ struct WgToSgVectorBroadcastOp
       }
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
-
     rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
     return success();
   }

>From 8cb5ebef69d5e4b5a564e666adb17b14c0d85397 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 4 Sep 2025 17:10:10 +0000
Subject: [PATCH 06/13] Clean up check

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 48 ++++++-------------
 1 file changed, 14 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index f421f49f96494..af514ca047db8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -487,15 +487,10 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
-        if (sliceAttr.isForSubgroup())
-          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
-                                         sliceAttr.dropSgLayoutAndData());
-      } else if (auto layoutAttr =
-                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
-        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
-          xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
-      }
+      if (!layout.getLaneLayoutAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                       layout.dropSgLayoutAndData());
+
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
     rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
@@ -549,13 +544,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
       // Copy all attributes, but update "layout_result_0" to drop
       // sgLayout/sgData
       for (auto attr : op->getAttrs()) {
-        if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
-          if (auto newLayout = layout.dropSgLayoutAndData())
-            state.addAttribute(attr.getName(), newLayout);
-        } else if (auto sliceAttr =
-                       dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
-          if (sliceAttr.isForSubgroup())
-            state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
+        if (auto layout =
+                dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+          if (!layout.getLaneLayoutAsInt().empty())
+            state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
         } else {
           state.addAttribute(attr.getName(), attr.getValue());
         }
@@ -746,15 +738,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
-      if (sliceAttr.isForSubgroup())
-        xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                       sliceAttr.dropSgLayoutAndData());
-    } else if (auto layoutAttr =
-                   dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
-      if (auto newLayout = layoutAttr.dropSgLayoutAndData())
-        xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
-    }
+    if (!layout.getLaneLayoutAsInt().empty())
+      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -983,15 +969,9 @@ struct WgToSgVectorShapeCastOp
     for (auto src : adaptor.getSource()) {
       auto newShapeCast =
           rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
-      if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
-        if (sliceAttr.isForSubgroup())
-          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
-                                         sliceAttr.dropSgLayoutAndData());
-      } else if (auto layoutAttr =
-                     dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
-        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
-          xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
-      }
+      if (!layout.getLaneLayoutAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+                                       layout.dropSgLayoutAndData());
       newShapeCastOps.push_back(newShapeCast.getResult());
     }
 

>From 1161e28e3e2424b03a974343254306f59c71b44a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 4 Sep 2025 23:45:33 +0000
Subject: [PATCH 07/13] Feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 28 ++++++++-----
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 40 +++++++++----------
 2 files changed, 37 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index af514ca047db8..3b9bd98742080 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -487,7 +487,8 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      if (!layout.getLaneLayoutAsInt().empty())
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getLaneDataAsInt().empty())
         xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
                                        layout.dropSgLayoutAndData());
 
@@ -546,7 +547,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
       for (auto attr : op->getAttrs()) {
         if (auto layout =
                 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
-          if (!layout.getLaneLayoutAsInt().empty())
+          if (!layout.getLaneLayoutAsInt().empty() ||
+              !layout.getLaneDataAsInt().empty())
             state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
         } else {
           state.addAttribute(attr.getName(), attr.getValue());
@@ -738,7 +740,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getLaneLayoutAsInt().empty())
+    if (!layout.getLaneLayoutAsInt().empty() ||
+        !layout.getLaneDataAsInt().empty())
       xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
                                      layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
@@ -923,18 +926,20 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
-    if (failed(maybeOffsets))
+    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
       return failure();
 
     VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
-    Value base = vector::StepOp::create(rewriter, loc, newTy);
+    Value steps = vector::StepOp::create(rewriter, loc, newTy);
     SmallVector<Value> newOps;
-    for (auto offsets : *maybeOffsets) {
-      Value bcast =
+    for (auto offsets : *sgOffsets) {
+      // Broadcast the offset scalar to a vector & add to the base steps
+      Value bcastOffset =
           vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
-      Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
-      newOps.push_back(add);
+      Value finalSteps =
+          arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
+      newOps.push_back(finalSteps);
     }
 
     rewriter.replaceOpWithMultiple(op, {newOps});
@@ -969,7 +974,8 @@ struct WgToSgVectorShapeCastOp
     for (auto src : adaptor.getSource()) {
       auto newShapeCast =
           rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
-      if (!layout.getLaneLayoutAsInt().empty())
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getInstDataAsInt().empty())
         xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
                                        layout.dropSgLayoutAndData());
       newShapeCastOps.push_back(newShapeCast.getResult());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index b7e313512e9b9..27d9fa6b06a7b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -370,15 +370,15 @@ gpu.module @test_distribution {
   // CHECK-LABEL: vector_step_op
   gpu.func @vector_step_op_slice_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+    //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
     //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return
@@ -386,15 +386,15 @@ gpu.module @test_distribution {
 
   gpu.func @vector_step_op_layout_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[c16:%.+]] = arith.constant 16 : index
-    //CHECK: [[c8:%.+]] = arith.constant 8 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
     //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
     %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
     gpu.return
@@ -414,8 +414,8 @@ gpu.module @test_distribution {
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf32>
-    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
-    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32>
+    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [2, 16, 4, 8]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
     gpu.return
   }
 }

>From 9457b54dc1b601df78aae0689d8a179e0dc18129 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 5 Sep 2025 20:43:05 +0000
Subject: [PATCH 08/13] Add check

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 24 +++++++++++++++++++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  4 ++--
 2 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3b9bd98742080..5c15d0749e894 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -970,6 +970,30 @@ struct WgToSgVectorShapeCastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
+    // TODO: Add check for compatible layouts in layout attr.
+    // Only support ShapeCast which expands or reduces unit dims only.
+    // That is, only allow shape casts where the non-unit dimensions are
+    // preserved, and any added or removed dimensions must be of size 1.
+    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    if (!srcType)
+      return failure();
+
+    auto isUnitOrPreserved = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
+      // Remove all 1s from both shapes and compare the rest.
+      SmallVector<int64_t> srcNonUnit, dstNonUnit;
+      for (int64_t d : src)
+        if (d != 1)
+          srcNonUnit.push_back(d);
+      for (int64_t d : dst)
+        if (d != 1)
+          dstNonUnit.push_back(d);
+      return srcNonUnit == dstNonUnit;
+    };
+
+    if (!isUnitOrPreserved(srcType.getShape(), sgShape) ||
+        !isUnitOrPreserved(sgShape, srcType.getShape()))
+      return failure();
+
     SmallVector<Value> newShapeCastOps;
     for (auto src : adaptor.getSource()) {
       auto newShapeCast =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 27d9fa6b06a7b..da015c6c0e4a7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -414,8 +414,8 @@ gpu.module @test_distribution {
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf32>
-    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<2x16x4x8xf32>
-    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [2, 16, 4, 8]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
+    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
     gpu.return
   }
 }

>From b8021edc0eeb20b27998691719aeec176930e6b8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 8 Sep 2025 15:57:57 +0000
Subject: [PATCH 09/13] Feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 25 ++++++++++++-------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  9 +++++++
 2 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5c15d0749e894..0d9ac35f07e02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -931,14 +931,23 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
       return failure();
 
     VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
-    Value steps = vector::StepOp::create(rewriter, loc, newTy);
+    auto steps = vector::StepOp::create(rewriter, loc, newTy);
     SmallVector<Value> newOps;
     for (auto offsets : *sgOffsets) {
       // Broadcast the offset scalar to a vector & add to the base steps
-      Value bcastOffset =
+      auto bcastOffset =
           vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
-      Value finalSteps =
+      auto finalSteps =
           arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getLaneDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(steps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      }
       newOps.push_back(finalSteps);
     }
 
@@ -971,14 +980,12 @@ struct WgToSgVectorShapeCastOp
         VectorType::get(sgShape, resultType.getElementType());
 
     // TODO: Add check for compatible layouts in layout attr.
-    // Only support ShapeCast which expands or reduces unit dims only.
-    // That is, only allow shape casts where the non-unit dimensions are
-    // preserved, and any added or removed dimensions must be of size 1.
     auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
     if (!srcType)
       return failure();
 
-    auto isUnitOrPreserved = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
+    // Check that shape_cast only adds/removes unit dimensions,
+    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
       // Remove all 1s from both shapes and compare the rest.
       SmallVector<int64_t> srcNonUnit, dstNonUnit;
       for (int64_t d : src)
@@ -990,8 +997,8 @@ struct WgToSgVectorShapeCastOp
       return srcNonUnit == dstNonUnit;
     };
 
-    if (!isUnitOrPreserved(srcType.getShape(), sgShape) ||
-        !isUnitOrPreserved(sgShape, srcType.getShape()))
+    if (!onlyUnitDims(srcType.getShape(), sgShape) ||
+        !onlyUnitDims(sgShape, srcType.getShape()))
       return failure();
 
     SmallVector<Value> newShapeCastOps;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index da015c6c0e4a7..7614b8a290ea1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -418,4 +418,13 @@ gpu.module @test_distribution {
     %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index
+  gpu.func @broadcast(%arg0: index, %arg1: index) {
+      %muli = arith.muli %arg0, %arg1 : index
+      // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
+      %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
+      gpu.return
+   }
 }

>From 2739fa83ace612f04afb5c6ec0a4ac50d227f1ba Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 8 Sep 2025 18:16:58 +0000
Subject: [PATCH 10/13] Fix

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0d9ac35f07e02..a05dcc9c474b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -997,8 +997,7 @@ struct WgToSgVectorShapeCastOp
       return srcNonUnit == dstNonUnit;
     };
 
-    if (!onlyUnitDims(srcType.getShape(), sgShape) ||
-        !onlyUnitDims(sgShape, srcType.getShape()))
+    if (!onlyUnitDims(srcType.getShape(), sgShape))
       return failure();
 
     SmallVector<Value> newShapeCastOps;

>From 77f32611477d962d5b46ac034dfc67dbf3e481a1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Sep 2025 04:02:00 +0000
Subject: [PATCH 11/13] Add check

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 17 +++++++++++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 29 +++++++++----------
 2 files changed, 30 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a05dcc9c474b8..82ed77ae7130a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1000,6 +1000,23 @@ struct WgToSgVectorShapeCastOp
     if (!onlyUnitDims(srcType.getShape(), sgShape))
       return failure();
 
+    // Check to verify that if expanding dims, the input operand's layout
+    // is sliceAttr and if reducing dims, result's layout is
+    // sliceAttr.
+    int srcRank = srcType.getRank();
+    int dstRank = sgShape.size();
+    if (dstRank > srcRank) {
+      // Expanding dims: input operand's layout must be a SliceAttr
+      auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
+      if (!srcLayout || !isa<xegpu::SliceAttr>(srcLayout))
+        return failure();
+    } else if (dstRank < srcRank) {
+      // Reducing dims: result's layout must be a SliceAttr
+      auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
+      if (!resLayout || !isa<xegpu::SliceAttr>(resLayout))
+        return failure();
+    }
+
     SmallVector<Value> newShapeCastOps;
     for (auto src : adaptor.getSource()) {
       auto newShapeCast =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 7614b8a290ea1..3478a9b91da5f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -408,23 +408,20 @@ gpu.module @test_distribution {
   }
 
   // CHECK-LABEL: vector_shape_cast
-  gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc[0, 0]
-      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-      -> vector<256x128xf32>
-    //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
-    %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
+  gpu.func @vector_shape_cast() {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} dense<10> : vector<128xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
+    %muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
+    //CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
+    %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
     gpu.return
   }
 
-  // CHECK-LABEL: broadcast
-  // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index
-  gpu.func @broadcast(%arg0: index, %arg1: index) {
-      %muli = arith.muli %arg0, %arg1 : index
-      // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
-      %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
-      gpu.return
-   }
+  // CHECK-LABEL: vector_broadcast
+  gpu.func @vector_broadcast(%arg0: index, %arg1: index) {
+    %muli = arith.muli %arg0, %arg1 : index
+    // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
+    %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
+    gpu.return
+  }
 }

>From d0546b214e244fc36411a93ade6381151b6a282f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 9 Sep 2025 20:29:07 +0000
Subject: [PATCH 12/13] Add check

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp     | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 82ed77ae7130a..c62c90ab9693c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1008,12 +1008,22 @@ struct WgToSgVectorShapeCastOp
     if (dstRank > srcRank) {
       // Expanding dims: input operand's layout must be a SliceAttr
       auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
-      if (!srcLayout || !isa<xegpu::SliceAttr>(srcLayout))
+      auto srcSliceAttr = cast<xegpu::SliceAttr>(srcLayout);
+      if (!srcLayout || !srcSliceAttr)
+        return failure();
+      auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
+      // Check srcLayout is a slice attr on top of resLayout
+      if (srcSliceAttr.getParent() != resLayout)
         return failure();
     } else if (dstRank < srcRank) {
       // Reducing dims: result's layout must be a SliceAttr
       auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
-      if (!resLayout || !isa<xegpu::SliceAttr>(resLayout))
+      auto resSliceAttr = cast<xegpu::SliceAttr>(resLayout);
+      auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
+      if (!resSliceAttr || !srcLayout)
+        return failure();
+      // Check resLayout is a sliced attr from srcLayout
+      if (resSliceAttr.getParent() != srcLayout)
         return failure();
     }
 

>From 9f3446e9d6a0bd81a58b79b3a18f472707e66202 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 12 Sep 2025 17:32:35 +0000
Subject: [PATCH 13/13] Clang-format

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp    | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 30f40ad4969a9..d7592fed6d186 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1000,14 +1000,12 @@ struct WgToSgVectorShapeCastOp
     if (!onlyUnitDims(srcType.getShape(), sgShape))
       return failure();
 
-    // Check to verify that if expanding dims, the input operand's layout
-    // is sliceAttr and if reducing dims, result's layout is
-    // sliceAttr.
     // For rank reducing or increasing shape_cast ops, the lower rank layout
     // must be a slice of higher rank layout.
-    int64_t sourceRank = srcType.getRank();;
+    int64_t sourceRank = srcType.getRank();
     int64_t resultRank = sgShape.size();
-    xegpu::DistributeLayoutAttr sourceLayout = xegpu::getDistributeLayoutAttr(op.getSource());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(op.getSource());
     if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
       return failure();
     if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))



More information about the Mlir-commits mailing list