[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Aug 26 07:23:56 PDT 2025
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/154949
>From bb959236f321bf79fc90e47ea5115f986f3689d9 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sat, 23 Aug 2025 10:22:34 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Scattered ops sg-to-wi distribution
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 141 +++++++++++++++++-
.../Dialect/XeGPU/subgroup-distribute.mlir | 42 ++++++
2 files changed, 179 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8e47968609d32..647b11b3706e1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Operation *lastNode = yield->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ else if (!storeScatterOp.getOffsets())
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Store op must have offsets argument");
+ else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+ .getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets vector");
+
+ VectorType storeVecTy =
+ cast<VectorType>(storeScatterOp.getValue().getType());
+ assert(storeVecTy.getRank() <= 2 &&
+ "Expected at most 2D result at SG level");
+ VectorType distStoreVecTy;
+ if (storeVecTy.getRank() == 2)
+ distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+ else // rank 1
+ distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands =
+ llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+ SmallVector<Type> operandTypes =
+ llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+ operandTypes[0] = distStoreVecTy;
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ Value offsetsVec = newStoreScatterOpOperands[2];
+ Value maskVec = newStoreScatterOpOperands[3];
+
+ auto loc = newWarpOp.getLoc();
+ Value laneId = warpOp.getLaneid();
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value laneOffset =
+ vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+ laneOffset = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+ Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+ laneMask = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+ newStoreScatterOpOperands[2] = laneOffset;
+ newStoreScatterOpOperands[3] = laneMask;
+
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+ rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+ storeScatterOp->getAttrs());
+ xegpu::removeLayoutAttrs(newOp);
+ rewriter.eraseOp(storeScatterOp);
+ return success();
+ }
+};
+
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
+ if (!isa<xegpu::LoadGatherOp>(op))
+ return false;
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ return yield->getPrevNode() == op;
+ });
+ if (!yieldOperand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a xegpu::LoadGatherOp op");
+
+ auto loadGatherOp =
+ yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
+ if (!loadGatherOp.getOffsets())
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Load op must have offsets argument");
+ else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
+ 1)
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Expected 1D offsets vector");
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands =
+ llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+ SmallVector<Type> operandTypes =
+ llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ const unsigned operandIdx = yieldOperand->getOperandNumber();
+ VectorType loadVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
+
+ Value offsetsVec = newLoadGatherOperands[1];
+ Value maskVec = newLoadGatherOperands[2];
+ auto loc = newWarpOp.getLoc();
+ Value laneId = warpOp.getLaneid();
+ rewriter.setInsertionPointAfter(newWarpOp);
+ Value laneOffset =
+ vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+ laneOffset = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+ Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+ laneMask = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+ newLoadGatherOperands[1] = laneOffset;
+ newLoadGatherOperands[2] = laneMask;
+
+ xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
- patterns.getContext());
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+ patterns.getContext());
}
void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
+ continue;
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..b319162dc3f25 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,45 @@ gpu.module @test {
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
>From b55fbb236b40964830826e757ae3216c9cc3ca20 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 25 Aug 2025 18:04:16 +0000
Subject: [PATCH 2/3] Assume distributable offset and mask producers
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 54 ++++++-------------
.../Dialect/XeGPU/subgroup-distribute.mlir | 30 +++++------
2 files changed, 31 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 647b11b3706e1..63b5b6559dfdd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
if (!storeScatterOp)
return failure();
- else if (!storeScatterOp.getOffsets())
+ if (!storeScatterOp.getOffsets())
return rewriter.notifyMatchFailure(storeScatterOp,
"Store op must have offsets argument");
- else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
- .getRank() != 1)
+ VectorType offsetsTy =
+ cast<VectorType>(storeScatterOp.getOffsets().getType());
+ if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(storeScatterOp,
"Expected 1D offsets vector");
-
VectorType storeVecTy =
cast<VectorType>(storeScatterOp.getValue().getType());
assert(storeVecTy.getRank() <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
SmallVector<size_t> newRetIndices;
- SmallVector<Value> operands =
- llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+ SmallVector<Value> operands = storeScatterOp->getOperands();
SmallVector<Type> operandTypes =
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
operandTypes[0] = distStoreVecTy;
+ // Assume offset and mask pproducers will be distributed as well.
+ operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ operandTypes[3] = VectorType::get(
+ {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypes, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- Value offsetsVec = newStoreScatterOpOperands[2];
- Value maskVec = newStoreScatterOpOperands[3];
-
auto loc = newWarpOp.getLoc();
- Value laneId = warpOp.getLaneid();
rewriter.setInsertionPointAfter(newWarpOp);
- Value laneOffset =
- vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
- laneOffset = vector::BroadcastOp::create(
- rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
- Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
- laneMask = vector::BroadcastOp::create(
- rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
- newStoreScatterOpOperands[2] = laneOffset;
- newStoreScatterOpOperands[3] = laneMask;
-
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
if (!loadGatherOp.getOffsets())
return rewriter.notifyMatchFailure(loadGatherOp,
"Load op must have offsets argument");
- else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
- 1)
+ VectorType offsetsTy =
+ cast<VectorType>(loadGatherOp.getOffsets().getType());
+ if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(loadGatherOp,
"Expected 1D offsets vector");
SmallVector<size_t> newRetIndices;
- SmallVector<Value> operands =
- llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+ SmallVector<Value> operands = loadGatherOp->getOperands();
SmallVector<Type> operandTypes =
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+ // Assume offset and mask pproducers will be distributed as well.
+ operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ operandTypes[2] = VectorType::get(
+ {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
cast<VectorType>(warpOp.getResult(operandIdx).getType());
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
- Value offsetsVec = newLoadGatherOperands[1];
- Value maskVec = newLoadGatherOperands[2];
auto loc = newWarpOp.getLoc();
- Value laneId = warpOp.getLaneid();
rewriter.setInsertionPointAfter(newWarpOp);
- Value laneOffset =
- vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
- laneOffset = vector::BroadcastOp::create(
- rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
- Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
- laneMask = vector::BroadcastOp::create(
- rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
- newLoadGatherOperands[1] = laneOffset;
- newLoadGatherOperands[2] = laneMask;
-
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index b319162dc3f25..1c4684681b62b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -323,19 +323,18 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.module @test {
- gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
- %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
- xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
gpu.return
}
@@ -344,19 +343,18 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.module @test {
- gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+ gpu.func @scatter_ops(%src: memref<256xf16>) {
%1 = arith.constant dense<1>: vector<16xi1>
- %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
- xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
gpu.return
}
>From 4971e49ed93c7a6f526e149c9fb5c635b4a3f0d2 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 26 Aug 2025 14:23:30 +0000
Subject: [PATCH 3/3] Address feedback
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 123 ++++++++++++------
1 file changed, 82 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 63b5b6559dfdd..fe5231368a895 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,26 +807,47 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+/// no chunk size -> vector of one element.
+/// chunk size -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
struct StoreDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- Operation *lastNode = yield->getPrevNode();
+ Operation *lastNode = warpOp.getTerminator()->getPrevNode();
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
if (!storeScatterOp)
return failure();
- if (!storeScatterOp.getOffsets())
- return rewriter.notifyMatchFailure(storeScatterOp,
- "Store op must have offsets argument");
- VectorType offsetsTy =
- cast<VectorType>(storeScatterOp.getOffsets().getType());
+ auto offsets = storeScatterOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()))
+ return rewriter.notifyMatchFailure(
+ storeScatterOp, "Store op must have a vector of offsets argument");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(storeScatterOp,
"Expected 1D offsets vector");
- VectorType storeVecTy =
- cast<VectorType>(storeScatterOp.getValue().getType());
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
assert(storeVecTy.getRank() <= 2 &&
"Expected at most 2D result at SG level");
VectorType distStoreVecTy;
@@ -837,23 +858,23 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = storeScatterOp->getOperands();
- SmallVector<Type> operandTypes =
+ SmallVector<Type> operandTypesToYield =
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
- operandTypes[0] = distStoreVecTy;
- // Assume offset and mask pproducers will be distributed as well.
- operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
- operandTypes[3] = VectorType::get(
+ operandTypesToYield[0] = distStoreVecTy;
+ // Assume offset and mask producers will be distributed as well.
+ operandTypesToYield[2] =
+ VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ operandTypesToYield[3] = VectorType::get(
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, operands, operandTypes, newRetIndices);
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- auto loc = newWarpOp.getLoc();
rewriter.setInsertionPointAfter(newWarpOp);
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
- rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+ rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
storeScatterOp->getAttrs());
xegpu::removeLayoutAttrs(newOp);
rewriter.eraseOp(storeScatterOp);
@@ -861,56 +882,75 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Distribute a scattered load op. The logic and requirements are the same as
+/// for the scattered store distribution. The warpOp's payload vector is
+/// expected to be distributed by the load's result consumer.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+/// Example 2 (chunk size, same mask and offsets):
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
struct LoadDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
- if (!isa<xegpu::LoadGatherOp>(op))
- return false;
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- return yield->getPrevNode() == op;
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid sinking it before barriers (maintain memory order).
+ return isa<xegpu::LoadGatherOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
});
- if (!yieldOperand)
+ if (!producedByLastLoad)
return rewriter.notifyMatchFailure(
- warpOp, "warp result is not a xegpu::LoadGatherOp op");
+ warpOp, "The last op is not xegpu::LoadGatherOp");
auto loadGatherOp =
- yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
- if (!loadGatherOp.getOffsets())
- return rewriter.notifyMatchFailure(loadGatherOp,
- "Load op must have offsets argument");
- VectorType offsetsTy =
- cast<VectorType>(loadGatherOp.getOffsets().getType());
+ producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
+ auto offsets = loadGatherOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()))
+ return rewriter.notifyMatchFailure(
+ loadGatherOp, "Load op must have a vector of offsets argument");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(loadGatherOp,
"Expected 1D offsets vector");
SmallVector<size_t> newRetIndices;
SmallVector<Value> operands = loadGatherOp->getOperands();
- SmallVector<Type> operandTypes =
+ SmallVector<Type> operandTypesToYield =
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
- // Assume offset and mask pproducers will be distributed as well.
- operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
- operandTypes[2] = VectorType::get(
- {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
+ // Assume offset and mask producers will be distributed as well.
+ operandTypesToYield[1] =
+ VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ operandTypesToYield[2] =
+ VectorType::get({1}, getElementTypeOrSelf(loadGatherOp.getMaskType()));
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, operands, operandTypes, newRetIndices);
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- const unsigned operandIdx = yieldOperand->getOperandNumber();
+ const unsigned operandIdx = producedByLastLoad->getOperandNumber();
VectorType loadVecTy =
cast<VectorType>(warpOp.getResult(operandIdx).getType());
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
- auto loc = newWarpOp.getLoc();
rewriter.setInsertionPointAfter(newWarpOp);
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
- loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+ newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
+ loadGatherOp->getAttrs());
Value distributedVal = newWarpOp.getResult(operandIdx);
rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
return success();
@@ -948,6 +988,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
+ // Vectors operands of these ops have a fixed and implicit layout.
if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
continue;
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
More information about the Mlir-commits
mailing list