[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)
Jianhui Li
llvmlistbot at llvm.org
Tue Dec 16 09:00:58 PST 2025
================
@@ -1111,41 +1111,58 @@ struct WgToSgVectorShapeCastOp
if (!layout || !layout.isForWorkgroup())
return failure();
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- VectorType newResultType =
- VectorType::get(sgShape, resultType.getElementType());
-
- // TODO: Add check for compatible layouts in layout attr.
- auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+ // Check that srcShape and destShape, if they differ, only differ by
+ // expand of unit dimensions.
+ auto srcType = dyn_cast<VectorType>(op.getSource().getType());
if (!srcType)
return failure();
- // Check that shape_cast only adds/removes unit dimensions,
- auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
- // Remove all 1s from both shapes and compare the rest.
- SmallVector<int64_t> srcNonUnit, dstNonUnit;
- for (int64_t d : src)
- if (d != 1)
- srcNonUnit.push_back(d);
- for (int64_t d : dst)
- if (d != 1)
- dstNonUnit.push_back(d);
- return srcNonUnit == dstNonUnit;
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ llvm::SetVector<int64_t> expandedUnitDims;
+
+ // Check if shapes only differ by expanding unit dimensions (like
+ // expand_dims)
+ auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+ ArrayRef<int64_t> dst) -> bool {
+ // All unit dimensions in dst that don't appear in src are the expanded
+ // unit dimensions
+ size_t srcIdx = 0;
+ for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+ if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+ srcIdx++;
+ else if (dst[dstIdx] == 1)
+ expandedUnitDims.insert(dstIdx);
+ else
+ return false;
+ return srcIdx == src.size();
};
- if (!onlyUnitDims(srcType.getShape(), sgShape))
- return failure();
+ if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(op.getSource());
- // For rank reducing or increasing shape_cast ops, the lower rank layout
- // must be a slice of higher rank layout.
- int64_t sourceRank = srcType.getRank();
- int64_t resultRank = sgShape.size();
- xegpu::DistributeLayoutAttr sourceLayout =
- xegpu::getDistributeLayoutAttr(op.getSource());
- if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
- return failure();
- if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
- return failure();
+ auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
+ return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
+ return isa<vector::BroadcastOp>(user);
+ });
+ };
+
+ if (!usedByBroadcastOp(op)) {
+ return rewriter.notifyMatchFailure(
+ op, "ShapeCast ops that expand unit dimensions and are used by "
+ "non-broadcast operations are not supported.");
+ }
+ if (!sourceLayout.isSliceOf(layout))
+ return rewriter.notifyMatchFailure(
+ op, "The ShapeCast op only expands dimensions, the result layout "
+ "must be a slice of the input layout, or vice versa.");
+ layout = layout.setUnitDimData(expandedUnitDims);
----------------
Jianhui-Li wrote:
The intent here is not to change the layout associated with the op. It is to modify the specific layout before feeding to getSgShapeAndCount function (since that function doesn't handle unitDim). I will rename the layout to layoutForDistribution to avoid the confusion with the layout being set to the op.
https://github.com/llvm/llvm-project/pull/171758
More information about the Mlir-commits
mailing list