[Mlir-commits] [mlir] [MLIR][XeGPU] Reapply attempt for "Scattered ops sg-to-wi distribution #154949" (PR #156924)

Artem Kroviakov llvmlistbot at llvm.org
Thu Sep 4 09:43:14 PDT 2025


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/156924

This PR is a reapply of https://github.com/llvm/llvm-project/pull/154949, which failed one of sanitizer checks.
The issue was querying the `warpOp` results in `LoadDistribution` after calling `moveRegionToNewWarpOpAndAppendReturns()`, which resulted in use after free. This PR solves the issue by moving the op query before the call and is otherwise identical to the one linked above.

>From e174b696ecdaf379df6807c600d5dceaf797c74f Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sat, 23 Aug 2025 10:22:34 +0000
Subject: [PATCH 1/8] [MLIR][XeGPU] Scattered ops sg-to-wi distribution

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 141 +++++++++++++++++-
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  42 ++++++
 2 files changed, 179 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index dddb5eaece2cb..3c3a52581ce90 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    else if (!storeScatterOp.getOffsets())
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Store op must have offsets argument");
+    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+                 .getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets vector");
+
+    VectorType storeVecTy =
+        cast<VectorType>(storeScatterOp.getValue().getType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
+    if (storeVecTy.getRank() == 2)
+      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+    else // rank 1
+      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+    operandTypes[0] = distStoreVecTy;
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    Value offsetsVec = newStoreScatterOpOperands[2];
+    Value maskVec = newStoreScatterOpOperands[3];
+
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newStoreScatterOpOperands[2] = laneOffset;
+    newStoreScatterOpOperands[3] = laneMask;
+
+    xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+        rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+        storeScatterOp->getAttrs());
+    xegpu::removeLayoutAttrs(newOp);
+    rewriter.eraseOp(storeScatterOp);
+    return success();
+  }
+};
+
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
+      if (!isa<xegpu::LoadGatherOp>(op))
+        return false;
+      auto yield = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      return yield->getPrevNode() == op;
+    });
+    if (!yieldOperand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a xegpu::LoadGatherOp op");
+
+    auto loadGatherOp =
+        yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
+    if (!loadGatherOp.getOffsets())
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Load op must have offsets argument");
+    else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
+             1)
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Expected 1D offsets vector");
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+    SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    VectorType loadVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
+
+    Value offsetsVec = newLoadGatherOperands[1];
+    Value maskVec = newLoadGatherOperands[2];
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newLoadGatherOperands[1] = laneOffset;
+    newLoadGatherOperands[2] = laneMask;
+
+    xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+        loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               UpdateNdOffsetDistribution, GpuBarrierDistribution>(
-      patterns.getContext());
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+           GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+          patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
+      if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
+        continue;      
       auto layout =
           xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
       if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..b319162dc3f25 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,45 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}

>From 6d2296888df03dfdf08c82b3a39bb890ed465d4c Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 25 Aug 2025 18:04:16 +0000
Subject: [PATCH 2/8] Assume distributable offset and mask producers

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 54 ++++++-------------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 30 +++++------
 2 files changed, 31 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 3c3a52581ce90..cf2c933e16cb4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
     if (!storeScatterOp)
       return failure();
-    else if (!storeScatterOp.getOffsets())
+    if (!storeScatterOp.getOffsets())
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Store op must have offsets argument");
-    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
-                 .getRank() != 1)
+    VectorType offsetsTy =
+        cast<VectorType>(storeScatterOp.getOffsets().getType());
+    if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Expected 1D offsets vector");
-
     VectorType storeVecTy =
         cast<VectorType>(storeScatterOp.getValue().getType());
     assert(storeVecTy.getRank() <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
 
     SmallVector<size_t> newRetIndices;
-    SmallVector<Value> operands =
-        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Value> operands = storeScatterOp->getOperands();
     SmallVector<Type> operandTypes =
         llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
     operandTypes[0] = distStoreVecTy;
+    // Assume offset and mask pproducers will be distributed as well.
+    operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypes[3] = VectorType::get(
+        {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
     SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    Value offsetsVec = newStoreScatterOpOperands[2];
-    Value maskVec = newStoreScatterOpOperands[3];
-
     auto loc = newWarpOp.getLoc();
-    Value laneId = warpOp.getLaneid();
     rewriter.setInsertionPointAfter(newWarpOp);
-    Value laneOffset =
-        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
-    laneOffset = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
-    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
-    laneMask = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
-    newStoreScatterOpOperands[2] = laneOffset;
-    newStoreScatterOpOperands[3] = laneMask;
-
     xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
         rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     if (!loadGatherOp.getOffsets())
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Load op must have offsets argument");
-    else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
-             1)
+    VectorType offsetsTy =
+        cast<VectorType>(loadGatherOp.getOffsets().getType());
+    if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Expected 1D offsets vector");
 
     SmallVector<size_t> newRetIndices;
-    SmallVector<Value> operands =
-        llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+    SmallVector<Value> operands = loadGatherOp->getOperands();
     SmallVector<Type> operandTypes =
         llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+    // Assume offset and mask pproducers will be distributed as well.
+    operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypes[2] = VectorType::get(
+        {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
 
-    Value offsetsVec = newLoadGatherOperands[1];
-    Value maskVec = newLoadGatherOperands[2];
     auto loc = newWarpOp.getLoc();
-    Value laneId = warpOp.getLaneid();
     rewriter.setInsertionPointAfter(newWarpOp);
-    Value laneOffset =
-        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
-    laneOffset = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
-    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
-    laneMask = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
-    newLoadGatherOperands[1] = laneOffset;
-    newLoadGatherOperands[2] = laneMask;
-
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
         loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
     Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index b319162dc3f25..1c4684681b62b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -323,19 +323,18 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
-  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
-    %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
@@ -344,19 +343,18 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
-  gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+  gpu.func @scatter_ops(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
-    %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-    xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }

>From a4d4e66062f797ff5fa1156f8450e55d3f124c46 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 26 Aug 2025 14:23:30 +0000
Subject: [PATCH 3/8] Address feedback

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 123 ++++++++++++------
 1 file changed, 82 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index cf2c933e16cb4..84278964cbb63 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+///   no chunk size -> vector of one element.
+///   chunk size    -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+///    %mask = producer_op : vector<16xi1>
+///    %offset = producer_op : vector<16xindex>
+///    xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+///     memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+///    %mask = producer_op : vector<1xi1>
+///    %offset = producer_op : vector<1xindex>
+///    xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+///    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+///    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 struct StoreDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
-        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-    Operation *lastNode = yield->getPrevNode();
+    Operation *lastNode = warpOp.getTerminator()->getPrevNode();
     auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
     if (!storeScatterOp)
       return failure();
-    if (!storeScatterOp.getOffsets())
-      return rewriter.notifyMatchFailure(storeScatterOp,
-                                         "Store op must have offsets argument");
-    VectorType offsetsTy =
-        cast<VectorType>(storeScatterOp.getOffsets().getType());
+    auto offsets = storeScatterOp.getOffsets();
+    if (!offsets || !isa<VectorType>(offsets.getType()))
+      return rewriter.notifyMatchFailure(
+          storeScatterOp, "Store op must have a vector of offsets argument");
+    VectorType offsetsTy = cast<VectorType>(offsets.getType());
     if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Expected 1D offsets vector");
-    VectorType storeVecTy =
-        cast<VectorType>(storeScatterOp.getValue().getType());
+    VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
     assert(storeVecTy.getRank() <= 2 &&
            "Expected at most 2D result at SG level");
     VectorType distStoreVecTy;
@@ -837,23 +858,23 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = storeScatterOp->getOperands();
-    SmallVector<Type> operandTypes =
+    SmallVector<Type> operandTypesToYield =
         llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
-    operandTypes[0] = distStoreVecTy;
-    // Assume offset and mask pproducers will be distributed as well.
-    operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    operandTypes[3] = VectorType::get(
+    operandTypesToYield[0] = distStoreVecTy;
+    // Assume offset and mask producers will be distributed as well.
+    operandTypesToYield[2] =
+        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypesToYield[3] = VectorType::get(
         {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, operands, operandTypes, newRetIndices);
+        rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
     SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    auto loc = newWarpOp.getLoc();
     rewriter.setInsertionPointAfter(newWarpOp);
     xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
-        rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+        rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
     rewriter.eraseOp(storeScatterOp);
@@ -861,56 +882,75 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Distribute a scattered load op. The logic and requirements are the same as
+/// for the scattered store distribution. The warpOp's payload vector is
+/// expected to be distributed by the load's result consumer.
+/// Example 1 (no chunk size):
+///    %mask = producer_op : vector<16xi1>
+///    %offset = producer_op : vector<16xindex>
+///    %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+///    vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To
+///    %mask = producer_op : vector<1xi1>
+///    %offset = producer_op : vector<1xindex>
+///    %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
+/// Example 2 (chunk size, same mask and offsets):
+///    %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To
+///    %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
 struct LoadDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
-      if (!isa<xegpu::LoadGatherOp>(op))
-        return false;
-      auto yield = cast<gpu::YieldOp>(
-          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-      return yield->getPrevNode() == op;
+    OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+      // Check if the yield operand that was produced by the *last* scattered
+      // load op to avoid sinking it before barriers (maintain memory order).
+      return isa<xegpu::LoadGatherOp>(op) &&
+             warpOp.getTerminator()->getPrevNode() == op;
     });
-    if (!yieldOperand)
+    if (!producedByLastLoad)
       return rewriter.notifyMatchFailure(
-          warpOp, "warp result is not a xegpu::LoadGatherOp op");
+          warpOp, "The last op is not xegpu::LoadGatherOp");
 
     auto loadGatherOp =
-        yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
-    if (!loadGatherOp.getOffsets())
-      return rewriter.notifyMatchFailure(loadGatherOp,
-                                         "Load op must have offsets argument");
-    VectorType offsetsTy =
-        cast<VectorType>(loadGatherOp.getOffsets().getType());
+        producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
+    auto offsets = loadGatherOp.getOffsets();
+    if (!offsets || !isa<VectorType>(offsets.getType()))
+      return rewriter.notifyMatchFailure(
+          loadGatherOp, "Load op must have a vector of offsets argument");
+    VectorType offsetsTy = cast<VectorType>(offsets.getType());
     if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Expected 1D offsets vector");
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = loadGatherOp->getOperands();
-    SmallVector<Type> operandTypes =
+    SmallVector<Type> operandTypesToYield =
         llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
-    // Assume offset and mask pproducers will be distributed as well.
-    operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    operandTypes[2] = VectorType::get(
-        {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
+    // Assume offset and mask producers will be distributed as well.
+    operandTypesToYield[1] =
+        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypesToYield[2] =
+        VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, operands, operandTypes, newRetIndices);
+        rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
 
     SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    const unsigned operandIdx = producedByLastLoad->getOperandNumber();
     VectorType loadVecTy =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
 
-    auto loc = newWarpOp.getLoc();
     rewriter.setInsertionPointAfter(newWarpOp);
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
-        loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+        newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
+        loadGatherOp->getAttrs());
     Value distributedVal = newWarpOp.getResult(operandIdx);
     rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
     return success();
@@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
+      // Vectors operands of these ops have a fixed and implicit layout.
       if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
         continue;      
       auto layout =

>From daa143f5572534839993b880dc441364cbf511fb Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 28 Aug 2025 14:25:17 +0000
Subject: [PATCH 4/8] Add layout-based distribution

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 61 ++++++++----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 98 ++++++++++++++-----
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 34 +++++++
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 47 ++++++---
 4 files changed, 182 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 5cb47b2accd68..46c5777d1c157 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
+static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
+                                           bool scattered = false) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -207,6 +208,14 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+  if (scattered) {
+    packingFactor =
+        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
+            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+            : 1;
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+                      LaneData({1, packingFactor}));
+  }
   if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
     packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
   return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
@@ -214,7 +223,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
+                                           bool scattered = false) {
   // Expecting a 1D or 2D vector.
   assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
          "Expected 1D or 2D TensorDesc.");
@@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
 
-  if (tdescTy.isScattered()) {
+  if (scattered) {
     int packingFactor =
         bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
             ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -541,21 +551,27 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
 }
 
-/// Propagate the layout of the result to the tensor descriptor and mask
+/// Propagate the layout of the result to the tensor descriptor, mask and offset
 /// operands in LoadGatherOp.
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  // The layout is strictly determined by the tensor descriptor type.
-  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
+  // The layout is strictly determined by the payload type.
+  auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+  assert(payloadTy && "Only vector payload distribution is supported");
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
 
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
 
   // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(layout));
-  // Propagate the new layout to the mask operand.
+  if (isa<xegpu::TensorDescType>(load.getSourceType()))
+    propagateIfChanged(operands[0], operands[0]->meet(layout));
+  // Propagate the new layout to the mask and optional offset operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+  if (load.getOffsets()) {
+    propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+  }
 }
 
 /// Propagate the layout of the descriptor to the vector offset operand in
@@ -572,31 +588,36 @@ void LayoutInfoPropagation::visitCreateDescOp(
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
-/// Set the layout for the value, tensor descriptor, and mask operands in the
-/// StoreScatterOp.
+/// Set the layout for the value, tensor descriptor, offset and mask operands in
+/// the StoreScatterOp.
 void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
   // Currently, for 2D StoreScatterOp we expect that the height dimension of
   // the tensor descriptor is equal to the subgroup size. This is ensured by
   // the op verifier.
-  ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
-  if (tdescShape.size() > 1)
+  auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+  assert(payloadTy && "Only vector payload distribution is supported");
+  auto payloadShape = payloadTy.getShape();
+  if (payloadShape.size() > 1)
     assert(
-        tdescShape[0] == xegpu::targetinfo::subgroupSize &&
+        payloadShape[0] == xegpu::targetinfo::subgroupSize &&
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo layout =
-      getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+  LayoutInfo payloadLayout =
+      getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
 
-  // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(layout));
-  // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(layout));
-  // Use default 1D layout for mask operand.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+  // Propagate the payload operand layout
+  propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
+  // Propagate the destination (if tdesc) operand layout
+  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+    propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
+  // Propagate the new layout to the mask and optional offset operand.
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+  if (storeScatter.getOffsets())
+    propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 84278964cbb63..9bb0a2160f82e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -844,9 +844,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       return rewriter.notifyMatchFailure(
           storeScatterOp, "Store op must have a vector of offsets argument");
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
-    if (offsetsTy.getRank() != 1)
+    VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
+    if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
       return rewriter.notifyMatchFailure(storeScatterOp,
-                                         "Expected 1D offsets vector");
+                                         "Expected 1D offsets and mask vector");
     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
     assert(storeVecTy.getRank() <= 2 &&
            "Expected at most 2D result at SG level");
@@ -855,17 +856,45 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
     else // rank 1
       distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
-
-    SmallVector<size_t> newRetIndices;
-    SmallVector<Value> operands = storeScatterOp->getOperands();
-    SmallVector<Type> operandTypesToYield =
-        llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
-    operandTypesToYield[0] = distStoreVecTy;
     // Assume offset and mask producers will be distributed as well.
-    operandTypesToYield[2] =
+    VectorType distOffsetsTy =
         VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    operandTypesToYield[3] = VectorType::get(
+    VectorType distMaskTy = VectorType::get(
         {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
+    std::string layoutPayloadName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
+    std::string layoutOffsetsName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
+    std::string layoutMaskName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
+
+    xegpu::LayoutAttr layoutPayload =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
+    xegpu::LayoutAttr layoutOffsets =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+    xegpu::LayoutAttr layoutMask =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+    FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
+    FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+    FailureOr<VectorType> distMaskByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+    if (failed(distStoreVecByWarpOpOrFailure) ||
+        failed(distOffsetsByWarpOpOrFailure) ||
+        failed(distMaskByWarpOpOrFailure)) {
+      storeScatterOp.emitWarning(
+          "Some vector operands have no layouts, using defaults instead.");
+    }
+    distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
+    distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
+    distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands = storeScatterOp->getOperands();
+    SmallVector<Type> operandTypesToYield = {
+        distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -918,23 +947,47 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     auto loadGatherOp =
         producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
     auto offsets = loadGatherOp.getOffsets();
-    if (!offsets || !isa<VectorType>(offsets.getType()))
+    if (!offsets || !isa<VectorType>(offsets.getType()) ||
+        !isa<VectorType>(loadGatherOp.getMask().getType()))
       return rewriter.notifyMatchFailure(
-          loadGatherOp, "Load op must have a vector of offsets argument");
+          loadGatherOp,
+          "Load op must have a vector arguments for offsets and mask");
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
-    if (offsetsTy.getRank() != 1)
+    VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
+    if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
       return rewriter.notifyMatchFailure(loadGatherOp,
-                                         "Expected 1D offsets vector");
+                                         "Expected 1D offsets and mask vector");
+    // Assume offset and mask producers will be distributed as well.
+    VectorType distOffsetsTy =
+        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));
+
+    std::string layoutOffsetsName =
+        xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
+    std::string layoutMaskName =
+        xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
+
+    xegpu::LayoutAttr layoutOffsets =
+        loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+    xegpu::LayoutAttr layoutMask =
+        loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+    FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+    FailureOr<VectorType> distMaskByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+    if (failed(distOffsetsByWarpOpOrFailure) ||
+        failed(distMaskByWarpOpOrFailure)) {
+      loadGatherOp.emitWarning(
+          "Some vector operands have no layouts, using defaults instead.");
+    }
+    distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
+    distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = loadGatherOp->getOperands();
-    SmallVector<Type> operandTypesToYield =
-        llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
-    // Assume offset and mask producers will be distributed as well.
-    operandTypesToYield[1] =
-        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    operandTypesToYield[2] =
-        VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType()));
+    SmallVector<Type> operandTypesToYield = {operands[0].getType(),
+                                             distOffsetsTy, distMaskTy};
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -951,6 +1004,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
         newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
         loadGatherOp->getAttrs());
+    xegpu::removeLayoutAttrs(newOp);
     Value distributedVal = newWarpOp.getResult(operandIdx);
     rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
     return success();
@@ -990,7 +1044,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
 
       // Vectors operands of these ops have a fixed and implicit layout.
       if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
-        continue;      
+        continue;
       auto layout =
           xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
       if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 0214d84f2c16f..cba3f0bd690c3 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
   return
 }
 
+// -----
+// CHECK-LABEL: func.func @scatter_ops_chunksize(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+      : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+      : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @scatter_ops(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops(%src: memref<256xf16>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+  xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  return
+}
+
 // -----
 // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 1c4684681b62b..ddb279eb070ff 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -324,18 +324,14 @@ gpu.module @test {
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
 // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
-// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
-// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
   gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
     %offset = arith.constant dense<12> : vector<16xindex>
-    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }
@@ -344,18 +340,37 @@ gpu.module @test {
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
 // CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
-// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
-// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
   gpu.func @scatter_ops(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
     %offset = arith.constant dense<12> : vector<16xindex>
-    %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-    xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops(%src: memref<256xf16>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 {
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset], %1 {
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }

>From bcc9d85b2c953eaf8fa6dd588a8f2aaafa6cd406 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 29 Aug 2025 16:26:33 +0000
Subject: [PATCH 5/8] Address feedback

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 +++++----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 44 +++++++------------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 27 ++++--------
 3 files changed, 39 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 46c5777d1c157..c0c4394f73d4a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -195,7 +195,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
 
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           bool scattered = false) {
+                                           bool isScattered = false) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -208,7 +208,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  if (scattered) {
+  if (isScattered) {
     packingFactor =
         bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
             ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -224,7 +224,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
 
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
-                                           bool scattered = false) {
+                                           bool isScattered = false) {
   // Expecting a 1D or 2D vector.
   assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
          "Expected 1D or 2D TensorDesc.");
@@ -237,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
 
-  if (scattered) {
+  if (isScattered) {
     int packingFactor =
         bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
             ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -558,7 +558,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     ArrayRef<const LayoutInfoLattice *> results) {
   // The layout is strictly determined by the payload type.
   auto payloadTy = dyn_cast<VectorType>(load.getValueType());
-  assert(payloadTy && "Only vector payload distribution is supported");
+  if (!payloadTy) {
+    load.emitWarning("Not propagating, non-vector payload supplied.");
+    return;
+  }
   LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
 
   // Mask operand should have 1D default layout.
@@ -569,9 +572,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the new layout to the mask and optional offset operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
-  if (load.getOffsets()) {
+  if (load.getOffsets())
     propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
-  }
 }
 
 /// Propagate the layout of the descriptor to the vector offset operand in
@@ -597,7 +599,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   // the tensor descriptor is equal to the subgroup size. This is ensured by
   // the op verifier.
   auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
-  assert(payloadTy && "Only vector payload distribution is supported");
+  if (!payloadTy) {
+    storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+    return;
+  }
   auto payloadShape = payloadTy.getShape();
   if (payloadShape.size() > 1)
     assert(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9bb0a2160f82e..7b9c8ff3e6f6f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -849,18 +849,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Expected 1D offsets and mask vector");
     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
-    assert(storeVecTy.getRank() <= 2 &&
-           "Expected at most 2D result at SG level");
-    VectorType distStoreVecTy;
-    if (storeVecTy.getRank() == 2)
-      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
-    else // rank 1
-      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
-    // Assume offset and mask producers will be distributed as well.
-    VectorType distOffsetsTy =
-        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    VectorType distMaskTy = VectorType::get(
-        {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
+    if (storeVecTy.getRank() > 2)
+      return rewriter.notifyMatchFailure(
+          storeScatterOp, "Expected at most 2D result at SG level");
+
     std::string layoutPayloadName =
         xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
     std::string layoutOffsetsName =
@@ -884,17 +876,20 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     if (failed(distStoreVecByWarpOpOrFailure) ||
         failed(distOffsetsByWarpOpOrFailure) ||
         failed(distMaskByWarpOpOrFailure)) {
-      storeScatterOp.emitWarning(
+      return rewriter.notifyMatchFailure(
+          storeScatterOp,
           "Some vector operands have no layouts, using defaults instead.");
     }
-    distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
-    distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
-    distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
+    VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
+    VectorType expectedPayloadTy = VectorType::get(
+        {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = storeScatterOp->getOperands();
     SmallVector<Type> operandTypesToYield = {
-        distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};
+        expectedPayloadTy, operands[1].getType(),
+        distOffsetsByWarpOpOrFailure.value(),
+        distMaskByWarpOpOrFailure.value()};
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -958,10 +953,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Expected 1D offsets and mask vector");
     // Assume offset and mask producers will be distributed as well.
-    VectorType distOffsetsTy =
-        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
-    VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));
-
     std::string layoutOffsetsName =
         xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
     std::string layoutMaskName =
@@ -978,16 +969,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
     if (failed(distOffsetsByWarpOpOrFailure) ||
         failed(distMaskByWarpOpOrFailure)) {
-      loadGatherOp.emitWarning(
+      return rewriter.notifyMatchFailure(
+          loadGatherOp,
           "Some vector operands have no layouts, using defaults instead.");
     }
-    distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
-    distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = loadGatherOp->getOperands();
-    SmallVector<Type> operandTypesToYield = {operands[0].getType(),
-                                             distOffsetsTy, distMaskTy};
+    SmallVector<Type> operandTypesToYield = {
+        operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
+        distMaskByWarpOpOrFailure.value()};
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -998,7 +989,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     const unsigned operandIdx = producedByLastLoad->getOperandNumber();
     VectorType loadVecTy =
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
-    assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
 
     rewriter.setInsertionPointAfter(newWarpOp);
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index ddb279eb070ff..5a4030ce4bead 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -330,24 +330,15 @@ gpu.module @test {
   gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
     %offset = arith.constant dense<12> : vector<16xindex>
-    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-    gpu.return
-  }
-}
-
-// -----
-// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-gpu.module @test {
-  gpu.func @scatter_ops(%src: memref<256xf16>) {
-    %1 = arith.constant dense<1>: vector<16xi1>
-    %offset = arith.constant dense<12> : vector<16xindex>
-    %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-    xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> {
+      layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
+      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    } : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }

>From ffde76c8e1871bea2b1dda9960237d556f8182d4 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 1 Sep 2025 15:01:28 +0000
Subject: [PATCH 6/8] Remove exceptions

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7b9c8ff3e6f6f..b4919932f1ce4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1032,9 +1032,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
-      // Vectors operands of these ops have a fixed and implicit layout.
-      if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
-        continue;
       auto layout =
           xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
       if (!layout) {

>From 6bafb05a129f674b685615b674fb45c12192c0de Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 1 Sep 2025 15:22:18 +0000
Subject: [PATCH 7/8] Restructure testing

---
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 26 ++++++-------------
 1 file changed, 8 insertions(+), 18 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 5a4030ce4bead..a39aa90bbe3a8 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -328,17 +328,12 @@ gpu.module @test {
 // CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
   gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
-    %1 = arith.constant dense<1>: vector<16xi1>
-    %offset = arith.constant dense<12> : vector<16xindex>
+    %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
     %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
-      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
     } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> {
-      layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-    } : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }
@@ -351,17 +346,12 @@ gpu.module @test {
 // CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
   gpu.func @scatter_ops(%src: memref<256xf16>) {
-    %1 = arith.constant dense<1>: vector<16xi1>
-    %offset = arith.constant dense<12> : vector<16xindex>
+    %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
     %3 = xegpu.load %src[%offset], %1 {
-      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-    xegpu.store %3, %src[%offset], %1 {
-      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-    } : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }

>From 2a6c7d678cfbfbfaebde7f98ceb42dc4c2d9c294 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 4 Sep 2025 16:33:17 +0000
Subject: [PATCH 8/8] Query warpOp results before moveRegion

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp  | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b4919932f1ce4..6b8367dd8c201 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -980,16 +980,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
         distMaskByWarpOpOrFailure.value()};
 
+    const unsigned operandIdx = producedByLastLoad->getOperandNumber();
+    VectorType loadVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
 
     SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    const unsigned operandIdx = producedByLastLoad->getOperandNumber();
-    VectorType loadVecTy =
-        cast<VectorType>(warpOp.getResult(operandIdx).getType());
-
     rewriter.setInsertionPointAfter(newWarpOp);
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
         newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,



More information about the Mlir-commits mailing list