[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for Convert Layout from Wg to Sg (PR #178922)
Nishant Patel
llvmlistbot at llvm.org
Fri Feb 13 08:19:10 PST 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/178922
>From de5c4b67dde500560422435b0e222291da148aae Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 23 Jan 2026 15:22:34 +0000
Subject: [PATCH 1/6] Add convert layout via SLM
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 152 ++++++++++++++----
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 37 +++++
2 files changed, 162 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8328c2797be4f..450dc34a4263d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -604,44 +604,142 @@ struct WgToSgElementwiseOp : public ConversionPattern {
struct WgToSgConvertLayoutOp
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: currently, we only support LayoutAttr
- auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
- auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
+ Location loc = op.getLoc();
+
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ auto inputLayout = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+ auto targetLayout = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
- if (!input || !target || !input.isForWorkgroup() ||
- !target.isForWorkgroup())
+ if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+ !targetLayout.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
- DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
- DenseI32ArrayAttr inputSgData = input.getSgData();
- DenseI32ArrayAttr inputOrder = input.getOrder();
- DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
- DenseI32ArrayAttr targetSgData = target.getSgData();
- DenseI32ArrayAttr targetOrder = target.getOrder();
-
- // TODO: currently we only support for optimal case, where input and
- // output has the same sg_layout and sg_data, so SLM is not involved.
- if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
- inputOrder != targetOrder)
+ SmallVector<int64_t> inputSgLayout =
+ inputLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> targetSgLayout =
+ targetLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+ // if sg_layout and sg_data are identical, no SLM needed
+ if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+ inputLayout = inputLayout.dropSgLayoutAndData();
+ targetLayout = targetLayout.dropSgLayoutAndData();
+
+ SmallVector<Value> newOps(adaptor.getSource());
+ if (inputLayout && targetLayout) {
+ for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+ auto newOp = xegpu::ConvertLayoutOp::create(
+ rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+ newOps[i] = newOp;
+ }
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+
+ // SLM path: layouts differ, need cross-subgroup data redistribution
+ auto srcVectorType = cast<VectorType>(op.getSource().getType());
+ Type elemTy = srcVectorType.getElementType();
+
+ // Calculate SLM size requirements
+ auto slmShape = wgShape;
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+ // Allocate SLM
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+ elemTy, nullptr);
+ auto memDesc =
+ xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+ auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+ rewriter.getIndexType(), nullptr);
+
+ // STORE PHASE: Store input data to SLM using input layout
+ // Convert input sg_layout to Values for delinearizeIndex
+ SmallVector<Value> inputSgLayoutValues;
+ for (int64_t dim : inputSgLayout) {
+ inputSgLayoutValues.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dim));
+ }
+
+ auto inputSgIdsResult = affine::delinearizeIndex(
+ rewriter, loc, sgId.getResult(), inputSgLayoutValues);
+ if (failed(inputSgIdsResult))
return failure();
+ SmallVector<Value> inputSgIds = *inputSgIdsResult;
+
+ // Calculate store offsets based on input subgroup position and sg_data
+ SmallVector<Value> storeOffsets;
+ for (size_t i = 0; i < inputSgIds.size(); ++i) {
+ Value sgDataVal =
+ arith::ConstantIndexOp::create(rewriter, loc, inputSgData[i]);
+ Value offset =
+ arith::MulIOp::create(rewriter, loc, inputSgIds[i], sgDataVal);
+ storeOffsets.push_back(offset);
+ }
- input = input.dropSgLayoutAndData();
- target = target.dropSgLayoutAndData();
+ SmallVector<OpFoldResult> storeMatrixOffsets(storeOffsets.begin(),
+ storeOffsets.end());
- SmallVector<Value> newOps(adaptor.getSource());
- if (input && target) {
- // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
- for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
- auto newOp = xegpu::ConvertLayoutOp::create(
- rewriter, op.getLoc(), src.getType(), src, input, target);
- newOps[i] = newOp;
- }
+ for (auto src : adaptor.getSource()) {
+ xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+ storeMatrixOffsets,
+ targetLayout.dropSgLayoutAndData());
}
- rewriter.replaceOpWithMultiple(op, {newOps});
+
+ gpu::BarrierOp::create(rewriter, loc);
+
+ // LOAD PHASE: Load data from SLM using target layout
+ // Convert target sg_layout to Values for delinearizeIndex
+ SmallVector<Value> targetSgLayoutValues;
+ for (int64_t dim : targetSgLayout) {
+ targetSgLayoutValues.push_back(
+ arith::ConstantIndexOp::create(rewriter, loc, dim));
+ }
+
+ auto targetSgIdsResult = affine::delinearizeIndex(
+ rewriter, loc, sgId.getResult(), targetSgLayoutValues);
+ if (failed(targetSgIdsResult))
+ return failure();
+ SmallVector<Value> targetSgIds = *targetSgIdsResult;
+
+ // Calculate load offsets based on target subgroup position and sg_data
+ SmallVector<Value> loadOffsets;
+ for (size_t i = 0; i < targetSgIds.size(); ++i) {
+ Value sgDataVal =
+ arith::ConstantIndexOp::create(rewriter, loc, targetSgData[i]);
+ Value offset =
+ arith::MulIOp::create(rewriter, loc, targetSgIds[i], sgDataVal);
+ loadOffsets.push_back(offset);
+ }
+
+ SmallVector<OpFoldResult> loadMatrixOffsets(loadOffsets.begin(),
+ loadOffsets.end());
+
+ VectorType targetVectorType = VectorType::get(targetSgData, elemTy);
+
+ SmallVector<Value> loadedVectors;
+ for (size_t i = 0; i < adaptor.getSource().size(); ++i) {
+ auto loadOp =
+ xegpu::LoadMatrixOp::create(rewriter, loc, targetVectorType,
+ memDesc.getResult(), loadMatrixOffsets,
+ /*layout=*/nullptr);
+ loadedVectors.push_back(loadOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {loadedVectors});
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index ff0946d100a63..06faf167803fc 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -8,6 +8,9 @@
// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
+// CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 floordiv 16)>
+// CHECK-DAG: #map9 = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK-DAG: #map10 = affine_map<()[s0] -> (s0 mod 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -838,4 +841,38 @@ gpu.module @test_distribution {
-> vector<256x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: convert_layout_slm
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+ gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+ // CHECK-DAG: %[[SGID_:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map8()[%[[SGID_]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map6()[%[[SGID_]]]
+ // CHECK-DAG: %[[ROW_OFF:.*]] = arith.muli %[[AFFINE1]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[COL_OFF:.*]] = arith.muli %[[AFFINE2]], %[[C16:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[ROW_OFF]], %[[COL_OFF]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map9()[%[[SGID_]]]
+ // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map10()[%[[SGID_]]]
+ // CHECK-DAG: %[[ROW_OFF2:.*]] = arith.muli %[[AFFINE3]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[COL_OFF2:.*]] = arith.muli %[[AFFINE4]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[ROW_OFF2]], %[[COL_OFF2]]] : !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+ %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+ gpu.return
+ }
}
>From 3fa629332a679a27aab27270b33f322e6983593d Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 26 Jan 2026 22:58:01 +0000
Subject: [PATCH 2/6] Add nD support
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 10 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 112 ++++++++----------
mlir/test/Dialect/XeGPU/invalid.mlir | 4 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 72 ++++++++---
5 files changed, 112 insertions(+), 90 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}];
let description = [{
- This operation loads a 2D block of data from shared local memory (SLM) as specified
- by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation loads an nD block of data from shared local memory (SLM) as specified
+ by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)}];
let description = [{
- This operation stores a 2D `data` fragment into the shared local memory region
- specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation stores an nD `data` fragment into the shared local memory region
+ specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
return success();
}
- if (mdescTy.getRank() != 2)
- return emitError() << "mem_desc must be 2D.";
+ if (mdescTy.getRank() < 2)
+ return emitError() << "mem_desc must be 2D or greater.";
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 44659bf1dec31..8dbca952cc8c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -627,7 +627,22 @@ struct WgToSgConvertLayoutOp
targetLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
- // if sg_layout and sg_data are identical, no SLM needed
+ auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+ if (shape.size() <= 2)
+ return true;
+ for (size_t i = 0; i + 2 < shape.size(); ++i)
+ if (shape[i] != 1)
+ return false;
+ return true;
+ };
+
+ if (wgShape.size() > 2) {
+ if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+ return rewriter.notifyMatchFailure(
+ op, "rank > 2 requires unit leading dims for sg_data");
+ }
+
+ // Fast path: if sg_layout and sg_data are identical, no SLM needed
if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
inputLayout = inputLayout.dropSgLayoutAndData();
targetLayout = targetLayout.dropSgLayoutAndData();
@@ -645,11 +660,11 @@ struct WgToSgConvertLayoutOp
}
// SLM path: layouts differ, need cross-subgroup data redistribution
- auto srcVectorType = cast<VectorType>(op.getSource().getType());
- Type elemTy = srcVectorType.getElementType();
+ Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+ SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
// Calculate SLM size requirements
- auto slmShape = wgShape;
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto bytesPerElement = bitWidth / 8;
auto slmSize = computeProduct(slmShape) * bytesPerElement;
@@ -666,80 +681,47 @@ struct WgToSgConvertLayoutOp
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
rewriter.getIndexType(), nullptr);
- // STORE PHASE: Store input data to SLM using input layout
- // Convert input sg_layout to Values for delinearizeIndex
- SmallVector<Value> inputSgLayoutValues;
- for (int64_t dim : inputSgLayout) {
- inputSgLayoutValues.push_back(
- arith::ConstantIndexOp::create(rewriter, loc, dim));
- }
-
- auto inputSgIdsResult = affine::delinearizeIndex(
- rewriter, loc, sgId.getResult(), inputSgLayoutValues);
- if (failed(inputSgIdsResult))
+ // STORE PHASE: Each subgroup stores in SLM using input layout
+ auto storeCoords = inputLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(storeCoords))
return failure();
- SmallVector<Value> inputSgIds = *inputSgIdsResult;
-
- // Calculate store offsets based on input subgroup position and sg_data
- SmallVector<Value> storeOffsets;
- for (size_t i = 0; i < inputSgIds.size(); ++i) {
- Value sgDataVal =
- arith::ConstantIndexOp::create(rewriter, loc, inputSgData[i]);
- Value offset =
- arith::MulIOp::create(rewriter, loc, inputSgIds[i], sgDataVal);
- storeOffsets.push_back(offset);
- }
-
- SmallVector<OpFoldResult> storeMatrixOffsets(storeOffsets.begin(),
- storeOffsets.end());
- for (auto src : adaptor.getSource()) {
+ // Store to SLM
+ for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+ SmallVector<OpFoldResult> storeMatrixOffsets;
+ for (Value coord : coords) {
+ storeMatrixOffsets.push_back(coord);
+ }
xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
- storeMatrixOffsets,
- targetLayout.dropSgLayoutAndData());
+ storeMatrixOffsets, nullptr /*layout*/);
}
gpu::BarrierOp::create(rewriter, loc);
- // LOAD PHASE: Load data from SLM using target layout
- // Convert target sg_layout to Values for delinearizeIndex
- SmallVector<Value> targetSgLayoutValues;
- for (int64_t dim : targetSgLayout) {
- targetSgLayoutValues.push_back(
- arith::ConstantIndexOp::create(rewriter, loc, dim));
- }
-
- auto targetSgIdsResult = affine::delinearizeIndex(
- rewriter, loc, sgId.getResult(), targetSgLayoutValues);
- if (failed(targetSgIdsResult))
+ // LOAD PHASE: Each target subgroup loads from SLM using target layout
+ auto loadCoords = targetLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(loadCoords))
return failure();
- SmallVector<Value> targetSgIds = *targetSgIdsResult;
-
- // Calculate load offsets based on target subgroup position and sg_data
- SmallVector<Value> loadOffsets;
- for (size_t i = 0; i < targetSgIds.size(); ++i) {
- Value sgDataVal =
- arith::ConstantIndexOp::create(rewriter, loc, targetSgData[i]);
- Value offset =
- arith::MulIOp::create(rewriter, loc, targetSgIds[i], sgDataVal);
- loadOffsets.push_back(offset);
- }
- SmallVector<OpFoldResult> loadMatrixOffsets(loadOffsets.begin(),
- loadOffsets.end());
+ VectorType loadType = VectorType::get(targetSgData, elemTy);
- VectorType targetVectorType = VectorType::get(targetSgData, elemTy);
+ // Load vectors from SLM
+ SmallVector<Value> finalResults;
+ for (auto coords : *loadCoords) {
+ SmallVector<OpFoldResult> loadMatrixOffsets;
+ for (Value coord : coords) {
+ loadMatrixOffsets.push_back(coord);
+ }
+ auto loadOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+ targetLayout.dropSgLayoutAndData());
- SmallVector<Value> loadedVectors;
- for (size_t i = 0; i < adaptor.getSource().size(); ++i) {
- auto loadOp =
- xegpu::LoadMatrixOp::create(rewriter, loc, targetVectorType,
- memDesc.getResult(), loadMatrixOffsets,
- /*layout=*/nullptr);
- loadedVectors.push_back(loadOp.getResult());
+ finalResults.push_back(loadOp.getResult());
}
- rewriter.replaceOpWithMultiple(op, {loadedVectors});
+ rewriter.replaceOpWithMultiple(op, {finalResults});
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
// -----
func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
- // expected-error at +1 {{mem_desc must be 2D}}
+ // expected-error at +1 {{mem_desc must be 2D or greater}}
%data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
return
}
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
// -----
func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
- // expected-error at +1 {{mem_desc must be 2D.}}
+ // expected-error at +1 {{mem_desc must be 2D or greater}}
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index d0419d1a11a8d..d4b611c713674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -8,9 +8,6 @@
// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
-// CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 floordiv 16)>
-// CHECK-DAG: #map9 = affine_map<()[s0] -> (s0 floordiv 8)>
-// CHECK-DAG: #map10 = affine_map<()[s0] -> (s0 mod 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -846,7 +843,7 @@ gpu.module @test_distribution {
// CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
// CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
// CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
// CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
@@ -857,24 +854,69 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
// CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
- // CHECK-DAG: %[[SGID_:.*]] = gpu.subgroup_id : index
- // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map8()[%[[SGID_]]]
- // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map6()[%[[SGID_]]]
- // CHECK-DAG: %[[ROW_OFF:.*]] = arith.muli %[[AFFINE1]], %[[C32:.*]] : index
- // CHECK-DAG: %[[COL_OFF:.*]] = arith.muli %[[AFFINE2]], %[[C16:.*]] : index
- // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[ROW_OFF]], %[[COL_OFF]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+ // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
// CHECK-DAG: gpu.barrier
- // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map9()[%[[SGID_]]]
- // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map10()[%[[SGID_]]]
- // CHECK-DAG: %[[ROW_OFF2:.*]] = arith.muli %[[AFFINE3]], %[[C16:.*]] : index
- // CHECK-DAG: %[[COL_OFF2:.*]] = arith.muli %[[AFFINE4]], %[[C32:.*]] : index
- // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[ROW_OFF2]], %[[COL_OFF2]]] : !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
%0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
%1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
%2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
gpu.return
}
+
+ gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+ %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+ gpu.return
+ }
+
// CHECK-LABEL: distribute_nested_slice
// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
// CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>
>From 1bf0df3c93110565642e3db583699179cc77de73 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 12 Feb 2026 15:57:03 +0000
Subject: [PATCH 3/6] Clean up
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 15 ---------------
1 file changed, 15 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8dbca952cc8c2..2f240b974892e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -627,21 +627,6 @@ struct WgToSgConvertLayoutOp
targetLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
- auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
- if (shape.size() <= 2)
- return true;
- for (size_t i = 0; i + 2 < shape.size(); ++i)
- if (shape[i] != 1)
- return false;
- return true;
- };
-
- if (wgShape.size() > 2) {
- if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
- return rewriter.notifyMatchFailure(
- op, "rank > 2 requires unit leading dims for sg_data");
- }
-
// Fast path: if sg_layout and sg_data are identical, no SLM needed
if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
inputLayout = inputLayout.dropSgLayoutAndData();
>From 9c436b36478139ad12a66e32765bb3e83a3b1160 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 12 Feb 2026 16:12:24 +0000
Subject: [PATCH 4/6] address feedback
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 3 +-
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 32 +++++++++++++++++--
2 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d3c666d00e99e..88290d2e5bb9c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -626,7 +626,8 @@ struct WgToSgConvertLayoutOp
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
// Fast path: if sg_layout and sg_data are identical, no SLM needed
- if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+ if (llvm::equal(inputSgLayout, targetSgLayout) &&
+ llvm::equal(inputSgData, targetSgData)) {
inputLayout = inputLayout.dropSgLayoutAndData();
targetLayout = targetLayout.dropSgLayoutAndData();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 8606e89616c91..e2e94c5f0300f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -828,6 +828,34 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: convert_layout_no_slm
+ gpu.func @convert_layout_no_slm(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %c256 = arith.constant 256 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = arith.muli %block_id_x, %c256 overflow<nsw> : index
+ %1 = arith.muli %block_id_y, %c256 overflow<nsw> : index
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
+ %3 = xegpu.load_nd %2[%0, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>> -> vector<256x256xf32>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>>
+ %5 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>>
+ %6 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %3) -> (vector<256x256xf32>) {
+ %7 = xegpu.load_nd %4[%0, %arg3] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>}> : !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+ %8 = xegpu.load_nd %5[%arg3, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>}> : !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<32x256xf16>
+ // CHECK: %[[CONVERT_A:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<32x32xf16>
+ // CHECK: %[[CONVERT_B:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}> : vector<32x32xf16>
+ %9 = xegpu.convert_layout %7 <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : vector<256x32xf16>
+ %10 = xegpu.convert_layout %8 <{input_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>}> : vector<32x256xf16>
+ %11 = xegpu.dpas %9, %10, %arg4 {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [16, 16]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32>
+ scf.yield %11 : vector<256x256xf32>
+ } {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}
+ xegpu.store_nd %6, %2[%0, %1] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>}> : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
+ gpu.return
+ }
+
// CHECK-LABEL: convert_layout_slm
// CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
@@ -869,8 +897,8 @@ gpu.module @test_distribution {
}
gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
- // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
- // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x32x16xindex>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<1x32x16xi1>
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
// CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
>From af1bbfff62705b70ba8700615de8a87687dd1a5d Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 12 Feb 2026 22:12:24 +0000
Subject: [PATCH 5/6] Clean up isEqualTo utility
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 4 +++
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 28 +++++++++++++------
.../XeGPU/Transforms/XeGPULayoutImpl.h | 2 --
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 18 ------------
.../Transforms/XeGPUSubgroupDistribute.cpp | 3 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 3 +-
6 files changed, 26 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 7badfaf4a8216..f85dde0b98c8e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -27,6 +27,10 @@ class TensorDescType;
class DistributeLayoutAttr;
class LayoutAttr;
class SliceAttr;
+
+/// Specifies the level of a layout hierarchy for comparison or propagation.
+enum class LayoutKind { Lane, InstData, Subgroup };
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 898fb7e1d8e6d..71d0a4e711d31 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -311,10 +311,27 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
/*methodName=*/"isSliceOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
- InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
+ InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout
+ at a specific level of the layout hierarchy.}],
/*retTy=*/"bool",
/*methodName=*/"isEqualTo",
- /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
+ "xegpu::LayoutKind": $level),
+ /*methodBody=*/[{
+ if (!other)
+ return false;
+ switch (level) {
+ case xegpu::LayoutKind::Subgroup:
+ return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
+ $_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
+ case xegpu::LayoutKind::InstData:
+ return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
+ case xegpu::LayoutKind::Lane:
+ return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
+ $_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+ }
+ return false;
+ }]>
];
}
@@ -545,10 +562,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
-
- /// Check if this is identical to some other layout.
- bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
-
}];
let assemblyFormat = "`<` struct(params) `>`";
@@ -736,9 +749,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Drop the slice dims to get the original layout.
SliceAttr dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop);
-
- /// Check if this is identical to some other layout.
- bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
}];
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 0d5210f07f05a..35dc9541b5755 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -33,8 +33,6 @@ class TensorDescType;
namespace xegpu {
-enum class LayoutKind { Lane, InstData, Subgroup };
-
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
LayoutKind layoutKind, bool printOnly = false);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d99557e68f0ec..5e790d34171e8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -389,13 +389,6 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
-bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
- if (dyn_cast<xegpu::SliceAttr>(other))
- return false;
-
- return *this == dyn_cast<xegpu::LayoutAttr>(other);
-}
-
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@@ -773,17 +766,6 @@ xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
return sliceWithoutDims;
}
-bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
- if (dyn_cast<xegpu::LayoutAttr>(other))
- return false;
-
- auto flattenedThis = flatten();
- auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
-
- return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
- (flattenedThis.getDims() == flattenedOther.getDims()));
-}
-
// Helper function to adjust dimensions from sliced space to parent space
// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aa1dfaa9e0fda..48426016dcb17 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1536,7 +1536,8 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
broadcastUnitDimsSet.end());
- bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+ bool isEqualTo =
+ sourceLayout.isEqualTo(resultLayout, xegpu::LayoutKind::Lane);
if (!isEqualTo)
return rewriter.notifyMatchFailure(
warpOp, "For same-rank broadcast, source must be identical to "
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 88290d2e5bb9c..4015b333cec02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -626,8 +626,7 @@ struct WgToSgConvertLayoutOp
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
// Fast path: if sg_layout and sg_data are identical, no SLM needed
- if (llvm::equal(inputSgLayout, targetSgLayout) &&
- llvm::equal(inputSgData, targetSgData)) {
+ if (inputLayout.isEqualTo(targetLayout, xegpu::LayoutKind::Subgroup)) {
inputLayout = inputLayout.dropSgLayoutAndData();
targetLayout = targetLayout.dropSgLayoutAndData();
>From 6bd954d9109d100ee5dbc77a90d4cf38cbfc1d63 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 13 Feb 2026 00:13:15 +0000
Subject: [PATCH 6/6] Clean up
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 22 +++++++++++++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 18 +++++++++++++++
.../Transforms/XeGPUSubgroupDistribute.cpp | 3 +--
.../Transforms/XeGPUWgToSgDistribute.cpp | 3 ++-
4 files changed, 39 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 71d0a4e711d31..377967dfdb1e5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -311,10 +311,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
/*methodName=*/"isSliceOf",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
- InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout
- at a specific level of the layout hierarchy.}],
+ InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
+ at a specific level of the layout hierarchy. Unlike isEqualTo,
+ this compares only the effective (non-sliced) fields at the
+ requested level.}],
/*retTy=*/"bool",
- /*methodName=*/"isEqualTo",
+ /*methodName=*/"isCompatibleWith",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
"xegpu::LayoutKind": $level),
/*methodBody=*/[{
@@ -331,7 +333,13 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
$_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
}
return false;
- }]>
+ }]>,
+ InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
+ For LayoutAttr, this compares all fields.
+ For SliceAttr, this requires the same parent and same sliced dims.}],
+ /*retTy=*/"bool",
+ /*methodName=*/"isEqualTo",
+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
];
}
@@ -562,6 +570,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+
+ /// Check if this layout is equal to another layout.
+ bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
}];
let assemblyFormat = "`<` struct(params) `>`";
@@ -747,6 +758,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
+ /// Check if this layout is equal to another layout.
+ bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+
/// Drop the slice dims to get the original layout.
SliceAttr dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop);
}];
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5e790d34171e8..7ace00a746e21 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -389,6 +389,13 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::SliceAttr>(other))
+ return false;
+
+ return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@@ -748,6 +755,17 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
[&](int64_t dim) { return thisDims.contains(dim); });
}
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+ if (dyn_cast<xegpu::LayoutAttr>(other))
+ return false;
+
+ auto flattenedThis = flatten();
+ auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+ return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+ (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
if (sliceDimsToDrop.empty())
return *this;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 48426016dcb17..aa1dfaa9e0fda 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1536,8 +1536,7 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
broadcastUnitDimsSet.end());
- bool isEqualTo =
- sourceLayout.isEqualTo(resultLayout, xegpu::LayoutKind::Lane);
+ bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
if (!isEqualTo)
return rewriter.notifyMatchFailure(
warpOp, "For same-rank broadcast, source must be identical to "
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 4015b333cec02..66eb7fc97aa1a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -626,7 +626,8 @@ struct WgToSgConvertLayoutOp
SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
// Fast path: if sg_layout and sg_data are identical, no SLM needed
- if (inputLayout.isEqualTo(targetLayout, xegpu::LayoutKind::Subgroup)) {
+ if (inputLayout.isCompatibleWith(targetLayout,
+ xegpu::LayoutKind::Subgroup)) {
inputLayout = inputLayout.dropSgLayoutAndData();
targetLayout = targetLayout.dropSgLayoutAndData();
More information about the Mlir-commits
mailing list