[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)

Jianhui Li llvmlistbot at llvm.org
Wed Dec 10 20:10:02 PST 2025


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/171758

Add special-case handling for ShapeCast when it expands unit dimensions for a succeeding broadcast op. In this scenario, distribution requires the source layout to be a slice layout, and the result layout is first normalized by setting the expanded unit dimensions to 1 before computing the distributed result shape. In all other cases, ShapeCast is distributed as usual.

This PR also updates the propagation rule for vectors with expanded unit dimensions, allowing them to share the same layout as the result of a broadcast op. This enables correct layout propagation back to the source of the ShapeCast op, as that layout must ultimately be restored as the parent layout of the slice layout.

>From a8f6c51683164884bd744d6a6b978572d821a235 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 11 Dec 2025 03:13:43 +0000
Subject: [PATCH] adjust the layout for expandedUnitDims and wg-to-sg
 distribution shapecast op

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  4 -
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  2 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 75 ++++++++++++-------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 17 +++++
 4 files changed, 64 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index dc9eb96c169b4..12d1c494a0b72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -609,10 +609,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
     propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
   }
-
-  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
-  resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
-                     .setUnitDimData(broadcastUnitDims);
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ca81c3cd7be42..27273ee245cf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1530,12 +1530,12 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
       if (rankDiff == 0) {
         SetVector<int64_t> broadcastUnitDims =
             broadcastOp.computeBroadcastedUnitDims();
-        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
         bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
         if (!isEqualTo)
           return rewriter.notifyMatchFailure(
               warpOp, "For same-rank broadcast, source must be identical to "
                       "adjusted result layouts with unit dims.");
+        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
         sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
       }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index be82cda574f1e..ef5da57c5f3b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    VectorType newResultType =
-        VectorType::get(sgShape, resultType.getElementType());
-
-    // TODO: Add check for compatible layouts in layout attr.
-    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    // Check that srcShape and destShape, if they differ, only differ by
+    // expand of unit dimensions.
+    auto srcType = dyn_cast<VectorType>(op.getSource().getType());
     if (!srcType)
       return failure();
 
-    // Check that shape_cast only adds/removes unit dimensions,
-    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
-      // Remove all 1s from both shapes and compare the rest.
-      SmallVector<int64_t> srcNonUnit, dstNonUnit;
-      for (int64_t d : src)
-        if (d != 1)
-          srcNonUnit.push_back(d);
-      for (int64_t d : dst)
-        if (d != 1)
-          dstNonUnit.push_back(d);
-      return srcNonUnit == dstNonUnit;
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+    llvm::SetVector<int64_t> expandedUnitDims;
+
+    // Check if shapes only differ by expanding unit dimensions (like
+    // expand_dims)
+    auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+                                       ArrayRef<int64_t> dst) -> bool {
+      // All unit dimensions in dst that don't appear in src are the expanded
+      // unit dimensions
+      size_t srcIdx = 0;
+      for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+        if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+          srcIdx++;
+        else if (dst[dstIdx] == 1)
+          expandedUnitDims.insert(dstIdx);
+        else
+          return false;
+      return srcIdx == src.size();
     };
 
-    if (!onlyUnitDims(srcType.getShape(), sgShape))
-      return failure();
+    if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+      xegpu::DistributeLayoutAttr sourceLayout =
+          xegpu::getDistributeLayoutAttr(op.getSource());
 
-    // For rank reducing or increasing shape_cast ops, the lower rank layout
-    // must be a slice of higher rank layout.
-    int64_t sourceRank = srcType.getRank();
-    int64_t resultRank = sgShape.size();
-    xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getDistributeLayoutAttr(op.getSource());
-    if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
-      return failure();
-    if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
-      return failure();
+      auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
+        return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
+          return isa<vector::BroadcastOp>(user);
+        });
+      };
+
+      if (!usedByBroadcastOp(op)) {
+        return rewriter.notifyMatchFailure(
+            op, "ShapeCast ops that expand unit dimensions and are used by "
+                "non-broadcast operations are not supported.");
+      }
+      if (!sourceLayout.isSliceOf(layout))
+        return rewriter.notifyMatchFailure(
+            op, "The ShapeCast op only expands dimensions, the result layout "
+                "must be a slice of the input layout, or vice versa.");
+      layout = layout.setUnitDimData(expandedUnitDims);
+      layout = layout.setUnitDimLayout(expandedUnitDims);
+    }
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
 
     SmallVector<Value> newShapeCastOps;
     for (auto src : adaptor.getSource()) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index a8015cced7eb4..7f651ef5fdc14 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -380,4 +380,21 @@ gpu.module @test_1_1_assignment {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32> 
+  // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+  gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+    %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+    %block_id_x = gpu.block_id  x
+    %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+    %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
+    %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+    gpu.return
+  }
 }
+



More information about the Mlir-commits mailing list