[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg (PR #154420)
Nishant Patel
llvmlistbot at llvm.org
Fri Aug 29 10:55:05 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/154420
>From b94a37fb467510f4cfe9e3b902a2147f99010c8f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 18 Aug 2025 16:18:26 +0000
Subject: [PATCH 1/6] Add pattern for load_gather and store_scatter ops
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 12 +++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 34 ++++++
.../Transforms/XeGPUWgToSgDistribute.cpp | 100 +++++++++++++++++-
3 files changed, 145 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 480b43e740736..34429a34f6d96 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -751,6 +751,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Type": $value, "Value": $source,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
@@ -859,6 +865,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+ "xegpu::CachePolicyAttr": $l1_hint,
+ "xegpu::CachePolicyAttr": $l2_hint,
+ "xegpu::CachePolicyAttr": $l3_hint)>,
+ OpBuilder<(ins "Value": $value, "Value": $dest,
+ "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+ "IntegerAttr": $chunk_size,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b7ce19e6937b..2e9bb88174c74 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -737,6 +737,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = source.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+ l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_StoreScatterOp
//===----------------------------------------------------------------------===//
@@ -785,6 +801,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
l2_hint, l3_hint);
}
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest,
+ ArrayRef<OpFoldResult> offsets, Value mask,
+ IntegerAttr chunk_size,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ auto loc = dest.getLoc();
+ int64_t size = static_cast<int64_t>(offsets.size());
+ auto type = VectorType::get(size, builder.getIndexType());
+ auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+ auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+ // Call the correct builder overload that does not expect result types.
+ build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+ l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_UpdateOffsetOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 270d71aaa7273..a4d697b5357e6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -685,6 +685,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+ : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ SmallVector<Value> newLoadOps;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+ VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+ for (auto [offsets, mask] :
+ llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+ auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newTy, op.getSource(), offsets, mask,
+ chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ xegpu::setLayoutAttr(newLoadOp->getResult(0),
+ layout.dropSgLayoutAndData());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+ : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getOffsets())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+ if (!valueType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = valueType.getShape();
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ auto chunkSizeOpt = op.getChunkSize();
+ int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+ for (auto [val, offs, mask] : llvm::zip(
+ adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+ if (auto layoutAttr =
+ op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ op->setAttr("layout_result_0", newLayout);
+ }
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -694,7 +776,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
- WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
+ WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+ WgToSgStoreScatterOpWithOffset>(
patterns.getContext());
}
} // namespace xegpu
@@ -815,6 +898,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+ [=](xegpu::LoadGatherOp op) -> bool {
+ auto layout = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+ [=](xegpu::StoreScatterOp op) -> bool {
+ // Check if the layout attribute is present on the result.
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+ if (!layout)
+ return true;
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
>From 459e98ab8e4f770f4832873d9adc4d97035c9933 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 19 Aug 2025 17:26:10 +0000
Subject: [PATCH 2/6] Add tests
---
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 41 +++++++++++++++++++
1 file changed, 41 insertions(+)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 07a0b86223c33..03fd95f0e778e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -263,4 +263,45 @@ gpu.module @test_distribution {
} {sg_id_range = #xegpu.range<[3, 19]>}
gpu.return
}
+
+ // CHECK-LABEL: @load_gather
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_gather(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @store_scatter
+ // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
+ gpu.func @store_scatter(%dest : memref<256xf16>) {
+ // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+ %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+ : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @load_with_chunk_size
+ // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+ gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+ %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+ : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
+ gpu.return
+ }
}
>From a25c40deb19866819c0a8efa61e30d878111826c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 28 Aug 2025 16:38:55 +0000
Subject: [PATCH 3/6] Feedback
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 16 ++++++----------
.../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 15 +++++++++------
2 files changed, 15 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 367d93b37c614..90084b4395355 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -780,7 +780,7 @@ struct WgToSgLoadGatherOpWithOffset
ArrayRef<int64_t> wgShape = resultType.getShape();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ if (!layout || !layout.isForWorkgroup())
return failure();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -820,9 +820,8 @@ struct WgToSgStoreScatterOpWithOffset
if (!valueType)
return failure();
- ArrayRef<int64_t> wgShape = valueType.getShape();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
- if (!layout || !layout.getSgLayout())
+ if (!layout || !layout.isForWorkgroup())
return failure();
auto chunkSizeOpt = op.getChunkSize();
@@ -833,12 +832,9 @@ struct WgToSgStoreScatterOpWithOffset
rewriter.create<xegpu::StoreScatterOp>(
loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
- // Update the layout_result_0 attribute to drop sg_layout and sg_data.
- if (auto layoutAttr =
- op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
- if (auto newLayout = layoutAttr.dropSgLayoutAndData())
- op->setAttr("layout_result_0", newLayout);
- }
+ // Update the layout attribute to drop sg_layout and sg_data.
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ op->setAttr("layout", newLayout);
}
rewriter.eraseOp(op);
return success();
@@ -1042,7 +1038,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
[=](xegpu::StoreScatterOp op) -> bool {
// Check if the layout attribute is present on the result.
- auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+ auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
if (!layout)
return true;
return isLegal(layout);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 90e36a4c994b9..afb2bf876c18f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -269,7 +269,8 @@ gpu.module @test_distribution {
gpu.func @load_gather(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
- // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
@@ -283,21 +284,23 @@ gpu.module @test_distribution {
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
- // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+ // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
- xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+ xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
gpu.return
}
- // CHECK-LABEL: @load_with_chunk_size
+ // CHECK-LABEL: @load_with_non_unit_chunk_size
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
- gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
+ gpu.func @load_with_non_unit_chunk_size(%src : memref<?xf16>) {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
- // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+ // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+ // CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
%load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
>From bdbf14f0c1bcbea52863da47a1ef84f1c8024013 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 28 Aug 2025 16:45:09 +0000
Subject: [PATCH 4/6] Cleanup
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 90084b4395355..3ed177e398836 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -764,7 +764,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
};
// This pattern transforms the LoadGatherOp with explicit offsets to load
-// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+// subgroup data
struct WgToSgLoadGatherOpWithOffset
: public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
@@ -804,7 +804,7 @@ struct WgToSgLoadGatherOpWithOffset
};
// This pattern transforms the StoreScatterOp with explicit offsets to store
-// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+// subgroup data
struct WgToSgStoreScatterOpWithOffset
: public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
>From c93090f2f0853aa7047d90bc58aba9d1be518a92 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 28 Aug 2025 17:21:24 +0000
Subject: [PATCH 5/6] Add check
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3ed177e398836..90848eebd1243 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -785,6 +785,14 @@ struct WgToSgLoadGatherOpWithOffset
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ // The offsets need to be distributed
+ if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
+ .getShape() !=
+ dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
SmallVector<Value> newLoadOps;
auto chunkSizeAttr =
rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
@@ -824,6 +832,14 @@ struct WgToSgStoreScatterOpWithOffset
if (!layout || !layout.isForWorkgroup())
return failure();
+ // The offsets need to be distributed
+ if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
+ .getShape() !=
+ dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
+ return rewriter.notifyMatchFailure(op,
+ "offsets have not been distributed");
+ }
+
auto chunkSizeOpt = op.getChunkSize();
int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
>From a7b780d172be99f9b3b0672640e0aa27e958bd86 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 29 Aug 2025 17:50:03 +0000
Subject: [PATCH 6/6] Feedback
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 18 ++++++++++++------
1 file changed, 12 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 90848eebd1243..623cac66e90f3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -786,9 +786,12 @@ struct WgToSgLoadGatherOpWithOffset
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
// The offsets need to be distributed
- if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
- .getShape() !=
- dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
return rewriter.notifyMatchFailure(op,
"offsets have not been distributed");
}
@@ -833,9 +836,12 @@ struct WgToSgStoreScatterOpWithOffset
return failure();
// The offsets need to be distributed
- if (dyn_cast<VectorType>(adaptor.getOffsets().front().getType())
- .getShape() !=
- dyn_cast<VectorType>(adaptor.getMask().front().getType()).getShape()) {
+ auto offsetsVecType =
+ dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
+ auto maskVecType =
+ dyn_cast<VectorType>(adaptor.getMask().front().getType());
+ if (!offsetsVecType || !maskVecType ||
+ offsetsVecType.getShape() != maskVecType.getShape()) {
return rewriter.notifyMatchFailure(op,
"offsets have not been distributed");
}
More information about the Mlir-commits
mailing list