[Mlir-commits] [mlir] [MLIR][XeGPU] Consider alignment in dpas sg_layout creation (PR #181141)
Artem Kroviakov
llvmlistbot at llvm.org
Fri Feb 13 07:38:48 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/181141
>From 8bcedb7617e46a91d1677b77736127e7c020f6b4 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 12 Feb 2026 13:50:44 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Dpas sg_layout propagation to consider A/B
alignment
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 6 +++
.../XeGPU/propagate-layout-subgroup.mlir | 49 +++++++++++++++----
2 files changed, 45 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index ca62b2893eb68..2822805829342 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -999,9 +999,15 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
layoutsCD.end());
std::optional<LayoutRepresentation> bestPick;
+ auto checkSgDataAB = [&](LayoutRepresentation sgLayout) {
+ return aTy.getShape().back() / sgLayout.second ==
+ bTy.getShape().front() / sgLayout.first;
+ };
for (auto &sgLayout : layoutsB) {
if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
// Is in (A and B and CD) and matches consumer -> best pick
+ if (!checkSgDataAB(sgLayout))
+ continue;
if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
bestPick = sgLayout;
break;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 6e6ed01845ff1..d48da04c99c27 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -168,7 +168,7 @@ gpu.module @test {
// -----
gpu.module @test {
// CHECK-LABEL: for_loop_dpas
- gpu.func @for_loop_dpas(%arg0: memref<2048x8192xf16>, %arg1: memref<8192x4096xf16>, %arg2: memref<2048x4096xf32>) kernel attributes {known_block_size = array<i32: 8, 1, 16>} {
+ gpu.func @for_loop_dpas(%arg0: memref<2048x8192xf16>, %arg1: memref<8192x4096xf16>, %arg2: memref<2048x4096xf32>) kernel attributes {known_block_size = array<i32: 4, 1, 16>} {
%cst = arith.constant dense<0.000000e+00> : vector<128x128xf32>
%c128 = arith.constant 128 : index
%c8192 = arith.constant 8192 : index
@@ -180,25 +180,25 @@ gpu.module @test {
// CHECK: %2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (vector<128x128xf32>) {
// CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<2048x8192xf16> ->
// CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
- // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+ // CHECK-SAME: #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>>
- // CHECK-NEXT: xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+ // CHECK-NEXT: xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>}>
// CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<8192x4096xf16> ->
// CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
- // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+ // CHECK-SAME: #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>>
- // CHECK-NEXT: xegpu.load_nd %6[%arg3, %block_id_y] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+ // CHECK-NEXT: xegpu.load_nd %6[%arg3, %block_id_y] <{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>}>
// CHECK-NEXT: xegpu.dpas %{{.*}} {
- // CHECK-SAME: layout_a = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
- // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
- // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}
+ // CHECK-SAME: layout_a = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>,
+ // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>,
+ // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>}
// CHECK-SAME: : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
// CHECK-NEXT: scf.yield %{{.*}} : vector<128x128xf32>
- // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}
- // CHECK: xegpu.store_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+ // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>}
+ // CHECK: xegpu.store_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [64, 64]>}>
%2 = scf.for %arg3 = %c0 to %c8192 step %c128 iter_args(%arg4 = %cst) -> (vector<128x128xf32>) {
%4 = xegpu.create_nd_tdesc %arg0 : memref<2048x8192xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
@@ -214,6 +214,35 @@ gpu.module @test {
}
}
+// -----
+gpu.module @test {
+ // CHECK-LABEL: for_loop_misaligned_dpas_fail
+ gpu.func @for_loop_misaligned_dpas_fail(%arg0: memref<2048x8192xf16>, %arg1: memref<8192x4096xf16>, %arg2: memref<2048x4096xf32>) kernel attributes {known_block_size = array<i32: 8, 1, 16>} {
+ %cst = arith.constant dense<0.000000e+00> : vector<128x128xf32>
+ %c128 = arith.constant 128 : index
+ %c8192 = arith.constant 8192 : index
+ %c0 = arith.constant 0 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%block_id_x]
+ %1 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%block_id_y]
+ %2 = scf.for %arg3 = %c0 to %c8192 step %c128 iter_args(%arg4 = %cst) -> (vector<128x128xf32>) {
+ %4 = xegpu.create_nd_tdesc %arg0 : memref<2048x8192xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ %5 = xegpu.load_nd %4[%block_id_x, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<128x128xf16>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<8192x4096xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ %7 = xegpu.load_nd %6[%arg3, %block_id_y] : !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<128x128xf16>
+ // Couldn not find a layout whose sg_data would be aligned on the reduction dimension.
+ // CHECK: xegpu.dpas %{{.*}} {layout_cd = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>} :
+ %8 = xegpu.dpas %5, %7, %arg4 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+ scf.yield %8 : vector<128x128xf32>
+ }
+ %3 = xegpu.create_nd_tdesc %arg2 : memref<2048x4096xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+ xegpu.store_nd %2, %3[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+ gpu.return
+ }
+}
+
+
// -----
gpu.module @test {
// CHECK-LABEL: dpas_fails
>From e01e7dff9fc58cee0fbdb03a5c977803227a9836 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 13 Feb 2026 15:38:25 +0000
Subject: [PATCH 2/2] Move check
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 2822805829342..92b9a5788c004 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -999,15 +999,15 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
layoutsCD.end());
std::optional<LayoutRepresentation> bestPick;
- auto checkSgDataAB = [&](LayoutRepresentation sgLayout) {
+ auto checkAlignedSgDataAB = [&](LayoutRepresentation sgLayout) {
return aTy.getShape().back() / sgLayout.second ==
bTy.getShape().front() / sgLayout.first;
};
for (auto &sgLayout : layoutsB) {
if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
- // Is in (A and B and CD) and matches consumer -> best pick
- if (!checkSgDataAB(sgLayout))
+ if (!checkAlignedSgDataAB(sgLayout))
continue;
+ // Is in (A and B and CD) and matches consumer -> best pick
if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
bestPick = sgLayout;
break;
More information about the Mlir-commits
mailing list