[Mlir-commits] [mlir] 9109c60 - [MLIR][XeGPU] Add uArch limitation to scatter load store (#172845)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 10:10:55 PST 2026
Author: Artem Kroviakov
Date: 2026-01-23T19:10:51+01:00
New Revision: 9109c603a88f6ae1ab914cf2064bb5ad05970918
URL: https://github.com/llvm/llvm-project/commit/9109c603a88f6ae1ab914cf2064bb5ad05970918
DIFF: https://github.com/llvm/llvm-project/commit/9109c603a88f6ae1ab914cf2064bb5ad05970918.diff
LOG: [MLIR][XeGPU] Add uArch limitation to scatter load store (#172845)
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index b3231a173f33a..29e75b57f4a5f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,6 +215,28 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
const unsigned packedFormatBitSizeB;
};
+struct StoreScatterInstruction : public Instruction {
+ StoreScatterInstruction()
+ : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreScatter;
+ }
+
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
+struct LoadGatherInstruction : public Instruction {
+ LoadGatherInstruction()
+ : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadGather;
+ }
+
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
//===----------------------------------------------------------------------===//
// uArch instances
//===----------------------------------------------------------------------===//
@@ -225,8 +247,11 @@ struct PVCuArch final : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
- &prefetchNdInst};
+ static const StoreScatterInstruction storeScatterInst;
+ static const LoadGatherInstruction loadGatherInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst,
+ &storeNdInst, &prefetchNdInst,
+ &storeScatterInst, &loadGatherInst};
return arr;
}
@@ -248,8 +273,11 @@ struct BMGuArch : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
- &prefetchNdInst};
+ static const StoreScatterInstruction storeScatterInst;
+ static const LoadGatherInstruction loadGatherInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst,
+ &storeNdInst, &prefetchNdInst,
+ &storeScatterInst, &loadGatherInst};
return arr;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f23b89134773..db1984b2edb1d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -38,7 +38,9 @@ enum class InstructionKind {
// matrix multiply-add operation
Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
- Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
+ Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
+ StoreScatter, // Lane-level store (scalar, vector)
+ LoadGather // Lane-level load (scalar, vector)
// @TODO: Add more instructions as needed
};
@@ -65,6 +67,10 @@ struct Instruction {
return "load_nd";
case InstructionKind::Subgroup2DBlockPrefetch:
return "prefetch_nd";
+ case InstructionKind::StoreScatter:
+ return "store";
+ case InstructionKind::LoadGather:
+ return "load";
}
llvm_unreachable("Unknown InstructionKind");
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e8eadb6de5b30..b46f6c7e751a1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1141,40 +1141,86 @@ void LayoutInfoPropagation::visitLoadGatherOp(
LayoutInfo loadLayout;
LayoutInfo maskLayout;
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
loadLayout = LayoutInfo(anchorLayout);
maskLayout = loadLayout;
} else {
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+
+ auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
+ auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+ instDataIncoming = SmallVector<int64_t>(
+ cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
+ .getInstData()
+ .asArrayRef());
- // The layout is strictly determined by the payload type.
VectorType payloadTy = load.getValueType();
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- auto uArch = getUArch(getChipStr(load).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
- if (srcTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+
+ // Check if value inst_data complies with uArch
+ if (layoutKind == LayoutKind::InstData) {
+ // Each lane loads either one element
+ SmallVector<int> instDataUarch{subgroupSize};
+ // Or multiple elements as 2D with lane's elements in the inner dimension
+ if (payloadTy.getRank() != 1) {
+ if (payloadTy.getRank() != 2) {
+ load.emitWarning("Expected 2D payload for LoadGatherOp.");
+ return;
+ }
+ instDataUarch.push_back(
+ (std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize())));
+ }
+ // If inst data does not match, enforce the uArch-based one
+ if (!llvm::equal(instDataIncoming, instDataUarch)) {
+ xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
+ sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
+ }
+ assert(sourceAttr);
+ xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
+ load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
+ DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
+ sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
+ sourceAttr.getOrder());
+
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+ updatedLayoutAttr = xegpu::SliceAttr::get(
+ load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
+ valueLayout = LayoutInfo(updatedLayoutAttr);
+ }
}
+ loadLayout = valueLayout;
+ load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ }
+ // If no user-defined anchor or we deal with a chunked op, set the default
+ // mask layout.
+ // Rank 1 data : Keep the mask layout aligned with data.
+ // Rank >1 data: Enforce the default xegpu 1D layout for mask.
+ if (!hasParamsOfLayoutKind(anchorLayout) ||
+ load.getValueType().getRank() > 1) {
if (layoutKind == LayoutKind::InstData)
- loadLayout =
- LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
- else
- loadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
-
- // Mask operand should have 1D default layout.
- maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
-
- load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+ maskLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+ else if (layoutKind == LayoutKind::Lane)
+ maskLayout =
+ getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
}
+
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
@@ -1209,6 +1255,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
LayoutInfo payloadLayout;
LayoutInfo maskLayout;
xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
if (hasParamsOfLayoutKind(anchorLayout)) {
payloadLayout = LayoutInfo(anchorLayout);
maskLayout = payloadLayout;
@@ -1222,21 +1271,23 @@ void LayoutInfoPropagation::visitStoreScatterOp(
return;
}
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
-
if (layoutKind == LayoutKind::InstData) {
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
- chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
- storeScatter.getDestType())) {
- if (dstTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::StoreScatter));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instDataUarch{subgroupSize};
+ if (payloadTy.getRank() != 1) {
+ if (payloadTy.getRank() != 2) {
+ storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
+ return;
+ }
+ instDataUarch.push_back(
+ (std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize())));
}
payloadLayout = LayoutInfo(
- xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
} else {
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
@@ -1247,12 +1298,24 @@ void LayoutInfoPropagation::visitStoreScatterOp(
payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
}
- maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
-
storeScatter.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
}
+
+ // If no user-defined anchor or we deal with a chunked op, set the default
+ // mask layout.
+ // Rank 1 data : Keep the mask layout aligned with data.
+ // Rank >1 data: Enforce the default xegpu 1D layout for mask.
+ if (!hasParamsOfLayoutKind(anchorLayout) ||
+ storeScatter.getValueType().getRank() > 1) {
+ if (layoutKind == LayoutKind::InstData)
+ maskLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
+ else if (layoutKind == LayoutKind::Lane)
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ }
+
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index aef9d7fc9e03a..5aad0f592abed 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -138,8 +138,8 @@ gpu.module @test_kernel {
gpu.module @test {
// CHECK-LABEL: func.func @scatter_ops_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}>
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
@@ -162,6 +162,75 @@ gpu.module @test {
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
%cst = arith.constant dense<0.0000> : vector<16x16xf16>
xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+ return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=32}>
+ : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_anchor(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=32, layout = #xegpu.layout<inst_data = [16, 16]>}>
+ : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_slice(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
+// CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+
+ %4 = vector.broadcast %3 : vector<16xf32> to vector<16x16xf32>
+ xegpu.store %4, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
+ : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
return
}
}
More information about the Mlir-commits
mailing list