[Mlir-commits] [mlir] 56d3169 - [MLIR][XeGPU] Reorganize uArch for easier extension (#178907)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 3 00:47:17 PST 2026


Author: Artem Kroviakov
Date: 2026-02-03T09:47:11+01:00
New Revision: 56d31697c2a2c38844bdc0f43e537c29a5115d87

URL: https://github.com/llvm/llvm-project/commit/56d31697c2a2c38844bdc0f43e537c29a5115d87
DIFF: https://github.com/llvm/llvm-project/commit/56d31697c2a2c38844bdc0f43e537c29a5115d87.diff

LOG: [MLIR][XeGPU] Reorganize uArch for easier extension (#178907)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
    mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 29e75b57f4a5f..05b4dbdbb0317 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,26 +215,16 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
   const unsigned packedFormatBitSizeB;
 };
 
-struct StoreScatterInstruction : public Instruction {
-  StoreScatterInstruction()
-      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::StoreScatter;
+struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
+  int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+    return 16;
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
-struct LoadGatherInstruction : public Instruction {
-  LoadGatherInstruction()
-      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::LoadGather;
+struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
+  int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const override {
+    return 16;
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
 //===----------------------------------------------------------------------===//
@@ -247,8 +237,8 @@ struct PVCuArch final : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const StoreScatterInstruction storeScatterInst;
-    static const LoadGatherInstruction loadGatherInst;
+    static const SpirvStoreScatterInstruction storeScatterInst;
+    static const SpirvLoadGatherInstruction loadGatherInst;
     static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
                                        &storeNdInst,      &prefetchNdInst,
                                        &storeScatterInst, &loadGatherInst};
@@ -273,8 +263,8 @@ struct BMGuArch : public Xe2Plus {
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
-    static const StoreScatterInstruction storeScatterInst;
-    static const LoadGatherInstruction loadGatherInst;
+    static const SpirvStoreScatterInstruction storeScatterInst;
+    static const SpirvLoadGatherInstruction loadGatherInst;
     static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
                                        &storeNdInst,      &prefetchNdInst,
                                        &storeScatterInst, &loadGatherInst};

diff  --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index db1984b2edb1d..ee3d5a5a8c398 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -250,6 +250,32 @@ struct MMAInstructionInterface {
   virtual ~MMAInstructionInterface() = default;
 };
 
+//===----------------------------------------------------------------------===//
+// Common instructions (shared across architectures)
+//===----------------------------------------------------------------------===//
+
+struct StoreScatterInstructionInterface : public Instruction {
+  StoreScatterInstructionInterface()
+      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreScatter;
+  }
+
+  virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+  virtual ~StoreScatterInstructionInterface() = default;
+};
+
+struct LoadGatherInstructionInterface : public Instruction {
+  LoadGatherInstructionInterface()
+      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadGather;
+  }
+
+  virtual int32_t getMaxLaneLoadStoreSize(int32_t bitWidth) const = 0;
+  virtual ~LoadGatherInstructionInterface() = default;
+};
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 96fdced39d9ab..faafb7e8cee61 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1168,7 +1168,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       return;
     }
     const auto *uArchInstruction =
-        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+        dyn_cast<xegpu::uArch::LoadGatherInstructionInterface>(
             uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
 
     // Check if value inst_data complies with uArch
@@ -1181,9 +1181,10 @@ void LayoutInfoPropagation::visitLoadGatherOp(
           load.emitWarning("Expected 2D payload for LoadGatherOp.");
           return;
         }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
+        int elemBitWidth = payloadTy.getElementTypeBitWidth();
+        instDataUarch.push_back((
+            std::min(static_cast<int>(payloadTy.getShape().back()),
+                     uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
       }
       // If inst data does not match, enforce the uArch-based one
       if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1274,8 +1275,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 
     if (layoutKind == xegpu::LayoutKind::InstData) {
       const auto *uArchInstruction =
-          dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
-              xegpu::uArch::InstructionKind::StoreScatter));
+          dyn_cast<xegpu::uArch::StoreScatterInstructionInterface>(
+              uArch->getInstruction(
+                  xegpu::uArch::InstructionKind::StoreScatter));
       const int subgroupSize = uArch->getSubgroupSize();
       SmallVector<int> instDataUarch{subgroupSize};
       if (payloadTy.getRank() != 1) {
@@ -1283,9 +1285,10 @@ void LayoutInfoPropagation::visitStoreScatterOp(
           storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
           return;
         }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
+        int elemBitWidth = payloadTy.getElementTypeBitWidth();
+        instDataUarch.push_back((
+            std::min(static_cast<int>(payloadTy.getShape().back()),
+                     uArchInstruction->getMaxLaneLoadStoreSize(elemBitWidth))));
       }
       payloadLayout = LayoutInfo(
           xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));


        


More information about the Mlir-commits mailing list