[Mlir-commits] [mlir] [MLIR][XeGPU] Introduce `xegpu::uArch` usage in target-sensitive passes (PR #163801)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Thu Oct 16 07:55:59 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
This PR is the _first, minimally working, and somewhat crude_ application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the `XeGPUTargetInfo.h` and rely on the attached target to the GPU module. Due to the early uncertainty of the design, I only consider `pvc`, and all of the new instructions provide the minimal interface. 
This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather
Some points to consider:
- LLVM does not use C++ RTTI, its own dynamic polymorphism requires manually amending the types for `dyn_cast` to work. This becomes crucial if you want to check for or get a specific uArch. An example of it can be found in `requireTranspose()` in `XeGPUSubgroupDistribute.cpp` where we still need to check for hardcoded strings, instead of `isa<>`/`dyn_cast<>`.
- Should uArch be exposed via a shared pointer to a constant structure or as a reference to a constant static structure? It depends on whether we allow for some fallback (chip string is not present or no uarch found for it, i.e., `if(!uArch){...}` ) or strictly require a valid uArch (then an invalid uArch is not even possible, `llvm_unreachable` in `getUArch()`). Generally, I lean towards trying a reference to a static constant for simplicity, but the uArch exposure to use cases may still be too small to judge yet. The shared_ptr version currently remains as the most flexible one.
- Type verification (`XeGPUDialect.cpp`, `TensorDescType::verify`) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
- All anchor ops need uArch, at least to query the subgroup size for `getDefaultSIMTLayoutInfo`. Inst_data is not part of `getDefaultSIMTLayoutInfo`, because it requires querying a specific operation in uArch.
I am eager to gather feedback both for the uArch API and usage in general, so feel free to ask and propose changes. 
My impression is that uArch is a nice for our purposes, but from the initial experience, as of now, it appears a bit bloated (two std maps of shared pointers to instructions per instance of uArch for seemingly compile time constant data) and tricky to use (e.g., no support for LLVM polymorphism). I may have missed some justifications for this in the uArch PR though, so it would be good to reiterate here.
---
Patch is 57.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163801.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+23-11) 
- (removed) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h (-30) 
- (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+6-1) 
- (modified) mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h (+74-4) 
- (modified) mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h (+15-2) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-7) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+164-62) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+16-10) 
- (modified) mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir (+1-1) 
- (added) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+51) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+59-23) 
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..ec236d702de0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,29 +379,41 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
-    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
+                      "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),
       [{
         auto sg_layout = DenseI32ArrayAttr();
         auto sg_data = DenseI32ArrayAttr();
-        auto inst_data = DenseI32ArrayAttr();
         auto order = DenseI32ArrayAttr();
-        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
-                     "llvm::ArrayRef<int32_t>": $lane_data,
-                     "llvm::ArrayRef<int32_t>": $order),
+                     "llvm::ArrayRef<int32_t>": $lane_data),
       [{
-        return $_get($_ctxt,
-                     /*sg_layout =*/ nullptr,
-                     /*sg_data   =*/ nullptr,
-                     /*inst_data =*/ nullptr,
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
-                     DenseI32ArrayAttr::get($_ctxt, lane_data),
-                     DenseI32ArrayAttr::get($_ctxt, order));
+                     DenseI32ArrayAttr::get($_ctxt, lane_data), order);
       }]>,
+    // AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+    //                  "llvm::ArrayRef<int32_t>": $lane_data,
+    //                  "llvm::ArrayRef<int32_t>": $order),
+    //   [{
+    //     return $_get($_ctxt,
+    //                  /*sg_layout =*/ nullptr,
+    //                  /*sg_data   =*/ nullptr,
+    //                  /*inst_data =*/ nullptr,
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_layout),
+    //                  DenseI32ArrayAttr::get($_ctxt, lane_data),
+    //                  DenseI32ArrayAttr::get($_ctxt, order));
+    //   }]>,
     AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
                      "DenseI32ArrayAttr": $lane_data,
                      "DenseI32ArrayAttr": $order),
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
deleted file mode 100644
index 8aa9536cb67c1..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
+++ /dev/null
@@ -1,30 +0,0 @@
-//===- XeGPUTargetInfo.h - Target constants ---------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-#define MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-
-namespace mlir {
-namespace xegpu {
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-namespace targetinfo {
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
-    16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
-    32; // Minimum packing size per register for DPAS B.
-constexpr unsigned packedSizeInBitsForGatherScatter =
-    32; // Minimum packing size per register for Gather and Scatter ops.
-} // namespace targetinfo
-} // namespace xegpu
-} // namespace mlir
-
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5ef1d499d618f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
   let options = [Option<
     "printOnly", "print-analysis-only", "bool",
     /*default=*/"false",
-    "Print the result of layout propagation analysis and exit.">];
+    "Print the result of layout propagation analysis and exit.">,
+    Option<
+    "assumeUnrolled", "assume-unrolled", "bool",
+    /*default=*/"false",
+    "If the input IR has SG-sized tiles matching instruction sizes, omit `inst_data`.">
+  ];
 }
 
 def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..5cb6d61336391 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -42,12 +42,59 @@ struct Xe2Plus : public uArch {
               &instrs = {})
       : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
         xeCore(xeCore) {}
+  int getSubgroupSize() const override { return 16; }
+  int getPackedFormatBitSizeGatherScatter() const override { return 32; }
+  int getPackedFormatBitSize() const override { return 16; }
+  std::optional<int> getPackedFormatBitSizeDpasB() const override { return 32; }
+};
+
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+  StoreNdInstruction()
+      : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+  // the specified pointer
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct LoadNdInstruction : public Instruction {
+  LoadNdInstruction()
+      : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+  // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+  // the specified pointer.
+  llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+  PrefetchNdInstruction()
+      : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+  // Source :
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+  llvm::SmallVector<int> getSortedLaneVectorLengths(int elementBitwidth) {
+    if (elementBitwidth == 8 || elementBitwidth == 16)
+      return {1, 2, 4, 8, 16};
+    else if (elementBitwidth == 32 || elementBitwidth == 64)
+      return {1, 2, 4, 8};
+    else
+      llvm_unreachable(
+          "Unsupported element bitwidth for PrefetchNdInstruction");
+  }
 };
 
-// struct to represent DPAS instruction
 struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   DPASInstruction()
       : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+  // Source:
+  // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
 
   // Override all virtuals from MatrixOpInterface
   virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -72,6 +119,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
   virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
 };
 
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
 struct PVCuArch : public Xe2Plus {
   // Maintaines ownership of the instructions owned by PVUarch
   llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
@@ -101,9 +151,15 @@ struct PVCuArch : public Xe2Plus {
         CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
 
     // Add the instructions-
-    auto dpas = std::make_shared<DPASInstruction>();
-    instructions.emplace(dpas->getInstructionKind(), dpas);
-    owned_instructions.push_back(dpas);
+    llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{
+        std::make_shared<DPASInstruction>(),
+        std::make_shared<StoreNdInstruction>(),
+        std::make_shared<LoadNdInstruction>(),
+        std::make_shared<PrefetchNdInstruction>()};
+    for (auto &inst : instructionsToAdd) {
+      instructions.emplace(inst->getInstructionKind(), inst);
+      owned_instructions.push_back(inst);
+    }
   }
 };
 
@@ -139,10 +195,24 @@ struct BMGuArch : public Xe2Plus {
     owned_instructions.push_back(dpas);
   }
 };
+
+inline std::shared_ptr<uArch> getUArch(const std::string &archName) {
+  if (archName == "pvc")
+    return std::make_shared<PVCuArch>();
+  else if (archName == "bmg")
+    return std::make_shared<BMGuArch>();
+  else
+    return nullptr;
+}
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
 
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
 inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
 DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
   auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..0f5b1282f0e24 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
 // An enum class to represent the scope of an instruction
 enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
-  DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
-        // multiply-add operation
+  DPAS,       // Dot Product Accumulate Systolic (DPAS) is a matrix
+              // multiply-add operation
+  STORE_ND,   // Subgroup-level 2D block write instruction
+  LOAD_ND,    // Subgroup-level 2D block load instruction
+  PREFETCH_ND // Subgroup-level 2D block prefetch instruction
   // @TODO: Add more instructions as needed
 };
 
@@ -148,6 +151,16 @@ struct uArch {
 
   const std::string &getDescription() const { return description; }
 
+  virtual int getSubgroupSize() const = 0;
+  virtual int getPackedFormatBitSizeGatherScatter() const = 0;
+  virtual int getPackedFormatBitSize() const = 0;
+  virtual std::optional<int> getPackedFormatBitSizeDpasB() const = 0;
+
+  std::shared_ptr<Instruction> getInstruction(InstructionKind instKind) const {
+    assert(instructions.find(instKind) != instructions.end());
+    return instructions.at(instKind);
+  }
+
   const std::map<RegisterFileType, RegisterFileInfo> &
   getRegisterFileInfo() const {
     return registerFileInfo;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d517473..afda04fa71105 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -11,7 +11,7 @@
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -226,8 +226,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   }
 
   if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
-    return emitError()
-           << "expected inst_data and lane_layout to have the same rank";
+    return emitError() << "expected inst_data and lane_layout to have the same "
+                          "rank, got inst_data "
+                       << inst_data.size() << ", lane_layout "
+                       << lane_layout.size();
   }
 
   // sg_data is optional for Workgroup layout, but its presence requires
@@ -565,10 +567,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 
   // for gather and scatter ops, Low-precision types are packed in 32-bit units.
   unsigned bitWidth = elementType.getIntOrFloatBitWidth();
-  int chunkAlignmentFactor =
-      bitWidth < targetinfo::packedSizeInBitsForGatherScatter
-          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
-          : 1;
+  constexpr int packingBitSizeGatherScatter{32};
+  int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
+                                 ? packingBitSizeGatherScatter / bitWidth
+                                 : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
     int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255d6347f..9c09908f3547d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
 namespace mlir {
 namespace xegpu {
 #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -104,6 +105,8 @@ struct LayoutInfo {
 
   SmallVector<int> getLaneData() const;
 
+  SmallVector<int> getInstData() const;
+
   bool isSliceLayout() const {
     if (!isAssigned())
       return false;
@@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
                              [](int64_t val) { return static_cast<int>(val); });
 }
 
+SmallVector<int> LayoutInfo::getInstData() const {
+  if (!isAssigned())
+    return {};
+  return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+                             [](int64_t val) { return static_cast<int>(val); });
+}
+
 void LayoutInfo::print(raw_ostream &os) const {
   if (isAssigned()) {
     os << storage;
@@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
 
   SmallVector<int32_t> laneLayout;
   SmallVector<int32_t> laneData;
+  SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
     laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
     laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
   }
-  return LayoutInfo(
-      xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
+                                           laneLayout, laneData));
 }
 
 //===----------------------------------------------------------------------===//
@@ -199,20 +211,33 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// Helper Function to get the default layout for uniform values like constants.
 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
-                                           unsigned rank) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
   }
   return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+                                           unsigned rank, int subgroupSize) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1) {
+    return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+  }
+  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
-                                           bool isScattered = false) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(VectorType vectorTy,
+                         std::shared_ptr<xegpu::uArch::uArch> &uArch,
+                         ArrayRef<int> instData, bool isScattered = false) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -221,29 +246,31 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   if (isScattered) {
     packingFactor =
-        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
-            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+        bitwidth < uArch->getPackedFormatBitSizeGatherScatter()
+            ? uArch->getPackedFormatBitSizeGatherScatter() / bitwidth
             : 1;
-    return LayoutInfo(xegpu::LayoutAttr::get(
-        vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
-        {1, packingFactor}));
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                             {uArch->getSubgroupSize(), 1},
+                                             {1, packingFactor}));
   }
-  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
-    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {1, xegpu::targetinfo::subgroupSize},
+  if (bitwidth < uArch->getPackedFormatBitSize())
+    packingFactor = uArch->getPackedFormatBitSize() / bitwidth;
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+                                           {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo get...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/163801
    
    
More information about the Mlir-commits
mailing list