[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)

Jianhui Li llvmlistbot at llvm.org
Tue Dec 16 09:59:07 PST 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/171758

>From a8f6c51683164884bd744d6a6b978572d821a235 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 11 Dec 2025 03:13:43 +0000
Subject: [PATCH 1/2] adjust the layout for expandedUnitDims and wg-to-sg
 distribution shapecast op

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  4 -
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  2 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 75 ++++++++++++-------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 17 +++++
 4 files changed, 64 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index dc9eb96c169b4..12d1c494a0b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -609,10 +609,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
     propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
   }
-
-  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
-  resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
-                     .setUnitDimData(broadcastUnitDims);
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ca81c3cd7be42..27273ee245cf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1530,12 +1530,12 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
       if (rankDiff == 0) {
         SetVector<int64_t> broadcastUnitDims =
             broadcastOp.computeBroadcastedUnitDims();
-        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
         bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
         if (!isEqualTo)
           return rewriter.notifyMatchFailure(
               warpOp, "For same-rank broadcast, source must be identical to "
                       "adjusted result layouts with unit dims.");
+        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
         sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
       }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be82cda574f1e..ef5da57c5f3b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    VectorType newResultType =
-        VectorType::get(sgShape, resultType.getElementType());
-
-    // TODO: Add check for compatible layouts in layout attr.
-    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    // Check that srcShape and destShape, if they differ, only differ by
+    // expand of unit dimensions.
+    auto srcType = dyn_cast<VectorType>(op.getSource().getType());
     if (!srcType)
       return failure();
 
-    // Check that shape_cast only adds/removes unit dimensions,
-    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
-      // Remove all 1s from both shapes and compare the rest.
-      SmallVector<int64_t> srcNonUnit, dstNonUnit;
-      for (int64_t d : src)
-        if (d != 1)
-          srcNonUnit.push_back(d);
-      for (int64_t d : dst)
-        if (d != 1)
-          dstNonUnit.push_back(d);
-      return srcNonUnit == dstNonUnit;
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+    llvm::SetVector<int64_t> expandedUnitDims;
+
+    // Check if shapes only differ by expanding unit dimensions (like
+    // expand_dims)
+    auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+                                       ArrayRef<int64_t> dst) -> bool {
+      // All unit dimensions in dst that don't appear in src are the expanded
+      // unit dimensions
+      size_t srcIdx = 0;
+      for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+        if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+          srcIdx++;
+        else if (dst[dstIdx] == 1)
+          expandedUnitDims.insert(dstIdx);
+        else
+          return false;
+      return srcIdx == src.size();
     };
 
-    if (!onlyUnitDims(srcType.getShape(), sgShape))
-      return failure();
+    if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+      xegpu::DistributeLayoutAttr sourceLayout =
+          xegpu::getDistributeLayoutAttr(op.getSource());
 
-    // For rank reducing or increasing shape_cast ops, the lower rank layout
-    // must be a slice of higher rank layout.
-    int64_t sourceRank = srcType.getRank();
-    int64_t resultRank = sgShape.size();
-    xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getDistributeLayoutAttr(op.getSource());
-    if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
-      return failure();
-    if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
-      return failure();
+      auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
+        return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
+          return isa<vector::BroadcastOp>(user);
+        });
+      };
+
+      if (!usedByBroadcastOp(op)) {
+        return rewriter.notifyMatchFailure(
+            op, "ShapeCast ops that expand unit dimensions and are used by "
+                "non-broadcast operations are not supported.");
+      }
+      if (!sourceLayout.isSliceOf(layout))
+        return rewriter.notifyMatchFailure(
+            op, "The ShapeCast op only expands dimensions, the result layout "
+                "must be a slice of the input layout, or vice versa.");
+      layout = layout.setUnitDimData(expandedUnitDims);
+      layout = layout.setUnitDimLayout(expandedUnitDims);
+    }
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
 
     SmallVector<Value> newShapeCastOps;
     for (auto src : adaptor.getSource()) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index a8015cced7eb4..7f651ef5fdc14 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -380,4 +380,21 @@ gpu.module @test_1_1_assignment {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32> 
+  // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+  gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+    %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+    %block_id_x = gpu.block_id  x
+    %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+    %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
+    %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    gpu.return
+  }
 }
+

>From 49d66980f904638c83aa47b3637900a851d8d3dc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 16 Dec 2025 17:58:07 +0000
Subject: [PATCH 2/2] address feedback

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp  | 14 +++++++++-----
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir      | 17 +++++++++++++++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir     | 17 -----------------
 3 files changed, 26 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ef5da57c5f3b4..95e27e46d90ab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1137,6 +1137,8 @@ struct WgToSgVectorShapeCastOp
       return srcIdx == src.size();
     };
 
+    xegpu::DistributeLayoutAttr layoutToDistribute = layout;
+
     if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
       xegpu::DistributeLayoutAttr sourceLayout =
           xegpu::getDistributeLayoutAttr(op.getSource());
@@ -1147,20 +1149,22 @@ struct WgToSgVectorShapeCastOp
         });
       };
 
-      if (!usedByBroadcastOp(op)) {
+      if (!usedByBroadcastOp(op))
         return rewriter.notifyMatchFailure(
             op, "ShapeCast ops that expand unit dimensions and are used by "
                 "non-broadcast operations are not supported.");
-      }
+
       if (!sourceLayout.isSliceOf(layout))
         return rewriter.notifyMatchFailure(
             op, "The ShapeCast op only expands dimensions, the result layout "
                 "must be a slice of the input layout, or vice versa.");
-      layout = layout.setUnitDimData(expandedUnitDims);
-      layout = layout.setUnitDimLayout(expandedUnitDims);
+      layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
+      layoutToDistribute =
+          layoutToDistribute.setUnitDimLayout(expandedUnitDims);
     }
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    SmallVector<int64_t> sgShape =
+        getSgShapeAndCount(wgShape, layoutToDistribute).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index c95c64084f3f8..37a76f316e75a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -143,4 +143,21 @@ gpu.module @test_distribution {
     %constant_mask = vector.create_mask %cst16, %cst16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1>
     gpu.return
   } 
+
+  // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8xf32> to vector<8x1xf32>
+  // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+  gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+    %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+    %block_id_x = gpu.block_id  x
+    %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+    %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
+    %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    gpu.return
+  }
+
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7f651ef5fdc14..a8015cced7eb4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -380,21 +380,4 @@ gpu.module @test_1_1_assignment {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
     gpu.return
   }
-
-  // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
-  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32> 
-  // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
-  gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
-    %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
-    %block_id_x = gpu.block_id  x
-    %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
-    %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
-    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
-    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
-    %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
-    %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
-    xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
-    gpu.return
-  }
 }
-



More information about the Mlir-commits mailing list