[Mlir-commits] [mlir] [MLIR][XeGPU] Add uArch limitation to scatter load store (PR #172845)

Artem Kroviakov llvmlistbot at llvm.org
Thu Dec 18 04:33:49 PST 2025


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/172845

This PR adds a uArch-based propagation for xegpu scattered load and store ops. We cannot rely on chunk size to dictate the instruction size.

>From 96467bbbef127e40991f8a8bddf6baa9e76a7107 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 18 Dec 2025 12:30:40 +0000
Subject: [PATCH] [MLIR][XeGPU] Add uArch limitation to scatter load store

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 34 +++++++++++++--
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |  8 +++-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 41 +++++++++++++------
 .../XeGPU/propagate-layout-inst-data.mlir     | 22 ++++++++++
 4 files changed, 87 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index b3231a173f33a..055c10ab50652 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,6 +215,26 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
   const unsigned packedFormatBitSizeB;
 };
 
+struct StoreScatterInstruction : public Instruction {
+  StoreScatterInstruction()
+      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreScatter;
+  }
+
+  int32_t getMaxBitSize() const { return 128; }
+};
+
+struct LoadGatherInstruction : public Instruction {
+  LoadGatherInstruction()
+      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadGather;
+  }
+
+  int32_t getMaxBitSize() const { return 128; }
+};
+
 //===----------------------------------------------------------------------===//
 // uArch instances
 //===----------------------------------------------------------------------===//
@@ -225,8 +245,11 @@ struct PVCuArch final : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
-                                       &prefetchNdInst};
+    static const StoreScatterInstruction storeScatterInst;
+    static const LoadGatherInstruction loadGatherInst;
+    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
+                                       &storeNdInst,      &prefetchNdInst,
+                                       &storeScatterInst, &loadGatherInst};
     return arr;
   }
 
@@ -248,8 +271,11 @@ struct BMGuArch : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
-                                       &prefetchNdInst};
+    static const StoreScatterInstruction storeScatterInst;
+    static const LoadGatherInstruction loadGatherInst;
+    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
+                                       &storeNdInst,      &prefetchNdInst,
+                                       &storeScatterInst, &loadGatherInst};
     return arr;
   }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f23b89134773..db1984b2edb1d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -38,7 +38,9 @@ enum class InstructionKind {
                              // matrix multiply-add operation
   Subgroup2DBlockStore,      // Subgroup-level 2D block write instruction
   Subgroup2DBlockLoad,       // Subgroup-level 2D block load instruction
-  Subgroup2DBlockPrefetch    // Subgroup-level 2D block prefetch instruction
+  Subgroup2DBlockPrefetch,   // Subgroup-level 2D block prefetch instruction
+  StoreScatter,              // Lane-level store (scalar, vector)
+  LoadGather                 // Lane-level load (scalar, vector)
   // @TODO: Add more instructions as needed
 };
 
@@ -65,6 +67,10 @@ struct Instruction {
       return "load_nd";
     case InstructionKind::Subgroup2DBlockPrefetch:
       return "prefetch_nd";
+    case InstructionKind::StoreScatter:
+      return "store";
+    case InstructionKind::LoadGather:
+      return "load";
     }
     llvm_unreachable("Unknown InstructionKind");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 112d138c73e18..337d47a94965d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -982,15 +982,21 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       return;
     }
     auto uArch = getUArch(getChipStr(load).value_or(""));
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+            uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+    const int maxElemsPerInst =
+        uArchInstruction->getMaxBitSize() /
+        payloadTy.getElementType().getIntOrFloatBitWidth();
+
     const int subgroupSize = uArch->getSubgroupSize();
     SmallVector<int> instData{subgroupSize};
-    if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
-      instData.push_back(chunkSize);
-    else if (auto srcTdescTy =
-                 dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
-      if (srcTdescTy.getChunkSizeAsInt() > 1)
-        instData.push_back(chunkSize);
+    auto chunkSize = load.getChunkSize().value_or(0);
+    if (auto srcTdescTy = dyn_cast<xegpu::TensorDescType>(load.getSourceType());
+        !chunkSize && srcTdescTy) {
+      chunkSize = srcTdescTy.getChunkSizeAsInt();
     }
+    instData.push_back(std::min(static_cast<int>(chunkSize), maxElemsPerInst));
 
     if (layoutKind == LayoutKind::InstData)
       loadLayout =
@@ -1056,15 +1062,24 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     const int subgroupSize = uArch->getSubgroupSize();
 
     if (layoutKind == LayoutKind::InstData) {
+      const auto *uArchInstruction =
+          dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+              uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+      const int maxElemsPerInst =
+          uArchInstruction->getMaxBitSize() /
+          payloadTy.getElementType().getIntOrFloatBitWidth();
+
+      const int subgroupSize = uArch->getSubgroupSize();
       SmallVector<int> instData{subgroupSize};
-      if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
-          chunkSize > 1)
-        instData.push_back(chunkSize);
-      else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
-                   storeScatter.getDestType())) {
-        if (dstTdescTy.getChunkSizeAsInt() > 1)
-          instData.push_back(chunkSize);
+      auto chunkSize = storeScatter.getChunkSize().value_or(0);
+      if (auto srcTdescTy =
+              dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType());
+          !chunkSize && srcTdescTy) {
+        chunkSize = srcTdescTy.getChunkSizeAsInt();
       }
+      instData.push_back(
+          std::min(static_cast<int>(chunkSize), maxElemsPerInst));
+
       payloadLayout = LayoutInfo(
           xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
     } else {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 32fb3178a8af2..310f75817ec53 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,25 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   return
 }
 }
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 4]>} : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
+  %1 = arith.constant dense<1>: vector<16xi1>
+  %offset = arith.constant dense<12> : vector<16xindex>
+  %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
+      : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+  xegpu.store %3, %src[%offset], %1 <{chunk_size=16}>
+      : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+  return
+}
+}



More information about the Mlir-commits mailing list