[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 17 13:53:16 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR adds distribution pattern for xegpu.load & store ops for the new sg-to-wi pass
---
Full diff: https://github.com/llvm/llvm-project/pull/181917.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp (+148-2)
- (modified) mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir (+114)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 3787fbb44e1b8..b3a9f8cd86667 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -395,6 +395,78 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ if (!layout)
+ return failure();
+
+ VectorType resultTy = op.getValueType();
+ if (!resultTy)
+ return failure();
+
+ // Check that leading dimensions are unit.
+ int chunkSize = op.getChunkSize().value_or(1);
+ int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
+ if (resultTy.getShape()[i] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the load vector!");
+ }
+
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+ VectorType supportedWiResultTy =
+ VectorType::get({expectedWiResultTy.getNumElements()},
+ expectedWiResultTy.getElementType());
+
+ // Flatten offsets and mask to 1D to match the 1D result type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
+ offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+
+ Value result = newOp->getResult(0);
+ if (supportedWiResultTy != expectedWiResultTy)
+ result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ expectedWiResultTy, result)
+ .getResult();
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/// This pattern distributes a subgroup-level vector.reduction op to
/// workitem-level. This require shuffling the data across the workitems (using
/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
@@ -522,6 +594,80 @@ struct LowerVectorMultiReductionPattern
}
};
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level.
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ if (!layout)
+ return failure();
+
+ VectorType valueTy = op.getValueType();
+ if (!valueTy)
+ return failure();
+
+ // Check that all leading dimensions are unit dimensions.
+ int chunkSize = op.getChunkSize().value_or(1);
+ int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ for (int i = 0; i < valueTy.getRank() - effectiveVecRank; i++) {
+ if (valueTy.getShape()[i] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the store vector!");
+ }
+
+ auto expectedWiValueTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+ if (failed(expectedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+ VectorType supportedWiValueTy =
+ VectorType::get({expectedWiValueTy.getNumElements()},
+ expectedWiValueTy.getElementType());
+
+ Value adaptedValue = adaptor.getValue();
+ if (adaptedValue.getType() != supportedWiValueTy)
+ adaptedValue =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
+ adaptedValue)
+ .getResult();
+
+ // Flatten offsets and mask to 1D to match the 1D value type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
+ xegpu::StoreScatterOp::create(
+ rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
+ op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
@@ -730,8 +876,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
- SgToWiVectorReduction, SgToWiMultiDimReduction>(
- typeConverter, patterns.getContext());
+ SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+ SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
}
void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1ec0879d4fb47..1cf41667b7a97 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -154,6 +154,120 @@ gpu.func @prefetch_nd() {
gpu.return
}
+// CHECK-LABEL: gpu.func @scatter_load_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+gpu.func @scatter_load_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[C1:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK: %[[C2:.*]] = vector.shape_cast %[[C1]] : vector<1x8xf16> to vector<8xf16>
+// CHECK: xegpu.store %[[C2]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_load
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+gpu.func @scatter_load(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: xegpu.store %[[LOAD]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V2:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[V1]]], %[[V2]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
+// CHECK: %[[V3:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V4:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[V3]]], %[[V4]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<1> : vector<1x1x16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<12> : vector<1x1x16xindex>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+ gpu.return
+}
+
// CHECK-LABEL: gpu.func @vector_reduction
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[LANE_RED:.*]] = vector.reduction <add>, %[[CAST:.*]] : vector<2xf32> into f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/181917
More information about the Mlir-commits
mailing list