[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for Convert Layout from Wg to Sg (PR #178922)

Nishant Patel llvmlistbot at llvm.org
Thu Feb 12 08:14:47 PST 2026


================
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 struct WgToSgConvertLayoutOp
     : public OpConversionPattern<xegpu::ConvertLayoutOp> {
   using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
 
-    auto input = op.getInputLayout();
-    auto target = op.getTargetLayout();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    auto inputLayout = op.getInputLayout();
+    auto targetLayout = op.getTargetLayout();
 
-    if (!input || !target || !input.isForWorkgroup() ||
-        !target.isForWorkgroup())
+    if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+        !targetLayout.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
-    SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr inputOrder = input.getOrder();
-    SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
-    DenseI32ArrayAttr targetOrder = target.getOrder();
-
-    // TODO: currently we only support for optimal case, where input and
-    // output has the same sg_layout and sg_data, so SLM is not involved.
-    if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
-        inputOrder != targetOrder)
+    SmallVector<int64_t> inputSgLayout =
+        inputLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> targetSgLayout =
+        targetLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+    auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+      if (shape.size() <= 2)
+        return true;
+      for (size_t i = 0; i + 2 < shape.size(); ++i)
+        if (shape[i] != 1)
+          return false;
+      return true;
+    };
+
+    if (wgShape.size() > 2) {
+      if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+        return rewriter.notifyMatchFailure(
+            op, "rank > 2 requires unit leading dims for sg_data");
+    }
+
+    // Fast path: if sg_layout and sg_data are identical, no SLM needed
+    if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+      inputLayout = inputLayout.dropSgLayoutAndData();
+      targetLayout = targetLayout.dropSgLayoutAndData();
+
+      SmallVector<Value> newOps(adaptor.getSource());
+      if (inputLayout && targetLayout) {
+        for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+          auto newOp = xegpu::ConvertLayoutOp::create(
+              rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+          newOps[i] = newOp;
+        }
+      }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+      return success();
+    }
+
+    // SLM path: layouts differ, need cross-subgroup data redistribution
+    Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+    SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+    // Calculate SLM size requirements
+    auto bitWidth = elemTy.getIntOrFloatBitWidth();
+    auto bytesPerElement = bitWidth / 8;
+    auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+    // Allocate SLM
+    auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+    auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
----------------
nbpatel wrote:

not sure I understand the question..where is the cf here?

https://github.com/llvm/llvm-project/pull/178922


More information about the Mlir-commits mailing list