[Mlir-commits] [mlir] [MLIR][XeGPU] Wrap layout with a slice attr when propagating broadcast (PR #169054)

Artem Kroviakov llvmlistbot at llvm.org
Fri Nov 21 07:36:22 PST 2025


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/169054

The input vector of a broadcast operation has a lower rank than the broadcast result. In xegpu terms, this means that the input data is _sliced_ (in the unit dimension). 
Currently, the broadcast simply passes the result layout to the operand, which is incorrect. 
This PR wraps the result layout in a slice attribute.

For shape cast changes, I assume 
```
  int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
```
in the propagation code implicitly considers only `Nx1` or `1xN` kind of shape cast results and does not need further slicing.

>From 0207ac98ace3b22e55613c6d84759e5c2ce065aa Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 21 Nov 2025 15:19:04 +0000
Subject: [PATCH] [MLIR][XeGPU] Wrap layout with a slice attr when propagating
 broadcast

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 23 +++++++++++----
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 29 +++++++++++++++++--
 2 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3a780abd3f12..2d8b5150d96fc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -572,8 +572,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
                           "one broadcasted dimension.");
     return;
   }
+  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+      broadcast->getContext(),
+      cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+      DenseI64ArrayAttr::get(broadcast->getContext(), {broadcastUnitDims[0]}));
   // Propagate the result layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
 }
 
 void LayoutInfoPropagation::visitShapeCastOp(
@@ -593,10 +597,19 @@ void LayoutInfoPropagation::visitShapeCastOp(
     return;
   }
   int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
-  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-      shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
-      DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
+  LayoutInfo operandLayout;
+  if (auto sliceResultAttr = dyn_cast<xegpu::SliceAttr>(resultLayout.get())) {
+    auto sliceDims = sliceResultAttr.getDims().asArrayRef();
+    if (sliceDims.size() == 1 && sliceDims[0] == slicedDim)
+      operandLayout = resultLayout;
+  } else {
+    xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+        shapeCast->getContext(),
+        cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+        DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
+    operandLayout = LayoutInfo(sliceLayout);
+  }
+  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
 }
 
 /// Propagate the layout of the result tensor to the source tensor descriptor
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index eb004932af4be..58ccb90f0bdb1 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -603,7 +603,7 @@ gpu.module @test {
 // CHECK-SAME:      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
 // CHECK-NEXT:    %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
 // CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
-// CHECK-NEXT:    %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-NEXT:    %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} :
 // CHECK-SAME:       vector<16xf16> to vector<1x16xf16>
 func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
   %c0 = arith.constant 0 : index
@@ -626,7 +626,7 @@ gpu.module @test {
 // CHECK-NEXT:     %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
 // CHECK-SAME:        {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1]
 // CHECK-SAME:        vector<16x16xf16> to vector<16xf16>
-// CHECK-NEXT:     %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-NEXT:     %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} :
 // CHECK-SAME:        vector<16xf16> to vector<16x1xf16>
 func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
   %c0 = arith.constant 0 : index
@@ -639,3 +639,28 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_slice_operand(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: i64) {
+// CHECK: %[[CST_0_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<1xindex>
+// CHECK: %[[CST_TRUE_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<1xi1>
+// CHECK: %[[CST_TRUE_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<32xi1>
+// CHECK: %[[CST_0_32:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<32xindex>
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%[[CST_0_1]]], %[[CST_TRUE_1]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>} :
+// CHECK-SAME: i64, vector<1xindex>, vector<1xi1> -> vector<1xf32>
+// CHECK: %[[BCASTED:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<1xf32> to vector<32xf32>
+// CHECK:  xegpu.store %[[BCASTED]], %[[ARG0]][%[[CST_0_32]]], %[[CST_TRUE_32]]  : vector<32xf32>, i64, vector<32xindex>, vector<32xi1>
+func.func @vector_broadcast_slice_operand(%arg0: i64) {
+  %offsets = arith.constant dense<0> : vector<1xindex>
+  %cst_4 = arith.constant dense<1> : vector<1xi1>
+  %cst_2 = arith.constant dense<1> : vector<32xi1>
+  %offsets_1 = arith.constant dense<0> : vector<32xindex>
+  %1 = xegpu.load %arg0[%offsets], %cst_4 : i64, vector<1xindex>, vector<1xi1> -> vector<1xf32>
+  %2 = vector.broadcast %1 : vector<1xf32> to vector<32xf32>
+  xegpu.store %2, %arg0[%offsets_1], %cst_2 : vector<32xf32>, i64, vector<32xindex>, vector<32xi1>
+  return
+}
+}



More information about the Mlir-commits mailing list