[Mlir-commits] [mlir] [MLIR][XeGPU] Honor convert_layout's input_layout when resolving layout conflict (PR #194935)
Nishant Patel
llvmlistbot at llvm.org
Wed May 13 14:47:34 PDT 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/194935
>From 97d555d71828aabd90ba2d0ba152c4dc9fb8d94e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 29 Apr 2026 15:17:50 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Honor convert_layout's input_layout when
resolving layout conflict
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 6 +++++
.../XeGPU/resolve-layout-conflicts.mlir | 26 +++++++++++++++++++
2 files changed, 32 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index d3925c40f9123b..4a7dc615dbde3f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1607,6 +1607,12 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
Operation *op = operand.getOwner();
+ // xegpu.convert_layout explicitly states its expected operand layout via
+ // the input_layout attribute. Use it directly so that ResolveLayoutConflicts
+ // sees the real expectation (and inserts a bridge convert_layout) when the
+ // producer's layout differs from this stated input layout.
+ if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op))
+ return convertOp.getInputLayoutAttr();
xegpu::DistributeLayoutAttr resLayout;
if (op->getNumResults() == 1)
resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index a33bd015180a6b..80f0dc7f49e1b0 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -278,3 +278,29 @@ func.func @convert_layout() {
return
}
}
+
+// -----
+
+// Producer carries a layout that differs from the input_layout declared on a
+// downstream xegpu.convert_layout consumer. ResolveLayoutConflicts must insert
+// a bridging convert_layout between the producer and the existing
+// convert_layout so the consumer's stated input_layout is honored.
+// CHECK-LABEL: func.func @convert_layout_bridge_input_mismatch
+// CHECK: %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : () -> vector<32x32xf16>
+// CHECK-NEXT: %[[BRIDGE:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [8, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}>
+// CHECK-SAME: : vector<32x32xf16>
+// CHECK-NEXT: %[[CVT:.*]] = xegpu.convert_layout %[[BRIDGE]]
+// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [32, 16]>}>
+// CHECK-SAME: : vector<32x32xf16>
+// CHECK-NEXT: return %[[CVT]] : vector<32x32xf16>
+gpu.module @test_convert_layout_bridge {
+func.func @convert_layout_bridge_input_mismatch() -> vector<32x32xf16> {
+ %0 = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : () -> vector<32x32xf16>
+ %1 = xegpu.convert_layout %0
+ <{input_layout = #xegpu.layout<inst_data = [16, 16]>,
+ target_layout = #xegpu.layout<inst_data = [32, 16]>}>
+ : vector<32x32xf16>
+ return %1 : vector<32x32xf16>
+}
+}
>From 3fa5a708103f669bde796e610fa1f092ecee364e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 29 Apr 2026 20:09:55 +0000
Subject: [PATCH 2/3] add check for anchor layout
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 3 +-
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 27 +++----
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 9 ++-
.../XeGPU/resolve-layout-conflicts.mlir | 73 +++++++++----------
4 files changed, 57 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 2dd8d9f610faf7..358bde51ddc6de 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -126,7 +126,8 @@ inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
/// Infers the source layout attribute for an operand using result layout
/// attribute
DistributeLayoutAttr
-inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
+inferSourceLayoutFromResultForNonAnchorOp(OpOperand &operand,
+ DistributeLayoutAttr resLayout);
/// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
/// the result.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 4a7dc615dbde3f..d4de88b125efc8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -153,7 +153,7 @@ static void propagateResultsToRegularOperands(Operation *op) {
for (OpOperand &opr : op->getOpOperands()) {
xegpu::DistributeLayoutAttr operandLayout =
- xegpu::inferSourceLayoutFromResult(opr, resLayout);
+ xegpu::inferSourceLayoutFromResultForNonAnchorOp(opr, resLayout);
if (isa<VectorType>(opr.get().getType()) && operandLayout)
xegpu::setTemporaryLayout(opr, operandLayout);
}
@@ -1523,9 +1523,8 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
return std::nullopt;
}
-xegpu::DistributeLayoutAttr
-xegpu::inferSourceLayoutFromResult(OpOperand &operand,
- xegpu::DistributeLayoutAttr resLayout) {
+xegpu::DistributeLayoutAttr xegpu::inferSourceLayoutFromResultForNonAnchorOp(
+ OpOperand &operand, xegpu::DistributeLayoutAttr resLayout) {
if (!resLayout)
return nullptr;
Operation *op = operand.getOwner();
@@ -1607,19 +1606,15 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
Operation *op = operand.getOwner();
- // xegpu.convert_layout explicitly states its expected operand layout via
- // the input_layout attribute. Use it directly so that ResolveLayoutConflicts
- // sees the real expectation (and inserts a bridge convert_layout) when the
- // producer's layout differs from this stated input layout.
- if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op))
- return convertOp.getInputLayoutAttr();
+ // Anchor ops declare the layout they
+ // require on each operand. Trust that declaration directly so that
+ // ResolveLayoutConflicts compares producer-vs-declared
+ if (isa<xegpu::AnchorLayoutInterface>(op))
+ return xegpu::getDistributeLayoutAttr(operand);
+ // For non-anchor ops, derive the operand layout from the op's result
+ // layout via op-specific semantics.
xegpu::DistributeLayoutAttr resLayout;
if (op->getNumResults() == 1)
resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
- auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
- if (inferredOperandLayout)
- return inferredOperandLayout;
- // By default, assume no layout conflict and return the current layout of
- // the operand.
- return xegpu::getDistributeLayoutAttr(operand.get());
+ return inferSourceLayoutFromResultForNonAnchorOp(operand, resLayout);
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index a5776ebce2e952..40b8e57a8610ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1472,7 +1472,14 @@ ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
consumerOp->emitWarning("Expected layout for non-1D vectors.");
return success(); // uniform non-tensor-data vector does not require layout
}
- // Get the consumer expected layout at this operand.
+ // Region branch ops (e.g. scf.for) and their terminators (e.g. scf.yield)
+ // forward their operands to successor region inputs / parent op results;
+ // their consumer layout is resolved through that forwarding, not at this
+ // use point.
+ if (isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(
+ consumerOp))
+ return success();
+
auto consumerLayout = xegpu::getConsumerLayoutAt(operand);
if (!consumerLayout)
return consumerOp->emitError(
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index 80f0dc7f49e1b0..0c83c69a7f4747 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -16,14 +16,14 @@ gpu.module @test {
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>>
// CHECK-NEXT: %{{.*}} = xegpu.load_nd %[[T1]][%{{.*}}, %{{.*}}] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x16xf16>
-func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
+func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-> !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
%1 = xegpu.load_nd %0 [%c0, %c0] {layout = #inst_data_8x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
-> vector<16x16xf16>
xegpu.prefetch_nd %0 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
- return %1 : vector<16x16xf16>
+ return
}
// CHECK-LABEL: func.func @multiple_tensor_desc_conflicts
@@ -40,7 +40,7 @@ func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vec
// CHECK-SAME: !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [32, 16]>> -> vector<32x16xf16>
// CHECK-NEXT: xegpu.prefetch_nd %[[T1]][%[[C0]], %[[C0]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
// CHECK-SAME: !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [16, 16]>>
-func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> (vector<32x16xf16>, vector<32x16xf16>) {
+func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
%tdesc1 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-> !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
@@ -49,7 +49,7 @@ func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> (vector<3
%load2 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #inst_data_32x16} : !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
-> vector<32x16xf16>
xegpu.prefetch_nd %tdesc1 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
- return %load1, %load2 : vector<32x16xf16>, vector<32x16xf16>
+ return
}
// CHECK-LABEL: func.func @load_nd_with_conflicting_tensor_desc_in_loop
@@ -64,8 +64,8 @@ func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> (vector<3
// CHECK: scf.yield %{{.*}}, %{{.*}} : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
// CHECK: xegpu.prefetch_nd %[[T0]][%{{.*}}] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
-// CHECK-NEXT: return %{{.*}}#0 : vector<16x16xf16>
-func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
+// CHECK-NEXT: return
+func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -80,7 +80,7 @@ func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>
scf.yield %3, %tdesc : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
} {layout_result_0 = #inst_data_8x16}
xegpu.prefetch_nd %0 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
- return %1#0 : vector<16x16xf16>
+ return
}
@@ -92,24 +92,24 @@ func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>
// CHECK-SAME: : vector<32x32xf16>
// CHECK: %[[ADD:.*]] = arith.addf %[[V0]], %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32x32xf16>
-// CHECK: return %[[ADD]] : vector<32x32xf16>
-func.func @elementwise_conflict() -> vector<32x32xf16> {
+// CHECK: return
+func.func @elementwise_conflict() {
%0 = "some_op"() {layout_result_0 = #inst_data_8x16} : () -> vector<32x32xf16>
%1 = "some_op"() {layout_result_0 = #inst_data_32x16} : () -> vector<32x32xf16>
%2 = arith.addf %0, %1 {layout_result_0 = #inst_data_8x16} : vector<32x32xf16>
- return %2 : vector<32x32xf16>
+ return
}
// CHECK-LABEL: func.func @elementwise_conflict_uniform
// CHECK-DAG: %[[V0:.*]] = "some_op"() : () -> vector<2xf16>
// CHECK-DAG: %[[V1:.*]] = "some_op"() : () -> vector<2xf16>
// CHECK: %[[ADD:.*]] = arith.addf %[[V0]], %[[V1]] : vector<2xf16>
-// CHECK: return %[[ADD]] : vector<2xf16>
-func.func @elementwise_conflict_uniform() -> vector<2xf16> {
+// CHECK: return
+func.func @elementwise_conflict_uniform() {
%0 = "some_op"() : () -> vector<2xf16>
%1 = "some_op"() : () -> vector<2xf16>
%non_tensor_data_vec = arith.addf %0, %1 : vector<2xf16>
- return %non_tensor_data_vec : vector<2xf16>
+ return
}
// CHECK-LABEL: func.func @broadcast_source_conflict
@@ -119,11 +119,11 @@ func.func @elementwise_conflict_uniform() -> vector<2xf16> {
// CHECK-SAME: : vector<16xf16>
// CHECK: %[[BC:.*]] = vector.broadcast %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>
-// CHECK: return %[[BC]] : vector<16x16xf16>
-func.func @broadcast_source_conflict() -> vector<16x16xf16> {
+// CHECK: return
+func.func @broadcast_source_conflict() {
%0 = "some_op"() {layout_result_0 = #inst_data_16} : () -> vector<16xf16>
%1 = vector.broadcast %0 {layout_result_0 = #inst_data_16x16} : vector<16xf16> to vector<16x16xf16>
- return %1 : vector<16x16xf16>
+ return
}
// CHECK-LABEL: func.func @shapecast_source_conflict
@@ -133,11 +133,11 @@ func.func @broadcast_source_conflict() -> vector<16x16xf16> {
// CHECK-SAME: : vector<2x4x32xf16>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [1, 32]>} : vector<2x4x32xf16> to vector<1x256xf16>
-// CHECK: return %[[SC]] : vector<1x256xf16>
-func.func @shapecast_source_conflict() -> vector<1x256xf16> {
+// CHECK: return
+func.func @shapecast_source_conflict() {
%0 = "some_op"() {layout_result_0 = #inst_data_1x2x16} : () -> vector<2x4x32xf16>
%1 = vector.shape_cast %0 {layout_result_0 = #inst_data_1x32} : vector<2x4x32xf16> to vector<1x256xf16>
- return %1 : vector<1x256xf16>
+ return
}
// CHECK-LABEL: func.func @bitcast_source_conflict
@@ -147,11 +147,11 @@ func.func @shapecast_source_conflict() -> vector<1x256xf16> {
// CHECK-SAME: : vector<32x16xf32>
// CHECK: %[[BC:.*]] = vector.bitcast %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 32]>} : vector<32x16xf32> to vector<32x32xf16>
-// CHECK: return %[[BC]] : vector<32x32xf16>
-func.func @bitcast_source_conflict() -> vector<32x32xf16> {
+// CHECK: return
+func.func @bitcast_source_conflict() {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<32x16xf32>
%1 = vector.bitcast %0 {layout_result_0 = #xegpu.layout<inst_data = [16, 32]>} : vector<32x16xf32> to vector<32x32xf16>
- return %1 : vector<32x32xf16>
+ return
}
// CHECK-LABEL: func.func @multireduction_source_conflict
@@ -167,14 +167,14 @@ func.func @bitcast_source_conflict() -> vector<32x32xf16> {
// CHECK: %[[MR:.*]] = vector.multi_reduction <add>, %[[CVT0]], %[[CVT1]]
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}
// CHECK-SAME: [0] : vector<32x32xf16> to vector<32xf16>
-// CHECK: return %[[MR]] : vector<32xf16>
-func.func @multireduction_source_conflict() -> vector<32xf16> {
+// CHECK: return
+func.func @multireduction_source_conflict() {
%0 = "some_op"() {layout_result_0 = #inst_data_32x16} : () -> vector<32x32xf16>
%acc = arith.constant {layout_result_0 = #inst_data_32} dense<0.0> : vector<32xf16>
%1 = vector.multi_reduction <add>, %0, %acc
{layout_result_0 = #xegpu.slice<#inst_data_16x16, dims = [0]>}
[0] : vector<32x32xf16> to vector<32xf16>
- return %1 : vector<32xf16>
+ return
}
// CHECK-LABEL: func.func @insert_strided_slice_source_conflict
@@ -187,14 +187,14 @@ func.func @multireduction_source_conflict() -> vector<32xf16> {
// CHECK: %[[ISS:.*]] = vector.insert_strided_slice %[[CVT]], %[[CST]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>, offsets = [0, 0, 0], strides = [1, 1]}
// CHECK-SAME: : vector<16x16xf16> into vector<2x32x32xf16>
-// CHECK: return %[[ISS]] : vector<2x32x32xf16>
-func.func @insert_strided_slice_source_conflict() -> vector<2x32x32xf16> {
+// CHECK: return
+func.func @insert_strided_slice_source_conflict() {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<16x16xf16>
%1 = arith.constant { layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>}
dense<0.0> : vector<2x32x32xf16>
%2 = vector.insert_strided_slice %0, %1 {offsets = [0, 0, 0], strides = [1, 1],
layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} : vector<16x16xf16> into vector<2x32x32xf16>
- return %2: vector<2x32x32xf16>
+ return
}
// CHECK-LABEL: func.func @conflict_inside_loop
@@ -209,8 +209,8 @@ func.func @insert_strided_slice_source_conflict() -> vector<2x32x32xf16> {
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x16xf16>
// CHECK: scf.yield %[[ADD]] : vector<16x16xf16>
// CHECK: }
-// CHECK: return %[[FOR]] : vector<16x16xf16>
-func.func @conflict_inside_loop() -> vector<16x16xf16> {
+// CHECK: return
+func.func @conflict_inside_loop() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -220,7 +220,7 @@ func.func @conflict_inside_loop() -> vector<16x16xf16> {
%2 = arith.addf %acc, %1 {layout_result_0 = #inst_data_8x16} : vector<16x16xf16>
scf.yield %2 : vector<16x16xf16>
} {layout_result_0 = #inst_data_8x16}
- return %0 : vector<16x16xf16>
+ return
}
// CHECK-LABEL: func.func @conflict_postop
@@ -237,8 +237,8 @@ func.func @conflict_inside_loop() -> vector<16x16xf16> {
// CHECK-SAME: : vector<16x16xf16>
// CHECK: %[[EXP:.*]] = math.exp %[[CVT]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x16xf16>
-// CHECK: return %[[EXP]] : vector<16x16xf16>
-func.func @conflict_postop() -> vector<16x16xf16> {
+// CHECK: return
+func.func @conflict_postop() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
@@ -249,7 +249,7 @@ func.func @conflict_postop() -> vector<16x16xf16> {
scf.yield %2 : vector<16x16xf16>
} {layout_result_0 = #inst_data_16x16}
%1 = math.exp %0 {layout_result_0 = #inst_data_8x16} : vector<16x16xf16>
- return %1 : vector<16x16xf16>
+ return
}
// CHECK-LABEL: func.func @convert_layout
@@ -293,14 +293,13 @@ func.func @convert_layout() {
// CHECK-NEXT: %[[CVT:.*]] = xegpu.convert_layout %[[BRIDGE]]
// CHECK-SAME: <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [32, 16]>}>
// CHECK-SAME: : vector<32x32xf16>
-// CHECK-NEXT: return %[[CVT]] : vector<32x32xf16>
gpu.module @test_convert_layout_bridge {
-func.func @convert_layout_bridge_input_mismatch() -> vector<32x32xf16> {
+func.func @convert_layout_bridge_input_mismatch() {
%0 = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : () -> vector<32x32xf16>
%1 = xegpu.convert_layout %0
<{input_layout = #xegpu.layout<inst_data = [16, 16]>,
target_layout = #xegpu.layout<inst_data = [32, 16]>}>
: vector<32x32xf16>
- return %1 : vector<32x32xf16>
+ return
}
}
>From 47df312e0da88536457efb3afd95bf91cb6cefe8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 30 Apr 2026 19:51:58 +0000
Subject: [PATCH 3/3] Add layout for extract op
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index d4de88b125efc8..751cafdea615b1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1596,6 +1596,14 @@ xegpu::DistributeLayoutAttr xegpu::inferSourceLayoutFromResultForNonAnchorOp(
// For vector::ExtractStridedSliceOp, simply return result layout
if (dyn_cast<vector::ExtractStridedSliceOp>(op))
return resLayout;
+
+ // For vector::ExtractOp, propagate the result layout to the source vector
+ // operand.
+ if (auto extract = dyn_cast<vector::ExtractOp>(op)) {
+ if (idx != 0)
+ return nullptr;
+ return resLayout;
+ }
// For elementwise operations, all operands must have the same layout as the
// result.
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
More information about the Mlir-commits
mailing list