[Mlir-commits] [mlir] [MLIR][XeGPU] Add uArch limitation to scatter load store (PR #172845)

Artem Kroviakov llvmlistbot at llvm.org
Fri Jan 23 08:45:45 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/172845

>From 0cbdce947fda2038d9d89a53e263c27d72142d53 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 18 Dec 2025 12:30:40 +0000
Subject: [PATCH 1/6] [MLIR][XeGPU] Add uArch limitation to scatter load store

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 34 +++++++++++++--
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |  8 +++-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 41 +++++++++++++------
 .../XeGPU/propagate-layout-inst-data.mlir     | 22 ++++++++++
 4 files changed, 87 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index b3231a173f33a..055c10ab50652 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,6 +215,26 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
   const unsigned packedFormatBitSizeB;
 };
 
+struct StoreScatterInstruction : public Instruction {
+  StoreScatterInstruction()
+      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreScatter;
+  }
+
+  int32_t getMaxBitSize() const { return 128; }
+};
+
+struct LoadGatherInstruction : public Instruction {
+  LoadGatherInstruction()
+      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadGather;
+  }
+
+  int32_t getMaxBitSize() const { return 128; }
+};
+
 //===----------------------------------------------------------------------===//
 // uArch instances
 //===----------------------------------------------------------------------===//
@@ -225,8 +245,11 @@ struct PVCuArch final : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
-                                       &prefetchNdInst};
+    static const StoreScatterInstruction storeScatterInst;
+    static const LoadGatherInstruction loadGatherInst;
+    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
+                                       &storeNdInst,      &prefetchNdInst,
+                                       &storeScatterInst, &loadGatherInst};
     return arr;
   }
 
@@ -248,8 +271,11 @@ struct BMGuArch : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
-                                       &prefetchNdInst};
+    static const StoreScatterInstruction storeScatterInst;
+    static const LoadGatherInstruction loadGatherInst;
+    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
+                                       &storeNdInst,      &prefetchNdInst,
+                                       &storeScatterInst, &loadGatherInst};
     return arr;
   }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f23b89134773..db1984b2edb1d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -38,7 +38,9 @@ enum class InstructionKind {
                              // matrix multiply-add operation
   Subgroup2DBlockStore,      // Subgroup-level 2D block write instruction
   Subgroup2DBlockLoad,       // Subgroup-level 2D block load instruction
-  Subgroup2DBlockPrefetch    // Subgroup-level 2D block prefetch instruction
+  Subgroup2DBlockPrefetch,   // Subgroup-level 2D block prefetch instruction
+  StoreScatter,              // Lane-level store (scalar, vector)
+  LoadGather                 // Lane-level load (scalar, vector)
   // @TODO: Add more instructions as needed
 };
 
@@ -65,6 +67,10 @@ struct Instruction {
       return "load_nd";
     case InstructionKind::Subgroup2DBlockPrefetch:
       return "prefetch_nd";
+    case InstructionKind::StoreScatter:
+      return "store";
+    case InstructionKind::LoadGather:
+      return "load";
     }
     llvm_unreachable("Unknown InstructionKind");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e8eadb6de5b30..faaa8f95014b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1154,15 +1154,21 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       return;
     }
     auto uArch = getUArch(getChipStr(load).value_or(""));
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+            uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+    const int maxElemsPerInst =
+        uArchInstruction->getMaxBitSize() /
+        payloadTy.getElementType().getIntOrFloatBitWidth();
+
     const int subgroupSize = uArch->getSubgroupSize();
     SmallVector<int> instData{subgroupSize};
-    if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
-      instData.push_back(chunkSize);
-    else if (auto srcTdescTy =
-                 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
-      if (srcTdescTy.getChunkSizeAsInt() > 1)
-        instData.push_back(chunkSize);
+    auto chunkSize = load.getChunkSize().value_or(0);
+    if (auto srcTdescTy = dyn_cast<xegpu::TensorDescType>(load.getSourceType());
+        !chunkSize && srcTdescTy) {
+      chunkSize = srcTdescTy.getChunkSizeAsInt();
     }
+    instData.push_back(std::min(static_cast<int>(chunkSize), maxElemsPerInst));
 
     if (layoutKind == LayoutKind::InstData)
       loadLayout =
@@ -1226,15 +1232,24 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     const int subgroupSize = uArch->getSubgroupSize();
 
     if (layoutKind == LayoutKind::InstData) {
+      const auto *uArchInstruction =
+          dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+              uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+      const int maxElemsPerInst =
+          uArchInstruction->getMaxBitSize() /
+          payloadTy.getElementType().getIntOrFloatBitWidth();
+
+      const int subgroupSize = uArch->getSubgroupSize();
       SmallVector<int> instData{subgroupSize};
-      if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
-          chunkSize > 1)
-        instData.push_back(chunkSize);
-      else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
-                   storeScatter.getDestType())) {
-        if (dstTdescTy.getChunkSizeAsInt() > 1)
-          instData.push_back(chunkSize);
+      auto chunkSize = storeScatter.getChunkSize().value_or(0);
+      if (auto srcTdescTy =
+              dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType());
+          !chunkSize && srcTdescTy) {
+        chunkSize = srcTdescTy.getChunkSizeAsInt();
       }
+      instData.push_back(
+          std::min(static_cast<int>(chunkSize), maxElemsPerInst));
+
       payloadLayout = LayoutInfo(
           xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
     } else {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index aef9d7fc9e03a..eda60ec0ef163 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -162,6 +162,28 @@ gpu.module @test {
 func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   %cst = arith.constant dense<0.0000> : vector<16x16xf16>
   xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 4]>} : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=16}>
+      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
   return
 }
 }

>From ea64cc9906263c58721a8fafb4faf78080f8aa5b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 19 Dec 2025 12:30:18 +0000
Subject: [PATCH 2/6] LoadGather to modify its inst_data

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 70 ++++++++++++++-----
 .../XeGPU/propagate-layout-inst-data.mlir     | 26 ++++++-
 2 files changed, 75 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index faaa8f95014b9..e9c1332982ffa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1146,8 +1146,19 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     loadLayout = LayoutInfo(anchorLayout);
     maskLayout = loadLayout;
   } else {
+    LayoutInfo valueLayout = results[0]->getValue();
+    // Need the layout of the value to propagate to the tensor descriptor.
+    if (!valueLayout.isAssigned())
+      return;
+
+    auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
+    auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
+    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+      instDataIncoming = SmallVector<int64_t>(
+          cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
+              .getInstData()
+              .asArrayRef());
 
-    // The layout is strictly determined by the payload type.
     VectorType payloadTy = load.getValueType();
     if (!payloadTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
@@ -1157,28 +1168,48 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     const auto *uArchInstruction =
         dyn_cast<xegpu::uArch::LoadGatherInstruction>(
             uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
-    const int maxElemsPerInst =
-        uArchInstruction->getMaxBitSize() /
-        payloadTy.getElementType().getIntOrFloatBitWidth();
 
     const int subgroupSize = uArch->getSubgroupSize();
-    SmallVector<int> instData{subgroupSize};
-    auto chunkSize = load.getChunkSize().value_or(0);
-    if (auto srcTdescTy = dyn_cast<xegpu::TensorDescType>(load.getSourceType());
-        !chunkSize && srcTdescTy) {
-      chunkSize = srcTdescTy.getChunkSizeAsInt();
-    }
-    instData.push_back(std::min(static_cast<int>(chunkSize), maxElemsPerInst));
-
-    if (layoutKind == LayoutKind::InstData)
-      loadLayout =
-          LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
-    else
-      loadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
-
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
 
+    // Check if value inst_data complies with uArch
+    if (!instDataIncoming.empty()) {
+      const int maxElemsPerInst =
+          uArchInstruction->getMaxBitSize() /
+          payloadTy.getElementType().getIntOrFloatBitWidth();
+
+      xegpu::LayoutAttr sourceAttr;
+      // Each lane loads either one element
+      SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
+      // Or multiple elements as 2D with lane's elements in the inner dimension
+      if (payloadTy.getRank() == 1) {
+        instDataUarch.back() = subgroupSize;
+      } else {
+        *std::prev(instDataUarch.end(), 2) = subgroupSize;
+        instDataUarch.back() = (std::min(
+            static_cast<int>(payloadTy.getShape().back()), maxElemsPerInst));
+      }
+      // If inst data does not match, enforce the uArch-based one
+      if (!llvm::equal(instDataIncoming, instDataUarch)) {
+        xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
+        if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
+          sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
+        }
+        assert(sourceAttr);
+        xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
+            load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
+            DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
+            sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
+            sourceAttr.getOrder());
+
+        if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+          updatedLayoutAttr = xegpu::SliceAttr::get(
+              load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
+        valueLayout = LayoutInfo(updatedLayoutAttr);
+      }
+    }
+    loadLayout = valueLayout;
     load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
   }
   // Propagate the new layout to the tensor descriptor operand.
@@ -1427,7 +1458,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    xegpu::setDistributeLayoutAttr(result, layout);
+    if (!isa<xegpu::LoadGatherOp>(op))
+      xegpu::setDistributeLayoutAttr(result, layout);
   }
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index eda60ec0ef163..0a360a4c47111 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -173,8 +173,8 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 4]>} : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
 // CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
 // CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
@@ -187,3 +187,25 @@ func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
   return
 }
 }
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_slice(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 4]>, dims = [0]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive_slice(%src: memref<1024xf32>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}>
+      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+  return
+}
+}

>From 44a620e253cf5c00187e30fa4e198a4accded021 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 8 Jan 2026 17:24:01 +0000
Subject: [PATCH 3/6] Remove payload layout for mask/offsets

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 24 +++++++++----------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  4 ++--
 2 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e9c1332982ffa..adaf1b68f391d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1141,10 +1141,11 @@ void LayoutInfoPropagation::visitLoadGatherOp(
 
   LayoutInfo loadLayout;
   LayoutInfo maskLayout;
+  auto uArch = getUArch(getChipStr(load).value_or(""));
+  const int subgroupSize = uArch->getSubgroupSize();
   xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
   if (hasParamsOfLayoutKind(anchorLayout)) {
     loadLayout = LayoutInfo(anchorLayout);
-    maskLayout = loadLayout;
   } else {
     LayoutInfo valueLayout = results[0]->getValue();
     // Need the layout of the value to propagate to the tensor descriptor.
@@ -1164,22 +1165,16 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-    auto uArch = getUArch(getChipStr(load).value_or(""));
     const auto *uArchInstruction =
         dyn_cast<xegpu::uArch::LoadGatherInstruction>(
             uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
 
-    const int subgroupSize = uArch->getSubgroupSize();
-    // Mask operand should have 1D default layout.
-    maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
-
     // Check if value inst_data complies with uArch
     if (!instDataIncoming.empty()) {
       const int maxElemsPerInst =
           uArchInstruction->getMaxBitSize() /
           payloadTy.getElementType().getIntOrFloatBitWidth();
 
-      xegpu::LayoutAttr sourceAttr;
       // Each lane loads either one element
       SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
       // Or multiple elements as 2D with lane's elements in the inner dimension
@@ -1212,6 +1207,9 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     loadLayout = valueLayout;
     load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
   }
+  // Mask operand should have 1D default layout.
+  maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
   // Propagate the new layout to the tensor descriptor operand.
   if (isa<xegpu::TensorDescType>(load.getSourceType()))
     propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
@@ -1246,6 +1244,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   LayoutInfo payloadLayout;
   LayoutInfo maskLayout;
   xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+  auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+  const int subgroupSize = uArch->getSubgroupSize();
+
   if (hasParamsOfLayoutKind(anchorLayout)) {
     payloadLayout = LayoutInfo(anchorLayout);
     maskLayout = payloadLayout;
@@ -1259,9 +1260,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
       return;
     }
 
-    auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
-    const int subgroupSize = uArch->getSubgroupSize();
-
     if (layoutKind == LayoutKind::InstData) {
       const auto *uArchInstruction =
           dyn_cast<xegpu::uArch::LoadGatherInstruction>(
@@ -1293,12 +1291,12 @@ void LayoutInfoPropagation::visitStoreScatterOp(
       payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
     }
 
-    maskLayout =
-        getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
-
     storeScatter.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
   }
+
+  maskLayout =
+      getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
   // Propagate the destination (if tdesc) operand layout
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index bf6c5d992a47f..3953bb7f407ec 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -218,8 +218,8 @@ func.func @scatter_ops(%src: memref<256xf16>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @scatter_ops_custom_perm_layout(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>

>From d0513414f0425758edce944f6e7010a35db25e5e Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 9 Jan 2026 11:00:38 +0000
Subject: [PATCH 4/6] inst_data for mask

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    |  4 ++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 +++++++++++++-----
 .../XeGPU/propagate-layout-inst-data.mlir     | 22 +++++++++----------
 3 files changed, 28 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 055c10ab50652..fd577af612523 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -222,7 +222,7 @@ struct StoreScatterInstruction : public Instruction {
     return B->getInstructionKind() == InstructionKind::StoreScatter;
   }
 
-  int32_t getMaxBitSize() const { return 128; }
+  int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
 };
 
 struct LoadGatherInstruction : public Instruction {
@@ -232,7 +232,7 @@ struct LoadGatherInstruction : public Instruction {
     return B->getInstructionKind() == InstructionKind::LoadGather;
   }
 
-  int32_t getMaxBitSize() const { return 128; }
+  int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index adaf1b68f391d..e95b4a136b954 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1172,7 +1172,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     // Check if value inst_data complies with uArch
     if (!instDataIncoming.empty()) {
       const int maxElemsPerInst =
-          uArchInstruction->getMaxBitSize() /
+          uArchInstruction->getMaxLaneLoadStoreBitSize() /
           payloadTy.getElementType().getIntOrFloatBitWidth();
 
       // Each lane loads either one element
@@ -1207,8 +1207,12 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     loadLayout = valueLayout;
     load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
   }
-  // Mask operand should have 1D default layout.
-  maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
+  if (layoutKind == LayoutKind::InstData)
+    maskLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+  else if (layoutKind == LayoutKind::Lane)
+    maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
 
   // Propagate the new layout to the tensor descriptor operand.
   if (isa<xegpu::TensorDescType>(load.getSourceType()))
@@ -1265,7 +1269,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
           dyn_cast<xegpu::uArch::LoadGatherInstruction>(
               uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
       const int maxElemsPerInst =
-          uArchInstruction->getMaxBitSize() /
+          uArchInstruction->getMaxLaneLoadStoreBitSize() /
           payloadTy.getElementType().getIntOrFloatBitWidth();
 
       const int subgroupSize = uArch->getSubgroupSize();
@@ -1295,8 +1299,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
   }
 
-  maskLayout =
-      getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+  if (layoutKind == LayoutKind::InstData)
+    maskLayout = LayoutInfo(
+        xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
+  else if (layoutKind == LayoutKind::Lane)
+    maskLayout =
+        getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
   // Propagate the destination (if tdesc) operand layout
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 0a360a4c47111..d1934e4c648bd 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -138,8 +138,8 @@ gpu.module @test_kernel {
 gpu.module @test {
 // CHECK-LABEL: func.func @scatter_ops_chunksize(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
 // CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}>
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
@@ -171,8 +171,8 @@ func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
 // CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
 // CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
 // CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
@@ -191,20 +191,20 @@ func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
 // -----
 
 gpu.module @test {
-// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_slice(
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_anchor(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 4]>, dims = [0]>}> :
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
 // CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}> :
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
 // CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
-func.func @scatter_ops_chunksize_excessive_slice(%src: memref<1024xf32>) {
+func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
   %1 = arith.constant dense<1>: vector<16xi1>
   %offset = arith.constant dense<12> : vector<16xindex>
   %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
       : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-  xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
       : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
   return
 }

>From e2692c7a2bc560ab2ebe499c6e62051b064c9392 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 12 Jan 2026 17:34:13 +0000
Subject: [PATCH 5/6] Change uArch restriction to reflect spirv

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    |  6 ++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 17 +++-------
 .../XeGPU/propagate-layout-inst-data.mlir     | 32 +++++++++----------
 3 files changed, 25 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index fd577af612523..29e75b57f4a5f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -222,7 +222,8 @@ struct StoreScatterInstruction : public Instruction {
     return B->getInstructionKind() == InstructionKind::StoreScatter;
   }
 
-  int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
+  // SPIRV restricts vector size
+  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
 struct LoadGatherInstruction : public Instruction {
@@ -232,7 +233,8 @@ struct LoadGatherInstruction : public Instruction {
     return B->getInstructionKind() == InstructionKind::LoadGather;
   }
 
-  int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
+  // SPIRV restricts vector size
+  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e95b4a136b954..b61ad817e093b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1171,10 +1171,6 @@ void LayoutInfoPropagation::visitLoadGatherOp(
 
     // Check if value inst_data complies with uArch
     if (!instDataIncoming.empty()) {
-      const int maxElemsPerInst =
-          uArchInstruction->getMaxLaneLoadStoreBitSize() /
-          payloadTy.getElementType().getIntOrFloatBitWidth();
-
       // Each lane loads either one element
       SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
       // Or multiple elements as 2D with lane's elements in the inner dimension
@@ -1182,8 +1178,9 @@ void LayoutInfoPropagation::visitLoadGatherOp(
         instDataUarch.back() = subgroupSize;
       } else {
         *std::prev(instDataUarch.end(), 2) = subgroupSize;
-        instDataUarch.back() = (std::min(
-            static_cast<int>(payloadTy.getShape().back()), maxElemsPerInst));
+        instDataUarch.back() =
+            (std::min(static_cast<int>(payloadTy.getShape().back()),
+                      uArchInstruction->getMaxLaneLoadStoreSize()));
       }
       // If inst data does not match, enforce the uArch-based one
       if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1268,10 +1265,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
       const auto *uArchInstruction =
           dyn_cast<xegpu::uArch::LoadGatherInstruction>(
               uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
-      const int maxElemsPerInst =
-          uArchInstruction->getMaxLaneLoadStoreBitSize() /
-          payloadTy.getElementType().getIntOrFloatBitWidth();
-
       const int subgroupSize = uArch->getSubgroupSize();
       SmallVector<int> instData{subgroupSize};
       auto chunkSize = storeScatter.getChunkSize().value_or(0);
@@ -1280,8 +1273,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
           !chunkSize && srcTdescTy) {
         chunkSize = srcTdescTy.getChunkSizeAsInt();
       }
-      instData.push_back(
-          std::min(static_cast<int>(chunkSize), maxElemsPerInst));
+      instData.push_back(std::min(static_cast<int>(chunkSize),
+                                  uArchInstruction->getMaxLaneLoadStoreSize()));
 
       payloadLayout = LayoutInfo(
           xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index d1934e4c648bd..53f3f014690a7 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -173,17 +173,17 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
   %1 = arith.constant dense<1>: vector<16xi1>
   %offset = arith.constant dense<12> : vector<16xindex>
-  %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
-      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-  xegpu.store %3, %src[%offset], %1 <{chunk_size=16}>
-      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=32}>
+      : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
   return
 }
 }
@@ -195,17 +195,17 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
-// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
   %1 = arith.constant dense<1>: vector<16xi1>
   %offset = arith.constant dense<12> : vector<16xindex>
-  %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
-      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-  xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
-      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=32, layout = #xegpu.layout<inst_data = [16, 16]>}>
+      : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
   return
 }
 }

>From 57414e607481f2a80fbedc4488b88b2c34eb16ea Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 23 Jan 2026 10:40:32 +0000
Subject: [PATCH 6/6] Feedback

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 82 +++++++++++--------
 .../XeGPU/propagate-layout-inst-data.mlir     | 25 ++++++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  4 +-
 3 files changed, 76 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b61ad817e093b..b46f6c7e751a1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1146,6 +1146,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
   if (hasParamsOfLayoutKind(anchorLayout)) {
     loadLayout = LayoutInfo(anchorLayout);
+    maskLayout = loadLayout;
   } else {
     LayoutInfo valueLayout = results[0]->getValue();
     // Need the layout of the value to propagate to the tensor descriptor.
@@ -1170,17 +1171,18 @@ void LayoutInfoPropagation::visitLoadGatherOp(
             uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
 
     // Check if value inst_data complies with uArch
-    if (!instDataIncoming.empty()) {
+    if (layoutKind == LayoutKind::InstData) {
       // Each lane loads either one element
-      SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
+      SmallVector<int> instDataUarch{subgroupSize};
       // Or multiple elements as 2D with lane's elements in the inner dimension
-      if (payloadTy.getRank() == 1) {
-        instDataUarch.back() = subgroupSize;
-      } else {
-        *std::prev(instDataUarch.end(), 2) = subgroupSize;
-        instDataUarch.back() =
+      if (payloadTy.getRank() != 1) {
+        if (payloadTy.getRank() != 2) {
+          load.emitWarning("Expected 2D payload for LoadGatherOp.");
+          return;
+        }
+        instDataUarch.push_back(
             (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize()));
+                      uArchInstruction->getMaxLaneLoadStoreSize())));
       }
       // If inst data does not match, enforce the uArch-based one
       if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1205,11 +1207,19 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
   }
 
-  if (layoutKind == LayoutKind::InstData)
-    maskLayout =
-        LayoutInfo(xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
-  else if (layoutKind == LayoutKind::Lane)
-    maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+  // If no user-defined anchor or we deal with a chunked op, set the default
+  // mask layout.
+  // Rank 1 data : Keep the mask layout aligned with data.
+  // Rank >1 data: Enforce the default xegpu 1D layout for mask.
+  if (!hasParamsOfLayoutKind(anchorLayout) ||
+      load.getValueType().getRank() > 1) {
+    if (layoutKind == LayoutKind::InstData)
+      maskLayout = LayoutInfo(
+          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+    else if (layoutKind == LayoutKind::Lane)
+      maskLayout =
+          getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+  }
 
   // Propagate the new layout to the tensor descriptor operand.
   if (isa<xegpu::TensorDescType>(load.getSourceType()))
@@ -1263,21 +1273,21 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 
     if (layoutKind == LayoutKind::InstData) {
       const auto *uArchInstruction =
-          dyn_cast<xegpu::uArch::LoadGatherInstruction>(
-              uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+          dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
+              xegpu::uArch::InstructionKind::StoreScatter));
       const int subgroupSize = uArch->getSubgroupSize();
-      SmallVector<int> instData{subgroupSize};
-      auto chunkSize = storeScatter.getChunkSize().value_or(0);
-      if (auto srcTdescTy =
-              dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType());
-          !chunkSize && srcTdescTy) {
-        chunkSize = srcTdescTy.getChunkSizeAsInt();
+      SmallVector<int> instDataUarch{subgroupSize};
+      if (payloadTy.getRank() != 1) {
+        if (payloadTy.getRank() != 2) {
+          storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
+          return;
+        }
+        instDataUarch.push_back(
+            (std::min(static_cast<int>(payloadTy.getShape().back()),
+                      uArchInstruction->getMaxLaneLoadStoreSize())));
       }
-      instData.push_back(std::min(static_cast<int>(chunkSize),
-                                  uArchInstruction->getMaxLaneLoadStoreSize()));
-
       payloadLayout = LayoutInfo(
-          xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+          xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
     } else {
       auto payloadShape = payloadTy.getShape();
       if (payloadShape.size() > 1)
@@ -1292,12 +1302,19 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
   }
 
-  if (layoutKind == LayoutKind::InstData)
-    maskLayout = LayoutInfo(
-        xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
-  else if (layoutKind == LayoutKind::Lane)
-    maskLayout =
-        getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+  // If no user-defined anchor or we deal with a chunked op, set the default
+  // mask layout.
+  // Rank 1 data : Keep the mask layout aligned with data.
+  // Rank >1 data: Enforce the default xegpu 1D layout for mask.
+  if (!hasParamsOfLayoutKind(anchorLayout) ||
+      storeScatter.getValueType().getRank() > 1) {
+    if (layoutKind == LayoutKind::InstData)
+      maskLayout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
+    else if (layoutKind == LayoutKind::Lane)
+      maskLayout =
+          getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+  }
 
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
@@ -1458,8 +1475,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    if (!isa<xegpu::LoadGatherOp>(op))
-      xegpu::setDistributeLayoutAttr(result, layout);
+    xegpu::setDistributeLayoutAttr(result, layout);
   }
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 53f3f014690a7..38ce62cf9b514 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -209,3 +209,28 @@ func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
   return
 }
 }
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_anchor(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
+// CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+
+  %4 = vector.broadcast %3 : vector<16xf32> to vector<16x16xf32>
+  xegpu.store %4, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
+      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+  return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 3953bb7f407ec..bf6c5d992a47f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -218,8 +218,8 @@ func.func @scatter_ops(%src: memref<256xf16>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @scatter_ops_custom_perm_layout(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>



More information about the Mlir-commits mailing list