[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi (PR #181917)
Nishant Patel
llvmlistbot at llvm.org
Tue Feb 17 13:52:39 PST 2026
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/181917
This PR adds distribution pattern for xegpu.load & store ops for the new sg-to-wi pass
>From 766846bcae560849cf5827b1818215ab4af9eb59 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 10 Feb 2026 23:30:25 +0000
Subject: [PATCH 1/4] Add pattern for xegpu.load & store
---
.../XeGPUSgToWiDistributeExperimental.cpp | 140 +++++++++++++++++-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 111 ++++++++++++++
2 files changed, 249 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..04b1f66b85f17 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -362,6 +362,141 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+/// The result at workitem level is always 1D. A ShapeCast is added to restore
+/// the expected rank from the lane layout if needed.
+///
+/// Example (with chunk_size):
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To:
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
+/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+/// %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
+///
+/// Example (without chunk_size):
+/// %0 = xegpu.load %src[%offset], %mask
+/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To:
+/// %0 = xegpu.load %src[%offset], %mask
+/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType resultTy = op.getValueType();
+ if (!resultTy)
+ return failure();
+
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+ // The hardware-supported WI type for scatter ops is always 1D.
+ VectorType supportedWiResultTy =
+ VectorType::get({expectedWiResultTy.getNumElements()},
+ expectedWiResultTy.getElementType());
+
+ // Build the new op with adapted (type-converted) values.
+ // Use Value() for offsets if not present (optional operand).
+ Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
+ offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+
+ // Cast the result to the expected type if needed (e.g., 1D to 2D).
+ Value result = newOp->getResult(0);
+ if (supportedWiResultTy != expectedWiResultTy)
+ result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ expectedWiResultTy, result)
+ .getResult();
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
+/// A ShapeCast is added to cast the incoming value to 1D if needed.
+///
+/// Example (with chunk_size):
+/// xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+/// : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+/// %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
+/// xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
+/// : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example (without chunk_size):
+/// xegpu.store %val, %src[%offset], %mask
+/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+/// : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+/// xegpu.store %val, %src[%offset], %mask
+/// : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType valueTy = op.getValueType();
+ if (!valueTy)
+ return failure();
+
+ auto expectedWiValueTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+ if (failed(expectedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+ // The hardware-supported WI type for scatter ops is always 1D.
+ VectorType supportedWiValueTy =
+ VectorType::get({expectedWiValueTy.getNumElements()},
+ expectedWiValueTy.getElementType());
+
+ // Cast the adapted value to the supported 1D type if needed.
+ Value adaptedValue = adaptor.getValue();
+ if (adaptedValue.getType() != supportedWiValueTy)
+ adaptedValue =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
+ adaptedValue)
+ .getResult();
+
+ // Build the new op with adapted values.
+ // Use Value() for offsets if not present (optional operand).
+ Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ xegpu::StoreScatterOp::create(
+ rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
+ adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
@@ -553,6 +688,7 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
});
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
- SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
- typeConverter, patterns.getContext());
+ SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+ SgToWiLoadGather, SgToWiStoreScatter>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..8a3dc92cd3a44 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -149,4 +149,115 @@ gpu.func @prefetch_nd() {
: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
+
+// CHECK-LABEL: gpu.func @scatter_load_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+gpu.func @scatter_load_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[C1:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK: %[[C2:.*]] = vector.shape_cast %[[C1]] : vector<1x8xf16> to vector<8xf16>
+// CHECK: xegpu.store %[[C2]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_load
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+gpu.func @scatter_load(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: xegpu.store %[[LOAD]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<1> : vector<1x1x16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<12> : vector<1x1x16xindex>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+ gpu.return
+}
+
}
>From 9374957ca5b6f7b064f5c9cf7975ff33626cc3a8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:10:50 +0000
Subject: [PATCH 2/4] clang-format
---
.../XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9ae564c685e07..ec9d1bcd90cb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -885,8 +885,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
- SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction, SgToWiMultiDimReduction>(typeConverter,
- patterns.getContext());
+ SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+ SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
}
void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
>From d9982f36dddfe98cc9d851e5d3a50b9756a8784e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:23:46 +0000
Subject: [PATCH 3/4] add flatten
---
.../XeGPUSgToWiDistributeExperimental.cpp | 50 ++++++++++++++++---
.../XeGPU/sg-to-wi-experimental-unit.mlir | 12 +++--
2 files changed, 50 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index ec9d1bcd90cb8..f5a48f39860c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -453,12 +453,29 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
VectorType::get({expectedWiResultTy.getNumElements()},
expectedWiResultTy.getElementType());
- // Build the new op with adapted (type-converted) values.
- // Use Value() for offsets if not present (optional operand).
+ // Flatten offsets and mask to 1D for hardware compatibility.
Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ if (offsets) {
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+
+ // Build the new op with adapted (type-converted) values.
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
- offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
// Cast the result to the expected type if needed (e.g., 1D to 2D).
@@ -665,13 +682,30 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
adaptedValue)
.getResult();
- // Build the new op with adapted values.
- // Use Value() for offsets if not present (optional operand).
+ // Flatten offsets and mask to 1D for hardware compatibility.
Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ if (offsets) {
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+
+ // Build the new op with adapted values.
xegpu::StoreScatterOp::create(
- rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
- adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
+ op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
rewriter.eraseOp(op);
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1dd3b07c69a13..1cf41667b7a97 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -242,12 +242,16 @@ gpu.func @scatter_store(%src: memref<256xf16>) {
// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
-// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V2:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[V1]]], %[[V2]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
-// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+// CHECK: %[[V3:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V4:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[V3]]], %[[V4]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
%mask = arith.constant
{layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
>From be9507e56df7623ad21f125e9daaf38c6d94d609 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 18:49:53 +0000
Subject: [PATCH 4/4] Clean up
---
.../XeGPUSgToWiDistributeExperimental.cpp | 95 +++++--------------
1 file changed, 26 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index f5a48f39860c2..b3a9f8cd86667 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -396,25 +396,6 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
};
/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
-/// The result at workitem level is always 1D. A ShapeCast is added to restore
-/// the expected rank from the lane layout if needed.
-///
-/// Example (with chunk_size):
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-/// To:
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
-/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-/// %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
-///
-/// Example (without chunk_size):
-/// %0 = xegpu.load %src[%offset], %mask
-/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-/// To:
-/// %0 = xegpu.load %src[%offset], %mask
-/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
@@ -422,7 +403,6 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
if (!layout)
return failure();
@@ -430,7 +410,7 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
if (!resultTy)
return failure();
- // Check that all leading dimensions are unit dimensions.
+ // Check that leading dimensions are unit.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
@@ -448,15 +428,13 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
"unable to compute expected workitem vector type from lane layout");
VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
- // The hardware-supported WI type for scatter ops is always 1D.
VectorType supportedWiResultTy =
VectorType::get({expectedWiResultTy.getNumElements()},
expectedWiResultTy.getElementType());
- // Flatten offsets and mask to 1D for hardware compatibility.
- Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
- if (offsets) {
- auto offsetsTy = cast<VectorType>(offsets.getType());
+ // Flatten offsets and mask to 1D to match the 1D result type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
if (offsetsTy != offsetsTy1D)
@@ -465,20 +443,20 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
.getResult();
}
Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
-
- // Build the new op with adapted (type-converted) values.
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
- // Cast the result to the expected type if needed (e.g., 1D to 2D).
Value result = newOp->getResult(0);
if (supportedWiResultTy != expectedWiResultTy)
result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
@@ -617,25 +595,7 @@ struct LowerVectorMultiReductionPattern
};
/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
-/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
-/// A ShapeCast is added to cast the incoming value to 1D if needed.
-///
-/// Example (with chunk_size):
-/// xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-/// : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-/// %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
-/// xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
-/// : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example (without chunk_size):
-/// xegpu.store %val, %src[%offset], %mask
-/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-/// : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-/// xegpu.store %val, %src[%offset], %mask
-/// : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// workitem-level.
struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
@@ -643,7 +603,6 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
if (!layout)
return failure();
@@ -669,12 +628,10 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
"unable to compute expected workitem vector type from lane layout");
VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
- // The hardware-supported WI type for scatter ops is always 1D.
VectorType supportedWiValueTy =
VectorType::get({expectedWiValueTy.getNumElements()},
expectedWiValueTy.getElementType());
- // Cast the adapted value to the supported 1D type if needed.
Value adaptedValue = adaptor.getValue();
if (adaptedValue.getType() != supportedWiValueTy)
adaptedValue =
@@ -682,10 +639,9 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
adaptedValue)
.getResult();
- // Flatten offsets and mask to 1D for hardware compatibility.
- Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
- if (offsets) {
- auto offsetsTy = cast<VectorType>(offsets.getType());
+ // Flatten offsets and mask to 1D to match the 1D value type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
if (offsetsTy != offsetsTy1D)
@@ -694,14 +650,15 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
.getResult();
}
Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
-
- // Build the new op with adapted values.
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
xegpu::StoreScatterOp::create(
rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
More information about the Mlir-commits
mailing list