[Mlir-commits] [mlir] 6c6afdd - [MLIR][XeGPU] Reapply attempt for "Scattered ops sg-to-wi distribution #154949" (#156924)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 4 12:04:34 PDT 2025
Author: Artem Kroviakov
Date: 2025-09-04T12:04:30-07:00
New Revision: 6c6afdd8c262f49bb23cf455d98108f31b732c6c
URL: https://github.com/llvm/llvm-project/commit/6c6afdd8c262f49bb23cf455d98108f31b732c6c
DIFF: https://github.com/llvm/llvm-project/commit/6c6afdd8c262f49bb23cf455d98108f31b732c6c.diff
LOG: [MLIR][XeGPU] Reapply attempt for "Scattered ops sg-to-wi distribution #154949" (#156924)
This PR is a reapply of
https://github.com/llvm/llvm-project/pull/154949, which failed one of
sanitizer checks.
The issue was querying the `warpOp` results in `LoadDistribution` after
calling `moveRegionToNewWarpOpAndAppendReturns()`, which resulted in use
after free. This PR solves the issue by moving the op query before the
call and is otherwise identical to the one linked above.
---------
Co-authored-by: Charitha Saumya <136391709+charithaintc at users.noreply.github.com>
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/test/Dialect/XeGPU/propagate-layout.mlir
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 5cb47b2accd68..c0c4394f73d4a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -194,7 +194,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
+static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
+ bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
@@ -207,6 +208,14 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ if (isScattered) {
+ packingFactor =
+ bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
+ ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
+ : 1;
+ return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+ LaneData({1, packingFactor}));
+ }
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
@@ -214,7 +223,8 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
+ bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
"Expected 1D or 2D TensorDesc.");
@@ -227,7 +237,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
- if (tdescTy.isScattered()) {
+ if (isScattered) {
int packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
@@ -541,21 +551,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
-/// Propagate the layout of the result to the tensor descriptor and mask
+/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- // The layout is strictly determined by the tensor descriptor type.
- LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
+ // The layout is strictly determined by the payload type.
+ auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+ if (!payloadTy) {
+ load.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
// Mask operand should have 1D default layout.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
// Propagate the new layout to the tensor descriptor operand.
- propagateIfChanged(operands[0], operands[0]->meet(layout));
- // Propagate the new layout to the mask operand.
+ if (isa<xegpu::TensorDescType>(load.getSourceType()))
+ propagateIfChanged(operands[0], operands[0]->meet(layout));
+ // Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+ if (load.getOffsets())
+ propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
/// Propagate the layout of the descriptor to the vector offset operand in
@@ -572,31 +590,39 @@ void LayoutInfoPropagation::visitCreateDescOp(
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
-/// Set the layout for the value, tensor descriptor, and mask operands in the
-/// StoreScatterOp.
+/// Set the layout for the value, tensor descriptor, offset and mask operands in
+/// the StoreScatterOp.
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
- ArrayRef<int64_t> tdescShape = storeScatter.getTensorDescType().getShape();
- if (tdescShape.size() > 1)
+ auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+ if (!payloadTy) {
+ storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
+ return;
+ }
+ auto payloadShape = payloadTy.getShape();
+ if (payloadShape.size() > 1)
assert(
- tdescShape[0] == xegpu::targetinfo::subgroupSize &&
+ payloadShape[0] == xegpu::targetinfo::subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");
- LayoutInfo layout =
- getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+ LayoutInfo payloadLayout =
+ getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
- // Propagate the value layout.
- propagateIfChanged(operands[0], operands[0]->meet(layout));
- // Propagate the tensor descriptor layout.
- propagateIfChanged(operands[1], operands[1]->meet(layout));
- // Use default 1D layout for mask operand.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+ // Propagate the payload operand layout
+ propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
+ // Propagate the destination (if tdesc) operand layout
+ if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+ propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
+ // Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+ if (storeScatter.getOffsets())
+ propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}
namespace {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index dddb5eaece2cb..6b8367dd8c201 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,6 +807,200 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+/// no chunk size -> vector of one element.
+/// chunk size -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ Operation *lastNode = warpOp.getTerminator()->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ auto offsets = storeScatterOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()))
+ return rewriter.notifyMatchFailure(
+ storeScatterOp, "Store op must have a vector of offsets argument");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
+ if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets and mask vector");
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+ if (storeVecTy.getRank() > 2)
+ return rewriter.notifyMatchFailure(
+ storeScatterOp, "Expected at most 2D result at SG level");
+
+ std::string layoutPayloadName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
+ std::string layoutOffsetsName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
+ std::string layoutMaskName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
+
+ xegpu::LayoutAttr layoutPayload =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
+ xegpu::LayoutAttr layoutOffsets =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+ xegpu::LayoutAttr layoutMask =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+ if (failed(distStoreVecByWarpOpOrFailure) ||
+ failed(distOffsetsByWarpOpOrFailure) ||
+ failed(distMaskByWarpOpOrFailure)) {
+ return rewriter.notifyMatchFailure(
+ storeScatterOp,
+ "Some vector operands have no layouts, using defaults instead.");
+ }
+ VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
+ VectorType expectedPayloadTy = VectorType::get(
+ {distPayloadTy.getNumElements()}, distPayloadTy.getElementType());
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands = storeScatterOp->getOperands();
+ SmallVector<Type> operandTypesToYield = {
+ expectedPayloadTy, operands[1].getType(),
+ distOffsetsByWarpOpOrFailure.value(),
+ distMaskByWarpOpOrFailure.value()};
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+ rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
+ storeScatterOp->getAttrs());
+ xegpu::removeLayoutAttrs(newOp);
+ rewriter.eraseOp(storeScatterOp);
+ return success();
+ }
+};
+
+/// Distribute a scattered load op. The logic and requirements are the same as
+/// for the scattered store distribution. The warpOp's payload vector is
+/// expected to be distributed by the load's result consumer.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+/// Example 2 (chunk size, same mask and offsets):
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid sinking it before barriers (maintain memory order).
+ return isa<xegpu::LoadGatherOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadGatherOp");
+
+ auto loadGatherOp =
+ producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
+ auto offsets = loadGatherOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()) ||
+ !isa<VectorType>(loadGatherOp.getMask().getType()))
+ return rewriter.notifyMatchFailure(
+ loadGatherOp,
+ "Load op must have a vector arguments for offsets and mask");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
+ if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Expected 1D offsets and mask vector");
+ // Assume offset and mask producers will be distributed as well.
+ std::string layoutOffsetsName =
+ xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
+ std::string layoutMaskName =
+ xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
+
+ xegpu::LayoutAttr layoutOffsets =
+ loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+ xegpu::LayoutAttr layoutMask =
+ loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+ if (failed(distOffsetsByWarpOpOrFailure) ||
+ failed(distMaskByWarpOpOrFailure)) {
+ return rewriter.notifyMatchFailure(
+ loadGatherOp,
+ "Some vector operands have no layouts, using defaults instead.");
+ }
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands = loadGatherOp->getOperands();
+ SmallVector<Type> operandTypesToYield = {
+ operands[0].getType(), distOffsetsByWarpOpOrFailure.value(),
+ distMaskByWarpOpOrFailure.value()};
+
+ const unsigned operandIdx = producedByLastLoad->getOperandNumber();
+ VectorType loadVecTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
+
+ SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+ newWarpOp.getLoc(), loadVecTy, newLoadGatherOperands,
+ loadGatherOp->getAttrs());
+ xegpu::removeLayoutAttrs(newOp);
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -819,10 +1013,11 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
- patterns.getContext());
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+ GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+ patterns.getContext());
}
void XeGPUSubgroupDistributePass::runOnOperation() {
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 0214d84f2c16f..cba3f0bd690c3 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -162,6 +162,40 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
return
}
+// -----
+// CHECK-LABEL: func.func @scatter_ops_chunksize(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @scatter_ops(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+func.func @scatter_ops(%src: memref<256xf16>) {
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %offset = arith.constant dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ return
+}
+
// -----
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..a39aa90bbe3a8 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,39 @@ gpu.module @test {
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
+ %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
+ layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+ gpu.func @scatter_ops(%src: memref<256xf16>) {
+ %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+ %3 = xegpu.load %src[%offset], %1 {
+ layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+ } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %3, %src[%offset], %1 : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+ }
+}
More information about the Mlir-commits
mailing list