[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)

Artem Kroviakov llvmlistbot at llvm.org
Mon Dec 15 00:43:54 PST 2025


================
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    VectorType newResultType =
-        VectorType::get(sgShape, resultType.getElementType());
-
-    // TODO: Add check for compatible layouts in layout attr.
-    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    // Check that srcShape and destShape, if they differ, only differ by
+    // expand of unit dimensions.
+    auto srcType = dyn_cast<VectorType>(op.getSource().getType());
     if (!srcType)
       return failure();
 
-    // Check that shape_cast only adds/removes unit dimensions,
-    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
-      // Remove all 1s from both shapes and compare the rest.
-      SmallVector<int64_t> srcNonUnit, dstNonUnit;
-      for (int64_t d : src)
-        if (d != 1)
-          srcNonUnit.push_back(d);
-      for (int64_t d : dst)
-        if (d != 1)
-          dstNonUnit.push_back(d);
-      return srcNonUnit == dstNonUnit;
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+    llvm::SetVector<int64_t> expandedUnitDims;
+
+    // Check if shapes only differ by expanding unit dimensions (like
+    // expand_dims)
+    auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+                                       ArrayRef<int64_t> dst) -> bool {
+      // All unit dimensions in dst that don't appear in src are the expanded
+      // unit dimensions
+      size_t srcIdx = 0;
+      for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+        if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+          srcIdx++;
+        else if (dst[dstIdx] == 1)
+          expandedUnitDims.insert(dstIdx);
+        else
+          return false;
+      return srcIdx == src.size();
     };
 
-    if (!onlyUnitDims(srcType.getShape(), sgShape))
-      return failure();
+    if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+      xegpu::DistributeLayoutAttr sourceLayout =
+          xegpu::getDistributeLayoutAttr(op.getSource());
 
-    // For rank reducing or increasing shape_cast ops, the lower rank layout
-    // must be a slice of higher rank layout.
-    int64_t sourceRank = srcType.getRank();
-    int64_t resultRank = sgShape.size();
-    xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getDistributeLayoutAttr(op.getSource());
-    if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
-      return failure();
-    if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
-      return failure();
+      auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
+        return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
+          return isa<vector::BroadcastOp>(user);
+        });
+      };
+
+      if (!usedByBroadcastOp(op)) {
+        return rewriter.notifyMatchFailure(
+            op, "ShapeCast ops that expand unit dimensions and are used by "
+                "non-broadcast operations are not supported.");
+      }
+      if (!sourceLayout.isSliceOf(layout))
+        return rewriter.notifyMatchFailure(
+            op, "The ShapeCast op only expands dimensions, the result layout "
+                "must be a slice of the input layout, or vice versa.");
+      layout = layout.setUnitDimData(expandedUnitDims);
----------------
akroviakov wrote:

The new restriction adds clarity, but the layout manipulation in distribution passes is discouraged, AFAIK.

https://github.com/llvm/llvm-project/pull/171758


More information about the Mlir-commits mailing list