[Mlir-commits] [mlir] [MLIR][XeGPU] Add uArch limitation to scatter load store (PR #172845)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Jan 15 07:07:41 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/172845
>From 2970ef37c677866d693651082fbce8bede731089 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 18 Dec 2025 12:30:40 +0000
Subject: [PATCH 1/6] [MLIR][XeGPU] Add uArch limitation to scatter load store
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 34 +++++++++++++--
.../mlir/Dialect/XeGPU/uArch/uArchBase.h | 8 +++-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 41 +++++++++++++------
.../XeGPU/propagate-layout-inst-data.mlir | 22 ++++++++++
4 files changed, 87 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index b3231a173f33a..055c10ab50652 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,6 +215,26 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
const unsigned packedFormatBitSizeB;
};
+struct StoreScatterInstruction : public Instruction {
+ StoreScatterInstruction()
+ : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::StoreScatter;
+ }
+
+ int32_t getMaxBitSize() const { return 128; }
+};
+
+struct LoadGatherInstruction : public Instruction {
+ LoadGatherInstruction()
+ : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::LoadGather;
+ }
+
+ int32_t getMaxBitSize() const { return 128; }
+};
+
//===----------------------------------------------------------------------===//
// uArch instances
//===----------------------------------------------------------------------===//
@@ -225,8 +245,11 @@ struct PVCuArch final : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
- &prefetchNdInst};
+ static const StoreScatterInstruction storeScatterInst;
+ static const LoadGatherInstruction loadGatherInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst,
+ &storeNdInst, &prefetchNdInst,
+ &storeScatterInst, &loadGatherInst};
return arr;
}
@@ -248,8 +271,11 @@ struct BMGuArch : public Xe2Plus {
static const Subgroup2DBlockLoadInstruction loadNdInst;
static const Subgroup2DBlockStoreInstruction storeNdInst;
static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
- static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
- &prefetchNdInst};
+ static const StoreScatterInstruction storeScatterInst;
+ static const LoadGatherInstruction loadGatherInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst,
+ &storeNdInst, &prefetchNdInst,
+ &storeScatterInst, &loadGatherInst};
return arr;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 8f23b89134773..db1984b2edb1d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -38,7 +38,9 @@ enum class InstructionKind {
// matrix multiply-add operation
Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
- Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
+ Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
+ StoreScatter, // Lane-level store (scalar, vector)
+ LoadGather // Lane-level load (scalar, vector)
// @TODO: Add more instructions as needed
};
@@ -65,6 +67,10 @@ struct Instruction {
return "load_nd";
case InstructionKind::Subgroup2DBlockPrefetch:
return "prefetch_nd";
+ case InstructionKind::StoreScatter:
+ return "store";
+ case InstructionKind::LoadGather:
+ return "load";
}
llvm_unreachable("Unknown InstructionKind");
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..76fa7c0e2698d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -986,15 +986,21 @@ void LayoutInfoPropagation::visitLoadGatherOp(
return;
}
auto uArch = getUArch(getChipStr(load).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+ const int maxElemsPerInst =
+ uArchInstruction->getMaxBitSize() /
+ payloadTy.getElementType().getIntOrFloatBitWidth();
+
const int subgroupSize = uArch->getSubgroupSize();
SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
- if (srcTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
+ auto chunkSize = load.getChunkSize().value_or(0);
+ if (auto srcTdescTy = dyn_cast<xegpu::TensorDescType>(load.getSourceType());
+ !chunkSize && srcTdescTy) {
+ chunkSize = srcTdescTy.getChunkSizeAsInt();
}
+ instData.push_back(std::min(static_cast<int>(chunkSize), maxElemsPerInst));
if (layoutKind == LayoutKind::InstData)
loadLayout =
@@ -1060,15 +1066,24 @@ void LayoutInfoPropagation::visitStoreScatterOp(
const int subgroupSize = uArch->getSubgroupSize();
if (layoutKind == LayoutKind::InstData) {
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+ uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+ const int maxElemsPerInst =
+ uArchInstruction->getMaxBitSize() /
+ payloadTy.getElementType().getIntOrFloatBitWidth();
+
+ const int subgroupSize = uArch->getSubgroupSize();
SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
- chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
- storeScatter.getDestType())) {
- if (dstTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
+ auto chunkSize = storeScatter.getChunkSize().value_or(0);
+ if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType());
+ !chunkSize && srcTdescTy) {
+ chunkSize = srcTdescTy.getChunkSizeAsInt();
}
+ instData.push_back(
+ std::min(static_cast<int>(chunkSize), maxElemsPerInst));
+
payloadLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
} else {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..f2b270ab16218 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,25 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
return
}
}
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 4]>} : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=16}>
+ : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
>From e80cad1faec06e5dabf39388343248d3047211a0 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 19 Dec 2025 12:30:18 +0000
Subject: [PATCH 2/6] LoadGather to modify its inst_data
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 72 +++++++++++++------
.../XeGPU/propagate-layout-inst-data.mlir | 26 ++++++-
2 files changed, 75 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 76fa7c0e2698d..ef57278662cb0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -978,8 +978,19 @@ void LayoutInfoPropagation::visitLoadGatherOp(
loadLayout = LayoutInfo(anchorLayout);
maskLayout = loadLayout;
} else {
+ LayoutInfo valueLayout = results[0]->getValue();
+ // Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+
+ auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
+ auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+ instDataIncoming = SmallVector<int64_t>(
+ cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
+ .getInstData()
+ .asArrayRef());
- // The layout is strictly determined by the payload type.
VectorType payloadTy = load.getValueType();
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
@@ -989,30 +1000,48 @@ void LayoutInfoPropagation::visitLoadGatherOp(
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::LoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
- const int maxElemsPerInst =
- uArchInstruction->getMaxBitSize() /
- payloadTy.getElementType().getIntOrFloatBitWidth();
const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- auto chunkSize = load.getChunkSize().value_or(0);
- if (auto srcTdescTy = dyn_cast<xegpu::TensorDescType>(load.getSourceType());
- !chunkSize && srcTdescTy) {
- chunkSize = srcTdescTy.getChunkSizeAsInt();
- }
- instData.push_back(std::min(static_cast<int>(chunkSize), maxElemsPerInst));
-
- if (layoutKind == LayoutKind::InstData)
- loadLayout =
- LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
- else
- loadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
-
// Mask operand should have 1D default layout.
maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+ // Check if value inst_data complies with uArch
+ if (!instDataIncoming.empty()) {
+ const int maxElemsPerInst =
+ uArchInstruction->getMaxBitSize() /
+ payloadTy.getElementType().getIntOrFloatBitWidth();
+
+ xegpu::LayoutAttr sourceAttr;
+ // Each lane loads either one element
+ SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
+ // Or multiple elements as 2D with lane's elements in the inner dimension
+ if (payloadTy.getRank() == 1) {
+ instDataUarch.back() = subgroupSize;
+ } else {
+ *std::prev(instDataUarch.end(), 2) = subgroupSize;
+ instDataUarch.back() = (std::min(
+ static_cast<int>(payloadTy.getShape().back()), maxElemsPerInst));
+ }
+ // If inst data does not match, enforce the uArch-based one
+ if (!llvm::equal(instDataIncoming, instDataUarch)) {
+ xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
+ sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
+ }
+ assert(sourceAttr);
+ xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
+ load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
+ DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
+ sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
+ sourceAttr.getOrder());
+
+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
+ updatedLayoutAttr = xegpu::SliceAttr::get(
+ load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
+ valueLayout = LayoutInfo(updatedLayoutAttr);
+ }
+ }
+ loadLayout = valueLayout;
load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
// Propagate the new layout to the tensor descriptor operand.
@@ -1235,7 +1264,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setDistributeLayoutAttr(result, layout);
+ if (!isa<xegpu::LoadGatherOp>(op))
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index f2b270ab16218..1fe32d380eafc 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -161,8 +161,8 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 4]>} : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
@@ -175,3 +175,25 @@ func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
return
}
}
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_slice(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 4]>, dims = [0]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive_slice(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}>
+ : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
>From 47ebcbbc9e11f65fdcea50dbd6b3d949fb2b20d3 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 8 Jan 2026 17:24:01 +0000
Subject: [PATCH 3/6] Remove payload layout for mask/offsets
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 24 +++++++++----------
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 16 ++++++-------
2 files changed, 19 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ef57278662cb0..991da194fbbe8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -973,10 +973,11 @@ void LayoutInfoPropagation::visitLoadGatherOp(
LayoutInfo loadLayout;
LayoutInfo maskLayout;
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
if (hasParamsOfLayoutKind(anchorLayout)) {
loadLayout = LayoutInfo(anchorLayout);
- maskLayout = loadLayout;
} else {
LayoutInfo valueLayout = results[0]->getValue();
// Need the layout of the value to propagate to the tensor descriptor.
@@ -996,22 +997,16 @@ void LayoutInfoPropagation::visitLoadGatherOp(
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- auto uArch = getUArch(getChipStr(load).value_or(""));
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::LoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
- const int subgroupSize = uArch->getSubgroupSize();
- // Mask operand should have 1D default layout.
- maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
-
// Check if value inst_data complies with uArch
if (!instDataIncoming.empty()) {
const int maxElemsPerInst =
uArchInstruction->getMaxBitSize() /
payloadTy.getElementType().getIntOrFloatBitWidth();
- xegpu::LayoutAttr sourceAttr;
// Each lane loads either one element
SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
// Or multiple elements as 2D with lane's elements in the inner dimension
@@ -1044,6 +1039,9 @@ void LayoutInfoPropagation::visitLoadGatherOp(
loadLayout = valueLayout;
load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
+ // Mask operand should have 1D default layout.
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
@@ -1078,6 +1076,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
LayoutInfo payloadLayout;
LayoutInfo maskLayout;
xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
if (hasParamsOfLayoutKind(anchorLayout)) {
payloadLayout = LayoutInfo(anchorLayout);
maskLayout = payloadLayout;
@@ -1091,9 +1092,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
return;
}
- auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
-
if (layoutKind == LayoutKind::InstData) {
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::LoadGatherInstruction>(
@@ -1127,12 +1125,12 @@ void LayoutInfoPropagation::visitStoreScatterOp(
/*scattered=*/true);
}
- maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
-
storeScatter.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
}
+
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..1e7da3ddb6aab 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
gpu.module @test {
// CHECK-LABEL: func.func @dpas_i8(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
@@ -218,8 +218,8 @@ func.func @scatter_ops(%src: memref<256xf16>) {
gpu.module @test {
// CHECK-LABEL: func.func @scatter_ops_custom_perm_layout(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
@@ -238,9 +238,9 @@ func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) {
gpu.module @test {
// CHECK-LABEL: func.func @scatter_ops_preserve_load_perm_layout(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,4 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
-}
\ No newline at end of file
+}
>From 2a98ae0ac99cf2fea449627a706f153c0068c5c3 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 9 Jan 2026 11:00:38 +0000
Subject: [PATCH 4/6] inst_data for mask
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 4 ++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 21 +++++++++++++-----
.../XeGPU/propagate-layout-inst-data.mlir | 22 +++++++++----------
3 files changed, 28 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 055c10ab50652..fd577af612523 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -222,7 +222,7 @@ struct StoreScatterInstruction : public Instruction {
return B->getInstructionKind() == InstructionKind::StoreScatter;
}
- int32_t getMaxBitSize() const { return 128; }
+ int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
};
struct LoadGatherInstruction : public Instruction {
@@ -232,7 +232,7 @@ struct LoadGatherInstruction : public Instruction {
return B->getInstructionKind() == InstructionKind::LoadGather;
}
- int32_t getMaxBitSize() const { return 128; }
+ int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 991da194fbbe8..4a454cbebe59d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1004,7 +1004,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
// Check if value inst_data complies with uArch
if (!instDataIncoming.empty()) {
const int maxElemsPerInst =
- uArchInstruction->getMaxBitSize() /
+ uArchInstruction->getMaxLaneLoadStoreBitSize() /
payloadTy.getElementType().getIntOrFloatBitWidth();
// Each lane loads either one element
@@ -1039,8 +1039,12 @@ void LayoutInfoPropagation::visitLoadGatherOp(
loadLayout = valueLayout;
load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
}
- // Mask operand should have 1D default layout.
- maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+
+ if (layoutKind == LayoutKind::InstData)
+ maskLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+ else if (layoutKind == LayoutKind::Lane)
+ maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
@@ -1097,7 +1101,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
dyn_cast<xegpu::uArch::LoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
const int maxElemsPerInst =
- uArchInstruction->getMaxBitSize() /
+ uArchInstruction->getMaxLaneLoadStoreBitSize() /
payloadTy.getElementType().getIntOrFloatBitWidth();
const int subgroupSize = uArch->getSubgroupSize();
@@ -1129,8 +1133,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
}
- maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+ if (layoutKind == LayoutKind::InstData)
+ maskLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
+ else if (layoutKind == LayoutKind::Lane)
+ maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 1fe32d380eafc..24d4a99bb0b0b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -138,8 +138,8 @@ gpu.module @test_kernel {
gpu.module @test {
// CHECK-LABEL: func.func @scatter_ops_chunksize(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}>
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64, layout = #xegpu.layout<inst_data = [16, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
@@ -159,8 +159,8 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
gpu.module @test {
// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
@@ -179,20 +179,20 @@ func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
// -----
gpu.module @test {
-// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_slice(
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_anchor(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 4]>, dims = [0]>}> :
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}> :
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
-func.func @scatter_ops_chunksize_excessive_slice(%src: memref<1024xf32>) {
+func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
%3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16, 16]>, dims = [0]>}>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
return
}
>From ea4d3060084f608d0a012d33f417f1fd8630109e Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 12 Jan 2026 17:34:13 +0000
Subject: [PATCH 5/6] Change uArch restriction to reflect spirv
---
.../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h | 6 ++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 17 +++-------
.../XeGPU/propagate-layout-inst-data.mlir | 32 +++++++++----------
3 files changed, 25 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index fd577af612523..29e75b57f4a5f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -222,7 +222,8 @@ struct StoreScatterInstruction : public Instruction {
return B->getInstructionKind() == InstructionKind::StoreScatter;
}
- int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
struct LoadGatherInstruction : public Instruction {
@@ -232,7 +233,8 @@ struct LoadGatherInstruction : public Instruction {
return B->getInstructionKind() == InstructionKind::LoadGather;
}
- int32_t getMaxLaneLoadStoreBitSize() const { return 128; }
+ // SPIRV restricts vector size
+ int32_t getMaxLaneLoadStoreSize() const { return 16; }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4a454cbebe59d..1427202ffcf16 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1003,10 +1003,6 @@ void LayoutInfoPropagation::visitLoadGatherOp(
// Check if value inst_data complies with uArch
if (!instDataIncoming.empty()) {
- const int maxElemsPerInst =
- uArchInstruction->getMaxLaneLoadStoreBitSize() /
- payloadTy.getElementType().getIntOrFloatBitWidth();
-
// Each lane loads either one element
SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
// Or multiple elements as 2D with lane's elements in the inner dimension
@@ -1014,8 +1010,9 @@ void LayoutInfoPropagation::visitLoadGatherOp(
instDataUarch.back() = subgroupSize;
} else {
*std::prev(instDataUarch.end(), 2) = subgroupSize;
- instDataUarch.back() = (std::min(
- static_cast<int>(payloadTy.getShape().back()), maxElemsPerInst));
+ instDataUarch.back() =
+ (std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize()));
}
// If inst data does not match, enforce the uArch-based one
if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1100,10 +1097,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
const auto *uArchInstruction =
dyn_cast<xegpu::uArch::LoadGatherInstruction>(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
- const int maxElemsPerInst =
- uArchInstruction->getMaxLaneLoadStoreBitSize() /
- payloadTy.getElementType().getIntOrFloatBitWidth();
-
const int subgroupSize = uArch->getSubgroupSize();
SmallVector<int> instData{subgroupSize};
auto chunkSize = storeScatter.getChunkSize().value_or(0);
@@ -1112,8 +1105,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
!chunkSize && srcTdescTy) {
chunkSize = srcTdescTy.getChunkSizeAsInt();
}
- instData.push_back(
- std::min(static_cast<int>(chunkSize), maxElemsPerInst));
+ instData.push_back(std::min(static_cast<int>(chunkSize),
+ uArchInstruction->getMaxLaneLoadStoreSize()));
payloadLayout = LayoutInfo(
xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 24d4a99bb0b0b..efcb1f062b82e 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -161,17 +161,17 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize_excessive(%src: memref<1024xf32>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
- : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=16}>
- : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=32}>
+ : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
return
}
}
@@ -183,17 +183,17 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 4]>}> :
-// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
-// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
-// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+// CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+// CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 32 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
%1 = arith.constant dense<1>: vector<16xi1>
%offset = arith.constant dense<12> : vector<16xindex>
- %3 = xegpu.load %src[%offset], %1 <{chunk_size=16}>
- : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x16xf32>
- xegpu.store %3, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
- : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=32}>
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16x32xf32>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=32, layout = #xegpu.layout<inst_data = [16, 16]>}>
+ : vector<16x32xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
return
}
}
>From 2b9ea2f7a5e8fa78c17040485bd0a537dcf10371 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 15 Jan 2026 15:07:03 +0000
Subject: [PATCH 6/6] Feedback
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 44 +++++++++----------
.../XeGPU/propagate-layout-inst-data.mlir | 25 +++++++++++
2 files changed, 47 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1427202ffcf16..c4b61d2d53304 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1002,17 +1002,18 @@ void LayoutInfoPropagation::visitLoadGatherOp(
uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
// Check if value inst_data complies with uArch
- if (!instDataIncoming.empty()) {
+ if (layoutKind == LayoutKind::InstData) {
// Each lane loads either one element
- SmallVector<int> instDataUarch(instDataIncoming.size(), 1);
+ SmallVector<int> instDataUarch{subgroupSize};
// Or multiple elements as 2D with lane's elements in the inner dimension
- if (payloadTy.getRank() == 1) {
- instDataUarch.back() = subgroupSize;
- } else {
- *std::prev(instDataUarch.end(), 2) = subgroupSize;
- instDataUarch.back() =
+ if (payloadTy.getRank() != 1) {
+ if (payloadTy.getRank() != 2) {
+ load.emitWarning("Expected 2D payload for LoadGatherOp.");
+ return;
+ }
+ instDataUarch.push_back(
(std::min(static_cast<int>(payloadTy.getShape().back()),
- uArchInstruction->getMaxLaneLoadStoreSize()));
+ uArchInstruction->getMaxLaneLoadStoreSize())));
}
// If inst data does not match, enforce the uArch-based one
if (!llvm::equal(instDataIncoming, instDataUarch)) {
@@ -1095,21 +1096,21 @@ void LayoutInfoPropagation::visitStoreScatterOp(
if (layoutKind == LayoutKind::InstData) {
const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::LoadGatherInstruction>(
- uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+ dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::StoreScatter));
const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData{subgroupSize};
- auto chunkSize = storeScatter.getChunkSize().value_or(0);
- if (auto srcTdescTy =
- dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType());
- !chunkSize && srcTdescTy) {
- chunkSize = srcTdescTy.getChunkSizeAsInt();
+ SmallVector<int> instDataUarch{subgroupSize};
+ if (payloadTy.getRank() != 1) {
+ if (payloadTy.getRank() != 2) {
+ storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
+ return;
+ }
+ instDataUarch.push_back(
+ (std::min(static_cast<int>(payloadTy.getShape().back()),
+ uArchInstruction->getMaxLaneLoadStoreSize())));
}
- instData.push_back(std::min(static_cast<int>(chunkSize),
- uArchInstruction->getMaxLaneLoadStoreSize()));
-
payloadLayout = LayoutInfo(
- xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
} else {
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
@@ -1264,8 +1265,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- if (!isa<xegpu::LoadGatherOp>(op))
- xegpu::setDistributeLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index efcb1f062b82e..18d892da8194a 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -197,3 +197,28 @@ func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
return
}
}
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @scatter_ops_chunksize_excessive_anchor(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
+// CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
+// CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize_excessive_anchor(%src: memref<1024xf32>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1
+ : memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+
+ %4 = vector.broadcast %3 : vector<16xf32> to vector<16x16xf32>
+ xegpu.store %4, %src[%offset], %1 <{chunk_size=16, layout = #xegpu.layout<inst_data = [16, 16]>}>
+ : vector<16x16xf32>, memref<1024xf32>, vector<16xindex>, vector<16xi1>
+ return
+}
+}
More information about the Mlir-commits
mailing list