[Mlir-commits] [mlir] [MLIR][XeGPU] Introduce `xegpu::uArch` usage in target-sensitive passes (PR #163801)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 16 07:55:59 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
This PR is the _first, minimally working, and somewhat crude_ application of the xegpu::uArch infra in the uarch-sensitive parts of XeGPU. We completely remove the `XeGPUTargetInfo.h` and rely on the attached target to the GPU module. Due to the early uncertainty of the design, I only consider `pvc`, and all of the new instructions provide the minimal interface.
This PR adds support for inst_data default setting for store_nd, prefetch_nd, dpas, scatter, gather
Some points to consider:
- LLVM does not use C++ RTTI, its own dynamic polymorphism requires manually amending the types for `dyn_cast` to work. This becomes crucial if you want to check for or get a specific uArch. An example of it can be found in `requireTranspose()` in `XeGPUSubgroupDistribute.cpp` where we still need to check for hardcoded strings, instead of `isa<>`/`dyn_cast<>`.
- Should uArch be exposed via a shared pointer to a constant structure or as a reference to a constant static structure? It depends on whether we allow for some fallback (chip string is not present or no uarch found for it, i.e., `if(!uArch){...}` ) or strictly require a valid uArch (then an invalid uArch is not even possible, `llvm_unreachable` in `getUArch()`). Generally, I lean towards trying a reference to a static constant for simplicity, but the uArch exposure to use cases may still be too small to judge yet. The shared_ptr version currently remains as the most flexible one.
- Type verification (`XeGPUDialect.cpp`, `TensorDescType::verify`) gets trickier, because there is no way to get the target attribute. I'd be glad to hear some feedback on it. For now, I use a constexpr placeholder.
- All anchor ops need uArch, at least to query the subgroup size for `getDefaultSIMTLayoutInfo`. Inst_data is not part of `getDefaultSIMTLayoutInfo`, because it requires querying a specific operation in uArch.
I am eager to gather feedback both for the uArch API and usage in general, so feel free to ask and propose changes.
My impression is that uArch is a nice for our purposes, but from the initial experience, as of now, it appears a bit bloated (two std maps of shared pointers to instructions per instance of uArch for seemingly compile time constant data) and tricky to use (e.g., no support for LLVM polymorphism). I may have missed some justifications for this in the uArch PR though, so it would be good to reiterate here.
---
Patch is 57.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163801.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+23-11)
- (removed) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h (-30)
- (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+6-1)
- (modified) mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h (+74-4)
- (modified) mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h (+15-2)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-7)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+164-62)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+16-10)
- (modified) mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir (+1-1)
- (added) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+51)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+59-23)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..ec236d702de0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,29 +379,41 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
);
let builders = [
- AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+ AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
+ "llvm::ArrayRef<int32_t>": $lane_layout,
"llvm::ArrayRef<int32_t>": $lane_data),
[{
auto sg_layout = DenseI32ArrayAttr();
auto sg_data = DenseI32ArrayAttr();
- auto inst_data = DenseI32ArrayAttr();
auto order = DenseI32ArrayAttr();
- return $_get($_ctxt, sg_layout, sg_data, inst_data,
+ return $_get($_ctxt, sg_layout, sg_data,
+ DenseI32ArrayAttr::get($_ctxt, inst_data),
DenseI32ArrayAttr::get($_ctxt, lane_layout),
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
}]>,
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
- "llvm::ArrayRef<int32_t>": $lane_data,
- "llvm::ArrayRef<int32_t>": $order),
+ "llvm::ArrayRef<int32_t>": $lane_data),
[{
- return $_get($_ctxt,
- /*sg_layout =*/ nullptr,
- /*sg_data =*/ nullptr,
- /*inst_data =*/ nullptr,
+ auto sg_layout = DenseI32ArrayAttr();
+ auto sg_data = DenseI32ArrayAttr();
+ auto inst_data = DenseI32ArrayAttr();
+ auto order = DenseI32ArrayAttr();
+ return $_get($_ctxt, sg_layout, sg_data, inst_data,
DenseI32ArrayAttr::get($_ctxt, lane_layout),
- DenseI32ArrayAttr::get($_ctxt, lane_data),
- DenseI32ArrayAttr::get($_ctxt, order));
+ DenseI32ArrayAttr::get($_ctxt, lane_data), order);
}]>,
+ // AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+ // "llvm::ArrayRef<int32_t>": $lane_data,
+ // "llvm::ArrayRef<int32_t>": $order),
+ // [{
+ // return $_get($_ctxt,
+ // /*sg_layout =*/ nullptr,
+ // /*sg_data =*/ nullptr,
+ // /*inst_data =*/ nullptr,
+ // DenseI32ArrayAttr::get($_ctxt, lane_layout),
+ // DenseI32ArrayAttr::get($_ctxt, lane_data),
+ // DenseI32ArrayAttr::get($_ctxt, order));
+ // }]>,
AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
"DenseI32ArrayAttr": $lane_data,
"DenseI32ArrayAttr": $order),
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
deleted file mode 100644
index 8aa9536cb67c1..0000000000000
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h
+++ /dev/null
@@ -1,30 +0,0 @@
-//===- XeGPUTargetInfo.h - Target constants ---------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-#define MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
-
-namespace mlir {
-namespace xegpu {
-/// HW dependent constants.
-/// TODO: These constants should be queried from the target information.
-namespace targetinfo {
-constexpr unsigned subgroupSize = 16; // How many lanes in a subgroup.
-/// If DPAS A or B operands have low precision element types they must be packed
-/// according to the following sizes.
-constexpr unsigned packedSizeInBitsForDefault =
- 16; // Minimum packing size per register for DPAS A.
-constexpr unsigned packedSizeInBitsForDpasB =
- 32; // Minimum packing size per register for DPAS B.
-constexpr unsigned packedSizeInBitsForGatherScatter =
- 32; // Minimum packing size per register for Gather and Scatter ops.
-} // namespace targetinfo
-} // namespace xegpu
-} // namespace mlir
-
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTARGETINFO_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5ef1d499d618f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
let options = [Option<
"printOnly", "print-analysis-only", "bool",
/*default=*/"false",
- "Print the result of layout propagation analysis and exit.">];
+ "Print the result of layout propagation analysis and exit.">,
+ Option<
+ "assumeUnrolled", "assume-unrolled", "bool",
+ /*default=*/"false",
+ "If the input IR has SG-sized tiles matching instruction sizes, omit `inst_data`.">
+ ];
}
def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b2e277d..5cb6d61336391 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -42,12 +42,59 @@ struct Xe2Plus : public uArch {
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
+ int getSubgroupSize() const override { return 16; }
+ int getPackedFormatBitSizeGatherScatter() const override { return 32; }
+ int getPackedFormatBitSize() const override { return 16; }
+ std::optional<int> getPackedFormatBitSizeDpasB() const override { return 32; }
+};
+
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct StoreNdInstruction : public Instruction {
+ StoreNdInstruction()
+ : Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
+ // the specified pointer
+ llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct LoadNdInstruction : public Instruction {
+ LoadNdInstruction()
+ : Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
+ // Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
+ // the specified pointer.
+ llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
+};
+
+struct PrefetchNdInstruction : public Instruction {
+ PrefetchNdInstruction()
+ : Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+ llvm::SmallVector<int> getSortedLaneVectorLengths(int elementBitwidth) {
+ if (elementBitwidth == 8 || elementBitwidth == 16)
+ return {1, 2, 4, 8, 16};
+ else if (elementBitwidth == 32 || elementBitwidth == 64)
+ return {1, 2, 4, 8};
+ else
+ llvm_unreachable(
+ "Unsupported element bitwidth for PrefetchNdInstruction");
+ }
};
-// struct to represent DPAS instruction
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
DPASInstruction()
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+ // Source:
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -72,6 +119,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
};
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
struct PVCuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
@@ -101,9 +151,15 @@ struct PVCuArch : public Xe2Plus {
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
// Add the instructions-
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{
+ std::make_shared<DPASInstruction>(),
+ std::make_shared<StoreNdInstruction>(),
+ std::make_shared<LoadNdInstruction>(),
+ std::make_shared<PrefetchNdInstruction>()};
+ for (auto &inst : instructionsToAdd) {
+ instructions.emplace(inst->getInstructionKind(), inst);
+ owned_instructions.push_back(inst);
+ }
}
};
@@ -139,10 +195,24 @@ struct BMGuArch : public Xe2Plus {
owned_instructions.push_back(dpas);
}
};
+
+inline std::shared_ptr<uArch> getUArch(const std::string &archName) {
+ if (archName == "pvc")
+ return std::make_shared<PVCuArch>();
+ else if (archName == "bmg")
+ return std::make_shared<BMGuArch>();
+ else
+ return nullptr;
+}
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994ea5ecf5..0f5b1282f0e24 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
- // multiply-add operation
+ DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
+ // multiply-add operation
+ STORE_ND, // Subgroup-level 2D block write instruction
+ LOAD_ND, // Subgroup-level 2D block load instruction
+ PREFETCH_ND // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};
@@ -148,6 +151,16 @@ struct uArch {
const std::string &getDescription() const { return description; }
+ virtual int getSubgroupSize() const = 0;
+ virtual int getPackedFormatBitSizeGatherScatter() const = 0;
+ virtual int getPackedFormatBitSize() const = 0;
+ virtual std::optional<int> getPackedFormatBitSizeDpasB() const = 0;
+
+ std::shared_ptr<Instruction> getInstruction(InstructionKind instKind) const {
+ assert(instructions.find(instKind) != instructions.end());
+ return instructions.at(instKind);
+ }
+
const std::map<RegisterFileType, RegisterFileInfo> &
getRegisterFileInfo() const {
return registerFileInfo;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9beb22d517473..afda04fa71105 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -11,7 +11,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -226,8 +226,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
- return emitError()
- << "expected inst_data and lane_layout to have the same rank";
+ return emitError() << "expected inst_data and lane_layout to have the same "
+ "rank, got inst_data "
+ << inst_data.size() << ", lane_layout "
+ << lane_layout.size();
}
// sg_data is optional for Workgroup layout, but its presence requires
@@ -565,10 +567,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
- int chunkAlignmentFactor =
- bitWidth < targetinfo::packedSizeInBitsForGatherScatter
- ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
- : 1;
+ constexpr int packingBitSizeGatherScatter{32};
+ int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
+ ? packingBitSizeGatherScatter / bitWidth
+ : 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255d6347f..9c09908f3547d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -104,6 +105,8 @@ struct LayoutInfo {
SmallVector<int> getLaneData() const;
+ SmallVector<int> getInstData() const;
+
bool isSliceLayout() const {
if (!isAssigned())
return false;
@@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
[](int64_t val) { return static_cast<int>(val); });
}
+SmallVector<int> LayoutInfo::getInstData() const {
+ if (!isAssigned())
+ return {};
+ return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
os << storage;
@@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneLayout;
SmallVector<int32_t> laneData;
+ SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
}
- return LayoutInfo(
- xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+ return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
+ laneLayout, laneData));
}
//===----------------------------------------------------------------------===//
@@ -199,20 +211,33 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// Helper Function to get the default layout for uniform values like constants.
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
- unsigned rank) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank,
+ std::shared_ptr<xegpu::uArch::uArch> &uArch,
+ ArrayRef<int> instData) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+ xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
}
return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+ ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+ unsigned rank, int subgroupSize) {
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+ if (rank == 1) {
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+ }
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
- bool isScattered = false) {
+static LayoutInfo
+getDefaultSIMTLayoutInfo(VectorType vectorTy,
+ std::shared_ptr<xegpu::uArch::uArch> &uArch,
+ ArrayRef<int> instData, bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
@@ -221,29 +246,31 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (isScattered) {
packingFactor =
- bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
- ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+ bitwidth < uArch->getPackedFormatBitSizeGatherScatter()
+ ? uArch->getPackedFormatBitSizeGatherScatter() / bitwidth
: 1;
- return LayoutInfo(xegpu::LayoutAttr::get(
- vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ {uArch->getSubgroupSize(), 1},
+ {1, packingFactor}));
}
- if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
- packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {1, xegpu::targetinfo::subgroupSize},
+ if (bitwidth < uArch->getPackedFormatBitSize())
+ packingFactor = uArch->getPackedFormatBitSize() / bitwidth;
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ {1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo get...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/163801
More information about the Mlir-commits
mailing list