[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)
Charitha Saumya
llvmlistbot at llvm.org
Tue Dec 16 10:11:40 PST 2025
================
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
if (!layout || !layout.isForWorkgroup())
return failure();
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- VectorType newResultType =
- VectorType::get(sgShape, resultType.getElementType());
-
- // TODO: Add check for compatible layouts in layout attr.
- auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+ // Check that srcShape and destShape, if they differ, only differ by
+ // expand of unit dimensions.
+ auto srcType = dyn_cast<VectorType>(op.getSource().getType());
if (!srcType)
return failure();
- // Check that shape_cast only adds/removes unit dimensions,
- auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
- // Remove all 1s from both shapes and compare the rest.
- SmallVector<int64_t> srcNonUnit, dstNonUnit;
- for (int64_t d : src)
- if (d != 1)
- srcNonUnit.push_back(d);
- for (int64_t d : dst)
- if (d != 1)
- dstNonUnit.push_back(d);
- return srcNonUnit == dstNonUnit;
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ llvm::SetVector<int64_t> expandedUnitDims;
+
+ // Check if shapes only differ by expanding unit dimensions (like
+ // expand_dims)
+ auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.insert(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
};
- if (!onlyUnitDims(srcType.getShape(), sgShape))
- return failure();
+ if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getSource());
- // For rank reducing or increasing shape_cast ops, the lower rank layout
- // must be a slice of higher rank layout.
- int64_t sourceRank = srcType.getRank();
- int64_t resultRank = sgShape.size();
- xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getDistributeLayoutAttr(op.getSource());
- if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
- return failure();
- if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
- return failure();
+ auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
----------------
charithaintc wrote:
what if this is passed into a scf.for as a iter arg and used in broadcast inside scf.for?
https://github.com/llvm/llvm-project/pull/171758
More information about the Mlir-commits
mailing list