[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for Convert Layout from Wg to Sg (PR #178922)
Nishant Patel
llvmlistbot at llvm.org
Thu Feb 12 08:14:47 PST 2026
================
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
struct WgToSgConvertLayoutOp
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
- auto input = op.getInputLayout();
- auto target = op.getTargetLayout();
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ auto inputLayout = op.getInputLayout();
+ auto targetLayout = op.getTargetLayout();
- if (!input || !target || !input.isForWorkgroup() ||
- !target.isForWorkgroup())
+ if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+ !targetLayout.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
- SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr inputOrder = input.getOrder();
- SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr targetOrder = target.getOrder();
-
- // TODO: currently we only support for optimal case, where input and
- // output has the same sg_layout and sg_data, so SLM is not involved.
- if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
- inputOrder != targetOrder)
+ SmallVector<int64_t> inputSgLayout =
+ inputLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> targetSgLayout =
+ targetLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+ auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+ if (shape.size() <= 2)
+ return true;
+ for (size_t i = 0; i + 2 < shape.size(); ++i)
+ if (shape[i] != 1)
+ return false;
+ return true;
+ };
+
+ if (wgShape.size() > 2) {
+ if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+ return rewriter.notifyMatchFailure(
+ op, "rank > 2 requires unit leading dims for sg_data");
+ }
+
+ // Fast path: if sg_layout and sg_data are identical, no SLM needed
+ if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+ inputLayout = inputLayout.dropSgLayoutAndData();
+ targetLayout = targetLayout.dropSgLayoutAndData();
+
+ SmallVector<Value> newOps(adaptor.getSource());
+ if (inputLayout && targetLayout) {
+ for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+ auto newOp = xegpu::ConvertLayoutOp::create(
+ rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+ newOps[i] = newOp;
+ }
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+
+ // SLM path: layouts differ, need cross-subgroup data redistribution
+ Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+ SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+ // Calculate SLM size requirements
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+ // Allocate SLM
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
----------------
nbpatel wrote:
not sure I understand the question..where is the cf here?
https://github.com/llvm/llvm-project/pull/178922
More information about the Mlir-commits
mailing list