[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)
Nishant Patel
llvmlistbot at llvm.org
Thu Feb 26 12:53:36 PST 2026
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/181917
>From 766846bcae560849cf5827b1818215ab4af9eb59 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 10 Feb 2026 23:30:25 +0000
Subject: [PATCH 1/7] Add pattern for xegpu.load & store
---
.../XeGPUSgToWiDistributeExperimental.cpp | 140 +++++++++++++++++-
.../XeGPU/sg-to-wi-experimental-unit.mlir | 111 ++++++++++++++
2 files changed, 249 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..04b1f66b85f17 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -362,6 +362,141 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+/// The result at workitem level is always 1D. A ShapeCast is added to restore
+/// the expected rank from the lane layout if needed.
+///
+/// Example (with chunk_size):
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To:
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
+/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+/// %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
+///
+/// Example (without chunk_size):
+/// %0 = xegpu.load %src[%offset], %mask
+/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To:
+/// %0 = xegpu.load %src[%offset], %mask
+/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType resultTy = op.getValueType();
+ if (!resultTy)
+ return failure();
+
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+ // The hardware-supported WI type for scatter ops is always 1D.
+ VectorType supportedWiResultTy =
+ VectorType::get({expectedWiResultTy.getNumElements()},
+ expectedWiResultTy.getElementType());
+
+ // Build the new op with adapted (type-converted) values.
+ // Use Value() for offsets if not present (optional operand).
+ Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ auto newOp = xegpu::LoadGatherOp::create(
+ rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
+ offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+
+ // Cast the result to the expected type if needed (e.g., 1D to 2D).
+ Value result = newOp->getResult(0);
+ if (supportedWiResultTy != expectedWiResultTy)
+ result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ expectedWiResultTy, result)
+ .getResult();
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
+/// A ShapeCast is added to cast the incoming value to 1D if needed.
+///
+/// Example (with chunk_size):
+/// xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+/// : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+/// %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
+/// xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
+/// : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example (without chunk_size):
+/// xegpu.store %val, %src[%offset], %mask
+/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+/// : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+/// xegpu.store %val, %src[%offset], %mask
+/// : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+ using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ VectorType valueTy = op.getValueType();
+ if (!valueTy)
+ return failure();
+
+ auto expectedWiValueTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+ if (failed(expectedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+ // The hardware-supported WI type for scatter ops is always 1D.
+ VectorType supportedWiValueTy =
+ VectorType::get({expectedWiValueTy.getNumElements()},
+ expectedWiValueTy.getElementType());
+
+ // Cast the adapted value to the supported 1D type if needed.
+ Value adaptedValue = adaptor.getValue();
+ if (adaptedValue.getType() != supportedWiValueTy)
+ adaptedValue =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
+ adaptedValue)
+ .getResult();
+
+ // Build the new op with adapted values.
+ // Use Value() for offsets if not present (optional operand).
+ Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ xegpu::StoreScatterOp::create(
+ rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
+ adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
@@ -553,6 +688,7 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
});
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
- SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
- typeConverter, patterns.getContext());
+ SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+ SgToWiLoadGather, SgToWiStoreScatter>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..8a3dc92cd3a44 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -149,4 +149,115 @@ gpu.func @prefetch_nd() {
: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
+
+// CHECK-LABEL: gpu.func @scatter_load_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+gpu.func @scatter_load_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[C1:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK: %[[C2:.*]] = vector.shape_cast %[[C1]] : vector<1x8xf16> to vector<8xf16>
+// CHECK: xegpu.store %[[C2]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store_chunksize(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+ : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_load
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+gpu.func @scatter_load(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: xegpu.store %[[LOAD]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store(%src: memref<256xf16>) {
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<12> : vector<16xindex>
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ dense<true> : vector<16xi1>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+ : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
+ %mask = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<1> : vector<1x1x16xi1>
+ %offset = arith.constant
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+ dense<12> : vector<1x1x16xindex>
+ %0 = xegpu.load %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+ xegpu.store %0, %src[%offset], %mask
+ <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+ : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+ gpu.return
+}
+
}
>From 9374957ca5b6f7b064f5c9cf7975ff33626cc3a8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:10:50 +0000
Subject: [PATCH 2/7] clang-format
---
.../XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9ae564c685e07..ec9d1bcd90cb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -885,8 +885,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
- SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction, SgToWiMultiDimReduction>(typeConverter,
- patterns.getContext());
+ SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+ SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
}
void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
>From d9982f36dddfe98cc9d851e5d3a50b9756a8784e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:23:46 +0000
Subject: [PATCH 3/7] add flatten
---
.../XeGPUSgToWiDistributeExperimental.cpp | 50 ++++++++++++++++---
.../XeGPU/sg-to-wi-experimental-unit.mlir | 12 +++--
2 files changed, 50 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index ec9d1bcd90cb8..f5a48f39860c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -453,12 +453,29 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
VectorType::get({expectedWiResultTy.getNumElements()},
expectedWiResultTy.getElementType());
- // Build the new op with adapted (type-converted) values.
- // Use Value() for offsets if not present (optional operand).
+ // Flatten offsets and mask to 1D for hardware compatibility.
Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ if (offsets) {
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+
+ // Build the new op with adapted (type-converted) values.
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
- offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+ offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
// Cast the result to the expected type if needed (e.g., 1D to 2D).
@@ -665,13 +682,30 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
adaptedValue)
.getResult();
- // Build the new op with adapted values.
- // Use Value() for offsets if not present (optional operand).
+ // Flatten offsets and mask to 1D for hardware compatibility.
Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+ if (offsets) {
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+
+ // Build the new op with adapted values.
xegpu::StoreScatterOp::create(
- rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
- adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
+ op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
rewriter.eraseOp(op);
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1dd3b07c69a13..1cf41667b7a97 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -242,12 +242,16 @@ gpu.func @scatter_store(%src: memref<256xf16>) {
// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
-// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V2:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[V1]]], %[[V2]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
-// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+// CHECK: %[[V3:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V4:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[V3]]], %[[V4]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
%mask = arith.constant
{layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
>From be9507e56df7623ad21f125e9daaf38c6d94d609 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 18:49:53 +0000
Subject: [PATCH 4/7] Clean up
---
.../XeGPUSgToWiDistributeExperimental.cpp | 95 +++++--------------
1 file changed, 26 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index f5a48f39860c2..b3a9f8cd86667 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -396,25 +396,6 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
};
/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
-/// The result at workitem level is always 1D. A ShapeCast is added to restore
-/// the expected rank from the lane layout if needed.
-///
-/// Example (with chunk_size):
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-/// To:
-/// %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
-/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-/// %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
-///
-/// Example (without chunk_size):
-/// %0 = xegpu.load %src[%offset], %mask
-/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-/// : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-/// To:
-/// %0 = xegpu.load %src[%offset], %mask
-/// : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
@@ -422,7 +403,6 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
if (!layout)
return failure();
@@ -430,7 +410,7 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
if (!resultTy)
return failure();
- // Check that all leading dimensions are unit dimensions.
+ // Check that leading dimensions are unit.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
@@ -448,15 +428,13 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
"unable to compute expected workitem vector type from lane layout");
VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
- // The hardware-supported WI type for scatter ops is always 1D.
VectorType supportedWiResultTy =
VectorType::get({expectedWiResultTy.getNumElements()},
expectedWiResultTy.getElementType());
- // Flatten offsets and mask to 1D for hardware compatibility.
- Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
- if (offsets) {
- auto offsetsTy = cast<VectorType>(offsets.getType());
+ // Flatten offsets and mask to 1D to match the 1D result type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
if (offsetsTy != offsetsTy1D)
@@ -465,20 +443,20 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
.getResult();
}
Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
-
- // Build the new op with adapted (type-converted) values.
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
- // Cast the result to the expected type if needed (e.g., 1D to 2D).
Value result = newOp->getResult(0);
if (supportedWiResultTy != expectedWiResultTy)
result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
@@ -617,25 +595,7 @@ struct LowerVectorMultiReductionPattern
};
/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
-/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
-/// A ShapeCast is added to cast the incoming value to 1D if needed.
-///
-/// Example (with chunk_size):
-/// xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
-/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-/// : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-/// %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
-/// xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
-/// : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example (without chunk_size):
-/// xegpu.store %val, %src[%offset], %mask
-/// <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-/// : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-/// xegpu.store %val, %src[%offset], %mask
-/// : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// workitem-level.
struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
@@ -643,7 +603,6 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
- // If no layout, nothing to do.
if (!layout)
return failure();
@@ -669,12 +628,10 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
"unable to compute expected workitem vector type from lane layout");
VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
- // The hardware-supported WI type for scatter ops is always 1D.
VectorType supportedWiValueTy =
VectorType::get({expectedWiValueTy.getNumElements()},
expectedWiValueTy.getElementType());
- // Cast the adapted value to the supported 1D type if needed.
Value adaptedValue = adaptor.getValue();
if (adaptedValue.getType() != supportedWiValueTy)
adaptedValue =
@@ -682,10 +639,9 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
adaptedValue)
.getResult();
- // Flatten offsets and mask to 1D for hardware compatibility.
- Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
- if (offsets) {
- auto offsetsTy = cast<VectorType>(offsets.getType());
+ // Flatten offsets and mask to 1D to match the 1D value type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
if (offsetsTy != offsetsTy1D)
@@ -694,14 +650,15 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
.getResult();
}
Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
-
- // Build the new op with adapted values.
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
+
xegpu::StoreScatterOp::create(
rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
>From d5e9f43b372571d4ee8ac32230b262eb80fb35ad Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Feb 2026 20:38:36 +0000
Subject: [PATCH 5/7] Address feedback
---
.../XeGPUSgToWiDistributeExperimental.cpp | 58 ++++++++-----------
1 file changed, 24 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index b3a9f8cd86667..02b7ab86e8cb4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -413,12 +413,12 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
// Check that leading dimensions are unit.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
- if (resultTy.getShape()[i] != 1)
- return rewriter.notifyMatchFailure(
- op, "Only unit dimensions allowed for the leading "
- "dimensions of the load vector!");
- }
+ ArrayRef<int64_t> shape = resultTy.getShape();
+ if (llvm::any_of(shape.take_front(resultTy.getRank() - effectiveVecRank),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the load vector!");
auto expectedWiResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
@@ -437,19 +437,15 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
- if (offsetsTy != offsetsTy1D)
- offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
- offsetsTy1D, offsets)
- .getResult();
+ offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+ offsetsTy1D);
}
Value mask = adaptor.getMask();
if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
VectorType maskTy1D =
VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask =
- vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
+ mask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
}
auto newOp = xegpu::LoadGatherOp::create(
@@ -459,9 +455,8 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
Value result = newOp->getResult(0);
if (supportedWiResultTy != expectedWiResultTy)
- result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
- expectedWiResultTy, result)
- .getResult();
+ result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
+ expectedWiResultTy);
rewriter.replaceOp(op, result);
return success();
}
@@ -613,12 +608,12 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
// Check that all leading dimensions are unit dimensions.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- for (int i = 0; i < valueTy.getRank() - effectiveVecRank; i++) {
- if (valueTy.getShape()[i] != 1)
- return rewriter.notifyMatchFailure(
- op, "Only unit dimensions allowed for the leading "
- "dimensions of the store vector!");
- }
+ ArrayRef<int64_t> shape = valueTy.getShape();
+ if (llvm::any_of(shape.take_front(valueTy.getRank() - effectiveVecRank),
+ [](int64_t d) { return d != 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the store vector!");
auto expectedWiValueTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
@@ -635,28 +630,23 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
Value adaptedValue = adaptor.getValue();
if (adaptedValue.getType() != supportedWiValueTy)
adaptedValue =
- vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
- adaptedValue)
- .getResult();
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptedValue),
+ supportedWiValueTy);
// Flatten offsets and mask to 1D to match the 1D value type.
Value offsets = adaptor.getOffsets();
if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
offsetsTy.getElementType());
- if (offsetsTy != offsetsTy1D)
- offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
- offsetsTy1D, offsets)
- .getResult();
+ offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+ offsetsTy1D);
}
Value mask = adaptor.getMask();
if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
VectorType maskTy1D =
VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- if (maskTy != maskTy1D)
- mask =
- vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
- .getResult();
+ mask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
}
xegpu::StoreScatterOp::create(
>From 58183276462fa4df8db45ff24f23f8d29f145a32 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Feb 2026 20:56:53 +0000
Subject: [PATCH 6/7] Clean up
---
.../XeGPUSgToWiDistributeExperimental.cpp | 44 +++++++++----------
1 file changed, 20 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 02b7ab86e8cb4..33c982f34eb6f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -434,19 +434,17 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
// Flatten offsets and mask to 1D to match the 1D result type.
Value offsets = adaptor.getOffsets();
- if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
- VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
- offsetsTy.getElementType());
- offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
- offsetsTy1D);
- }
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+ offsetsTy1D);
+
Value mask = adaptor.getMask();
- if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- mask =
- castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
- }
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
auto newOp = xegpu::LoadGatherOp::create(
rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
@@ -635,19 +633,17 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
// Flatten offsets and mask to 1D to match the 1D value type.
Value offsets = adaptor.getOffsets();
- if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
- VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
- offsetsTy.getElementType());
- offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
- offsetsTy1D);
- }
+ auto offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+ offsetsTy1D);
+
Value mask = adaptor.getMask();
- if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- mask =
- castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
- }
+ auto maskTy = cast<VectorType>(mask.getType());
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
xegpu::StoreScatterOp::create(
rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
>From da81eeb6ee3c81a7a5cf74498154a9468e3a9a77 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 26 Feb 2026 20:48:39 +0000
Subject: [PATCH 7/7] Address feedback
---
.../XeGPUSgToWiDistributeExperimental.cpp | 192 ++++++++++++------
1 file changed, 129 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 33c982f34eb6f..b6a95d8782bf0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -396,6 +396,38 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
};
/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// Distributed to:
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// Distributed to:
+/// %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///
+/// Example 3 (3D with leading unit dims):
+/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+/// %mask = producer_op : vector<1x1x16xi1>
+/// %offset = producer_op : vector<1x1x16xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+/// Distributed to:
+/// %mask = producer_op : vector<1x1x1xi1>
+/// %offset = producer_op : vector<1x1x1xindex>
+/// %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
@@ -406,55 +438,57 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
if (!layout)
return failure();
- VectorType resultTy = op.getValueType();
- if (!resultTy)
+ VectorType origResultTy = op.getValueType();
+ if (!origResultTy)
return failure();
// Check that leading dimensions are unit.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- ArrayRef<int64_t> shape = resultTy.getShape();
- if (llvm::any_of(shape.take_front(resultTy.getRank() - effectiveVecRank),
- [](int64_t d) { return d != 1; }))
+ ArrayRef<int64_t> shape = origResultTy.getShape();
+ if (llvm::any_of(
+ shape.take_front(origResultTy.getRank() - effectiveVecRank),
+ [](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "Only unit dimensions allowed for the leading "
"dimensions of the load vector!");
- auto expectedWiResultTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
- if (failed(expectedWiResultTyOrFailure))
+ auto distResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
+ if (failed(distResultTyOrFailure))
return rewriter.notifyMatchFailure(
op,
"unable to compute expected workitem vector type from lane layout");
- VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
- VectorType supportedWiResultTy =
- VectorType::get({expectedWiResultTy.getNumElements()},
- expectedWiResultTy.getElementType());
+ VectorType distResultTy = distResultTyOrFailure.value();
+ VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
+ distResultTy.getElementType());
// Flatten offsets and mask to 1D to match the 1D result type.
- Value offsets = adaptor.getOffsets();
- auto offsetsTy = cast<VectorType>(offsets.getType());
- VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
- offsetsTy.getElementType());
- offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
- offsetsTy1D);
-
- Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-
+ Value distOffsets = adaptor.getOffsets();
+ auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+ VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+ distOffsetsTy.getElementType());
+ distOffsets = castValueTo(
+ rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+ Value distMask = adaptor.getMask();
+ auto distMaskTy = cast<VectorType>(distMask.getType());
+ VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+ distMaskTy.getElementType());
+ distMask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+ Value distSource = adaptor.getSource();
auto newOp = xegpu::LoadGatherOp::create(
- rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
- offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+ rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
+ distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
Value result = newOp->getResult(0);
- if (supportedWiResultTy != expectedWiResultTy)
+ if (distResultTy1D != distResultTy)
result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
- expectedWiResultTy);
+ distResultTy);
rewriter.replaceOp(op, result);
return success();
}
@@ -589,6 +623,38 @@ struct LowerVectorMultiReductionPattern
/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
/// workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+/// layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+/// layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 3 (3D with leading unit dims):
+/// layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+/// %mask = producer_op : vector<1x1x16xi1>
+/// %offset = producer_op : vector<1x1x16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
+/// memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+/// Distributed to:
+/// %mask = producer_op : vector<1x1x1xi1>
+/// %offset = producer_op : vector<1x1x1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
@@ -599,56 +665,56 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
if (!layout)
return failure();
- VectorType valueTy = op.getValueType();
- if (!valueTy)
+ VectorType origValueTy = op.getValueType();
+ if (!origValueTy)
return failure();
// Check that all leading dimensions are unit dimensions.
int chunkSize = op.getChunkSize().value_or(1);
int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
- ArrayRef<int64_t> shape = valueTy.getShape();
- if (llvm::any_of(shape.take_front(valueTy.getRank() - effectiveVecRank),
+ ArrayRef<int64_t> shape = origValueTy.getShape();
+ if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
[](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "Only unit dimensions allowed for the leading "
"dimensions of the store vector!");
- auto expectedWiValueTyOrFailure =
- xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
- if (failed(expectedWiValueTyOrFailure))
+ auto distValueTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
+ if (failed(distValueTyOrFailure))
return rewriter.notifyMatchFailure(
op,
"unable to compute expected workitem vector type from lane layout");
- VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
- VectorType supportedWiValueTy =
- VectorType::get({expectedWiValueTy.getNumElements()},
- expectedWiValueTy.getElementType());
+ VectorType distValueTy = distValueTyOrFailure.value();
+ VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
+ distValueTy.getElementType());
- Value adaptedValue = adaptor.getValue();
- if (adaptedValue.getType() != supportedWiValueTy)
- adaptedValue =
- castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptedValue),
- supportedWiValueTy);
+ Value distValue = adaptor.getValue();
+ if (distValue.getType() != distValueTy1D)
+ distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
+ distValueTy1D);
// Flatten offsets and mask to 1D to match the 1D value type.
- Value offsets = adaptor.getOffsets();
- auto offsetsTy = cast<VectorType>(offsets.getType());
- VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
- offsetsTy.getElementType());
- offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
- offsetsTy1D);
-
- Value mask = adaptor.getMask();
- auto maskTy = cast<VectorType>(mask.getType());
- VectorType maskTy1D =
- VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
- mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-
- xegpu::StoreScatterOp::create(
- rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
- op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
- op.getL3HintAttr(), /*layout=*/nullptr);
+ Value distOffsets = adaptor.getOffsets();
+ auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+ VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+ distOffsetsTy.getElementType());
+ distOffsets = castValueTo(
+ rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+ Value distMask = adaptor.getMask();
+ auto distMaskTy = cast<VectorType>(distMask.getType());
+ VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+ distMaskTy.getElementType());
+ distMask =
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+ Value distDest = adaptor.getDest();
+ xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
+ distOffsets, distMask, op.getChunkSizeAttr(),
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /*layout=*/nullptr);
rewriter.eraseOp(op);
return success();
}
More information about the Mlir-commits
mailing list