[Mlir-commits] [mlir] [MLIR][XeGPU] Reorganize uArch for easier extension (PR #178907)

Artem Kroviakov llvmlistbot at llvm.org
Fri Jan 30 07:48:35 PST 2026


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/178907

The scatter/gather instructions are common across architectures, so they are turned into an interface that allows each uArch specify its own limitations.

>From e4a61c8421c87d82572e403faba5c0118e68b798 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 30 Jan 2026 15:45:55 +0000
Subject: [PATCH] [MLIR][XeGPU] Reorganize uArch for easier extension

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 30 +++++++------------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 26 ++++++++++++++++
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 +++++++------
 3 files changed, 48 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 29e75b57f4a5f..05b4dbdbb0317 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,26 +215,16 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
   const unsigned packedFormatBitSizeB;
 };
 
-struct StoreScatterInstruction : public Instruction {
-  StoreScatterInstruction()
-      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::StoreScatter;
+struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
+  int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+    return 16;
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
-struct LoadGatherInstruction : public Instruction {
-  LoadGatherInstruction()
-      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::LoadGather;
+struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
+  int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+    return 16;
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
 //===----------------------------------------------------------------------===//
@@ -247,8 +237,8 @@ struct PVCuArch final : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const StoreScatterInstruction storeScatterInst;
-    static const LoadGatherInstruction loadGatherInst;
+    static const SpirvStoreScatterInstruction storeScatterInst;
+    static const SpirvLoadGatherInstruction loadGatherInst;
     static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
                                        &storeNdInst,      &prefetchNdInst,
                                        &storeScatterInst, &loadGatherInst};
@@ -273,8 +263,8 @@ struct BMGuArch : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const StoreScatterInstruction storeScatterInst;
-    static const LoadGatherInstruction loadGatherInst;
+    static const SpirvStoreScatterInstruction storeScatterInst;
+    static const SpirvLoadGatherInstruction loadGatherInst;
     static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
                                        &storeNdInst,      &prefetchNdInst,
                                        &storeScatterInst, &loadGatherInst};
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index db1984b2edb1d..ee3d5a5a8c398 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -250,6 +250,32 @@ struct MMAInstructionInterface {
   virtual ~MMAInstructionInterface() = default;
 };
 
+//===----------------------------------------------------------------------===//
+// Common instructions (shared across architectures)
+//===----------------------------------------------------------------------===//
+
+struct StoreScatterInstructionInterface : public Instruction {
+  StoreScatterInstructionInterface()
+      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreScatter;
+  }
+
+  virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+  virtual ~StoreScatterInstructionInterface() = default;
+};
+
+struct LoadGatherInstructionInterface : public Instruction {
+  LoadGatherInstructionInterface()
+      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadGather;
+  }
+
+  virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+  virtual ~LoadGatherInstructionInterface() = default;
+};
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 96fdced39d9ab..faafb7e8cee61 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1168,7 +1168,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       return;
     }
     const auto *uArchInstruction =
-        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+        dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
             uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
 
     // Check if value inst_data complies with uArch
@@ -1181,9 +1181,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
           load.emitWarning("Expected 2D payload for LoadGatherOp.");
           return;
         }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
+        int elemBitWidth = payloadTy.getElementTypeBitWidth();
+        instDataUarch.push_back((
+            std::min(static_cast<int>(payloadTy.getShape().back()),
+                     uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
       }
       // If inst data does not match, enforce the uArch-based one
       if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1274,8 +1275,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 
     if (layoutKind == xegpu::LayoutKind::InstData) {
       const auto *uArchInstruction =
-          dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
-              xegpu::uArch::InstructionKind::StoreScatter));
+          dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
+              uArch->getInstruction(
+                  xegpu::uArch::InstructionKind::StoreScatter));
       const int subgroupSize = uArch->getSubgroupSize();
       SmallVector<int> instDataUarch{subgroupSize};
       if (payloadTy.getRank() != 1) {
@@ -1283,9 +1285,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
           storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
           return;
         }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
+        int elemBitWidth = payloadTy.getElementTypeBitWidth();
+        instDataUarch.push_back((
+            std::min(static_cast<int>(payloadTy.getShape().back()),
+                     uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
       }
       payloadLayout = LayoutInfo(
           xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));



More information about the Mlir-commits mailing list