[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)

Nishant Patel llvmlistbot at llvm.org
Thu Feb 26 12:53:36 PST 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/181917

>From 766846bcae560849cf5827b1818215ab4af9eb59 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 10 Feb 2026 23:30:25 +0000
Subject: [PATCH 1/7] Add pattern for xegpu.load & store

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 140 +++++++++++++++++-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 111 ++++++++++++++
 2 files changed, 249 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 8e530642d9c7a..04b1f66b85f17 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -362,6 +362,141 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+/// The result at workitem level is always 1D. A ShapeCast is added to restore
+/// the expected rank from the lane layout if needed.
+///
+/// Example (with chunk_size):
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
+///     layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+///     : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To:
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
+///     : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///   %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
+///
+/// Example (without chunk_size):
+///   %0 = xegpu.load %src[%offset], %mask
+///     <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+///     : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To:
+///   %0 = xegpu.load %src[%offset], %mask
+///     : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType resultTy = op.getValueType();
+    if (!resultTy)
+      return failure();
+
+    auto expectedWiResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+    if (failed(expectedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+    // The hardware-supported WI type for scatter ops is always 1D.
+    VectorType supportedWiResultTy =
+        VectorType::get({expectedWiResultTy.getNumElements()},
+                        expectedWiResultTy.getElementType());
+
+    // Build the new op with adapted (type-converted) values.
+    // Use Value() for offsets if not present (optional operand).
+    Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+    auto newOp = xegpu::LoadGatherOp::create(
+        rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
+        offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+        op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+
+    // Cast the result to the expected type if needed (e.g., 1D to 2D).
+    Value result = newOp->getResult(0);
+    if (supportedWiResultTy != expectedWiResultTy)
+      result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                           expectedWiResultTy, result)
+                   .getResult();
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
+/// A ShapeCast is added to cast the incoming value to 1D if needed.
+///
+/// Example (with chunk_size):
+///   xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
+///     layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+///     : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+///   %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
+///   xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
+///     : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example (without chunk_size):
+///   xegpu.store %val, %src[%offset], %mask
+///     <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+///     : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To:
+///   xegpu.store %val, %src[%offset], %mask
+///     : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType valueTy = op.getValueType();
+    if (!valueTy)
+      return failure();
+
+    auto expectedWiValueTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+    if (failed(expectedWiValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+    // The hardware-supported WI type for scatter ops is always 1D.
+    VectorType supportedWiValueTy =
+        VectorType::get({expectedWiValueTy.getNumElements()},
+                        expectedWiValueTy.getElementType());
+
+    // Cast the adapted value to the supported 1D type if needed.
+    Value adaptedValue = adaptor.getValue();
+    if (adaptedValue.getType() != supportedWiValueTy)
+      adaptedValue =
+          vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
+                                      adaptedValue)
+              .getResult();
+
+    // Build the new op with adapted values.
+    // Use Value() for offsets if not present (optional operand).
+    Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+    xegpu::StoreScatterOp::create(
+        rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
+        adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+        op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -553,6 +688,7 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
       });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
-               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
-      typeConverter, patterns.getContext());
+               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+               SgToWiLoadGather, SgToWiStoreScatter>(typeConverter,
+                                                     patterns.getContext());
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0e9843f4626d4..8a3dc92cd3a44 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -149,4 +149,115 @@ gpu.func @prefetch_nd() {
     : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   gpu.return
 }
+
+// CHECK-LABEL: gpu.func @scatter_load_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+gpu.func @scatter_load_chunksize(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[C1:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK: %[[C2:.*]] = vector.shape_cast %[[C1]] : vector<1x8xf16> to vector<8xf16>
+// CHECK: xegpu.store %[[C2]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store_chunksize(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_load
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+gpu.func @scatter_load(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: xegpu.store %[[LOAD]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+    dense<1> : vector<1x1x16xi1>
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+    dense<12> : vector<1x1x16xindex>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+    : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+    : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+  gpu.return
+}
+
 }

>From 9374957ca5b6f7b064f5c9cf7975ff33626cc3a8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:10:50 +0000
Subject: [PATCH 2/7] clang-format

---
 .../XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9ae564c685e07..ec9d1bcd90cb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -885,8 +885,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
-               SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction, SgToWiMultiDimReduction>(typeConverter,
-                                                     patterns.getContext());
+               SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+               SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
 }
 
 void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(

>From d9982f36dddfe98cc9d851e5d3a50b9756a8784e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 17:23:46 +0000
Subject: [PATCH 3/7] add flatten

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 50 ++++++++++++++++---
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 12 +++--
 2 files changed, 50 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index ec9d1bcd90cb8..f5a48f39860c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -453,12 +453,29 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
         VectorType::get({expectedWiResultTy.getNumElements()},
                         expectedWiResultTy.getElementType());
 
-    // Build the new op with adapted (type-converted) values.
-    // Use Value() for offsets if not present (optional operand).
+    // Flatten offsets and mask to 1D for hardware compatibility.
     Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+    if (offsets) {
+      auto offsetsTy = cast<VectorType>(offsets.getType());
+      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                               offsetsTy.getElementType());
+      if (offsetsTy != offsetsTy1D)
+        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                              offsetsTy1D, offsets)
+                      .getResult();
+    }
+    Value mask = adaptor.getMask();
+    auto maskTy = cast<VectorType>(mask.getType());
+    VectorType maskTy1D =
+        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+    if (maskTy != maskTy1D)
+      mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                 .getResult();
+
+    // Build the new op with adapted (type-converted) values.
     auto newOp = xegpu::LoadGatherOp::create(
         rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
-        offsets, adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
+        offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
         op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
 
     // Cast the result to the expected type if needed (e.g., 1D to 2D).
@@ -665,13 +682,30 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
                                       adaptedValue)
               .getResult();
 
-    // Build the new op with adapted values.
-    // Use Value() for offsets if not present (optional operand).
+    // Flatten offsets and mask to 1D for hardware compatibility.
     Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
+    if (offsets) {
+      auto offsetsTy = cast<VectorType>(offsets.getType());
+      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                               offsetsTy.getElementType());
+      if (offsetsTy != offsetsTy1D)
+        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                              offsetsTy1D, offsets)
+                      .getResult();
+    }
+    Value mask = adaptor.getMask();
+    auto maskTy = cast<VectorType>(mask.getType());
+    VectorType maskTy1D =
+        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+    if (maskTy != maskTy1D)
+      mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                 .getResult();
+
+    // Build the new op with adapted values.
     xegpu::StoreScatterOp::create(
-        rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets,
-        adaptor.getMask(), op.getChunkSizeAttr(), op.getL1HintAttr(),
-        op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+        rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
+        op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+        op.getL3HintAttr(), /*layout=*/nullptr);
     rewriter.eraseOp(op);
     return success();
   }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1dd3b07c69a13..1cf41667b7a97 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -242,12 +242,16 @@ gpu.func @scatter_store(%src: memref<256xf16>) {
 // CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
 // CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
-// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1> -> vector<1xf16>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V2:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[V1]]], %[[V2]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
 // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
 // CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
-// CHECK: xegpu.store %[[CAST2]], %arg0[%[[OFFSET]]], %[[MASK]]
-// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>
+// CHECK: %[[V3:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V4:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[V3]]], %[[V4]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
   %mask = arith.constant
     {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}

>From be9507e56df7623ad21f125e9daaf38c6d94d609 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 17 Feb 2026 18:49:53 +0000
Subject: [PATCH 4/7] Clean up

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 95 +++++--------------
 1 file changed, 26 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index f5a48f39860c2..b3a9f8cd86667 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -396,25 +396,6 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
 };
 
 /// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
-/// The result at workitem level is always 1D. A ShapeCast is added to restore
-/// the expected rank from the lane layout if needed.
-///
-/// Example (with chunk_size):
-///   %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8,
-///     layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-///     : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-/// To:
-///   %0 = xegpu.load %src[%offset], %mask <{chunk_size = 8}>
-///     : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-///   %1 = vector.shape_cast %0 : vector<8xf16> to vector<1x8xf16>
-///
-/// Example (without chunk_size):
-///   %0 = xegpu.load %src[%offset], %mask
-///     <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-///     : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-/// To:
-///   %0 = xegpu.load %src[%offset], %mask
-///     : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
 struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
   using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
 
@@ -422,7 +403,6 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
   matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    // If no layout, nothing to do.
     if (!layout)
       return failure();
 
@@ -430,7 +410,7 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
     if (!resultTy)
       return failure();
 
-    // Check that all leading dimensions are unit dimensions.
+    // Check that leading dimensions are unit.
     int chunkSize = op.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
     for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
@@ -448,15 +428,13 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
           "unable to compute expected workitem vector type from lane layout");
 
     VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
-    // The hardware-supported WI type for scatter ops is always 1D.
     VectorType supportedWiResultTy =
         VectorType::get({expectedWiResultTy.getNumElements()},
                         expectedWiResultTy.getElementType());
 
-    // Flatten offsets and mask to 1D for hardware compatibility.
-    Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
-    if (offsets) {
-      auto offsetsTy = cast<VectorType>(offsets.getType());
+    // Flatten offsets and mask to 1D to match the 1D result type.
+    Value offsets = adaptor.getOffsets();
+    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
       VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
                                                offsetsTy.getElementType());
       if (offsetsTy != offsetsTy1D)
@@ -465,20 +443,20 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
                       .getResult();
     }
     Value mask = adaptor.getMask();
-    auto maskTy = cast<VectorType>(mask.getType());
-    VectorType maskTy1D =
-        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-    if (maskTy != maskTy1D)
-      mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
-                 .getResult();
-
-    // Build the new op with adapted (type-converted) values.
+    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+      VectorType maskTy1D =
+          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+      if (maskTy != maskTy1D)
+        mask =
+            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                .getResult();
+    }
+
     auto newOp = xegpu::LoadGatherOp::create(
         rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
         offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
         op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
 
-    // Cast the result to the expected type if needed (e.g., 1D to 2D).
     Value result = newOp->getResult(0);
     if (supportedWiResultTy != expectedWiResultTy)
       result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
@@ -617,25 +595,7 @@ struct LowerVectorMultiReductionPattern
 };
 
 /// Distributes a subgroup-level StoreScatter (xegpu.store) op to
-/// workitem-level. Stored value in workitem-level StoreScatter op is 1D.
-/// A ShapeCast is added to cast the incoming value to 1D if needed.
-///
-/// Example (with chunk_size):
-///   xegpu.store %val, %src[%offset], %mask <{chunk_size = 8,
-///     layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
-///     : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-///   %0 = vector.shape_cast %val : vector<1x8xf16> to vector<8xf16>
-///   xegpu.store %0, %src[%offset], %mask <{chunk_size = 8}>
-///     : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example (without chunk_size):
-///   xegpu.store %val, %src[%offset], %mask
-///     <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
-///     : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// To:
-///   xegpu.store %val, %src[%offset], %mask
-///     : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// workitem-level.
 struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
 
@@ -643,7 +603,6 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    // If no layout, nothing to do.
     if (!layout)
       return failure();
 
@@ -669,12 +628,10 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
           "unable to compute expected workitem vector type from lane layout");
 
     VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
-    // The hardware-supported WI type for scatter ops is always 1D.
     VectorType supportedWiValueTy =
         VectorType::get({expectedWiValueTy.getNumElements()},
                         expectedWiValueTy.getElementType());
 
-    // Cast the adapted value to the supported 1D type if needed.
     Value adaptedValue = adaptor.getValue();
     if (adaptedValue.getType() != supportedWiValueTy)
       adaptedValue =
@@ -682,10 +639,9 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
                                       adaptedValue)
               .getResult();
 
-    // Flatten offsets and mask to 1D for hardware compatibility.
-    Value offsets = adaptor.getOffsets() ? adaptor.getOffsets() : Value();
-    if (offsets) {
-      auto offsetsTy = cast<VectorType>(offsets.getType());
+    // Flatten offsets and mask to 1D to match the 1D value type.
+    Value offsets = adaptor.getOffsets();
+    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
       VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
                                                offsetsTy.getElementType());
       if (offsetsTy != offsetsTy1D)
@@ -694,14 +650,15 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
                       .getResult();
     }
     Value mask = adaptor.getMask();
-    auto maskTy = cast<VectorType>(mask.getType());
-    VectorType maskTy1D =
-        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-    if (maskTy != maskTy1D)
-      mask = vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
-                 .getResult();
-
-    // Build the new op with adapted values.
+    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+      VectorType maskTy1D =
+          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+      if (maskTy != maskTy1D)
+        mask =
+            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                .getResult();
+    }
+
     xegpu::StoreScatterOp::create(
         rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
         op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),

>From d5e9f43b372571d4ee8ac32230b262eb80fb35ad Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Feb 2026 20:38:36 +0000
Subject: [PATCH 5/7] Address feedback

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 58 ++++++++-----------
 1 file changed, 24 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index b3a9f8cd86667..02b7ab86e8cb4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -413,12 +413,12 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
     // Check that leading dimensions are unit.
     int chunkSize = op.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
-      if (resultTy.getShape()[i] != 1)
-        return rewriter.notifyMatchFailure(
-            op, "Only unit dimensions allowed for the leading "
-                "dimensions of the load vector!");
-    }
+    ArrayRef<int64_t> shape = resultTy.getShape();
+    if (llvm::any_of(shape.take_front(resultTy.getRank() - effectiveVecRank),
+                     [](int64_t d) { return d != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "Only unit dimensions allowed for the leading "
+              "dimensions of the load vector!");
 
     auto expectedWiResultTyOrFailure =
         xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
@@ -437,19 +437,15 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
     if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
       VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
                                                offsetsTy.getElementType());
-      if (offsetsTy != offsetsTy1D)
-        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
-                                              offsetsTy1D, offsets)
-                      .getResult();
+      offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+                            offsetsTy1D);
     }
     Value mask = adaptor.getMask();
     if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
       VectorType maskTy1D =
           VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-      if (maskTy != maskTy1D)
-        mask =
-            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
-                .getResult();
+      mask =
+          castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
     }
 
     auto newOp = xegpu::LoadGatherOp::create(
@@ -459,9 +455,8 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
 
     Value result = newOp->getResult(0);
     if (supportedWiResultTy != expectedWiResultTy)
-      result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
-                                           expectedWiResultTy, result)
-                   .getResult();
+      result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
+                           expectedWiResultTy);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -613,12 +608,12 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
     // Check that all leading dimensions are unit dimensions.
     int chunkSize = op.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    for (int i = 0; i < valueTy.getRank() - effectiveVecRank; i++) {
-      if (valueTy.getShape()[i] != 1)
-        return rewriter.notifyMatchFailure(
-            op, "Only unit dimensions allowed for the leading "
-                "dimensions of the store vector!");
-    }
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    if (llvm::any_of(shape.take_front(valueTy.getRank() - effectiveVecRank),
+                     [](int64_t d) { return d != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "Only unit dimensions allowed for the leading "
+              "dimensions of the store vector!");
 
     auto expectedWiValueTyOrFailure =
         xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
@@ -635,28 +630,23 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
     Value adaptedValue = adaptor.getValue();
     if (adaptedValue.getType() != supportedWiValueTy)
       adaptedValue =
-          vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
-                                      adaptedValue)
-              .getResult();
+          castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptedValue),
+                      supportedWiValueTy);
 
     // Flatten offsets and mask to 1D to match the 1D value type.
     Value offsets = adaptor.getOffsets();
     if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
       VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
                                                offsetsTy.getElementType());
-      if (offsetsTy != offsetsTy1D)
-        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
-                                              offsetsTy1D, offsets)
-                      .getResult();
+      offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+                            offsetsTy1D);
     }
     Value mask = adaptor.getMask();
     if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
       VectorType maskTy1D =
           VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-      if (maskTy != maskTy1D)
-        mask =
-            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
-                .getResult();
+      mask =
+          castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
     }
 
     xegpu::StoreScatterOp::create(

>From 58183276462fa4df8db45ff24f23f8d29f145a32 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Feb 2026 20:56:53 +0000
Subject: [PATCH 6/7] Clean up

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 44 +++++++++----------
 1 file changed, 20 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 02b7ab86e8cb4..33c982f34eb6f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -434,19 +434,17 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
 
     // Flatten offsets and mask to 1D to match the 1D result type.
     Value offsets = adaptor.getOffsets();
-    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
-      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
-                                               offsetsTy.getElementType());
-      offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
-                            offsetsTy1D);
-    }
+    auto offsetsTy = cast<VectorType>(offsets.getType());
+    VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                             offsetsTy.getElementType());
+    offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+                          offsetsTy1D);
+
     Value mask = adaptor.getMask();
-    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
-      VectorType maskTy1D =
-          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-      mask =
-          castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-    }
+    auto maskTy = cast<VectorType>(mask.getType());
+    VectorType maskTy1D =
+        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+    mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
 
     auto newOp = xegpu::LoadGatherOp::create(
         rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
@@ -635,19 +633,17 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
 
     // Flatten offsets and mask to 1D to match the 1D value type.
     Value offsets = adaptor.getOffsets();
-    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
-      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
-                                               offsetsTy.getElementType());
-      offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
-                            offsetsTy1D);
-    }
+    auto offsetsTy = cast<VectorType>(offsets.getType());
+    VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                             offsetsTy.getElementType());
+    offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
+                          offsetsTy1D);
+
     Value mask = adaptor.getMask();
-    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
-      VectorType maskTy1D =
-          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-      mask =
-          castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-    }
+    auto maskTy = cast<VectorType>(mask.getType());
+    VectorType maskTy1D =
+        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+    mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
 
     xegpu::StoreScatterOp::create(
         rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,

>From da81eeb6ee3c81a7a5cf74498154a9468e3a9a77 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 26 Feb 2026 20:48:39 +0000
Subject: [PATCH 7/7] Address feedback

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 192 ++++++++++++------
 1 file changed, 129 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 33c982f34eb6f..b6a95d8782bf0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -396,6 +396,38 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
 };
 
 /// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+///   %mask = producer_op : vector<16xi1>
+///   %offset = producer_op : vector<16xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// Distributed to:
+///   %mask = producer_op : vector<1xi1>
+///   %offset = producer_op : vector<1xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// Distributed to:
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///
+/// Example 3 (3D with leading unit dims):
+///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+///   %mask = producer_op : vector<1x1x16xi1>
+///   %offset = producer_op : vector<1x1x16xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+/// Distributed to:
+///   %mask = producer_op : vector<1x1x1xi1>
+///   %offset = producer_op : vector<1x1x1xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
 struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
   using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
 
@@ -406,55 +438,57 @@ struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
     if (!layout)
       return failure();
 
-    VectorType resultTy = op.getValueType();
-    if (!resultTy)
+    VectorType origResultTy = op.getValueType();
+    if (!origResultTy)
       return failure();
 
     // Check that leading dimensions are unit.
     int chunkSize = op.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    ArrayRef<int64_t> shape = resultTy.getShape();
-    if (llvm::any_of(shape.take_front(resultTy.getRank() - effectiveVecRank),
-                     [](int64_t d) { return d != 1; }))
+    ArrayRef<int64_t> shape = origResultTy.getShape();
+    if (llvm::any_of(
+            shape.take_front(origResultTy.getRank() - effectiveVecRank),
+            [](int64_t d) { return d != 1; }))
       return rewriter.notifyMatchFailure(
           op, "Only unit dimensions allowed for the leading "
               "dimensions of the load vector!");
 
-    auto expectedWiResultTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
-    if (failed(expectedWiResultTyOrFailure))
+    auto distResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
+    if (failed(distResultTyOrFailure))
       return rewriter.notifyMatchFailure(
           op,
           "unable to compute expected workitem vector type from lane layout");
 
-    VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
-    VectorType supportedWiResultTy =
-        VectorType::get({expectedWiResultTy.getNumElements()},
-                        expectedWiResultTy.getElementType());
+    VectorType distResultTy = distResultTyOrFailure.value();
+    VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
+                                                distResultTy.getElementType());
 
     // Flatten offsets and mask to 1D to match the 1D result type.
-    Value offsets = adaptor.getOffsets();
-    auto offsetsTy = cast<VectorType>(offsets.getType());
-    VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
-                                             offsetsTy.getElementType());
-    offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
-                          offsetsTy1D);
-
-    Value mask = adaptor.getMask();
-    auto maskTy = cast<VectorType>(mask.getType());
-    VectorType maskTy1D =
-        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-    mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-
+    Value distOffsets = adaptor.getOffsets();
+    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+                                             distOffsetsTy.getElementType());
+    distOffsets = castValueTo(
+        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+    Value distMask = adaptor.getMask();
+    auto distMaskTy = cast<VectorType>(distMask.getType());
+    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+                                          distMaskTy.getElementType());
+    distMask =
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+    Value distSource = adaptor.getSource();
     auto newOp = xegpu::LoadGatherOp::create(
-        rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
-        offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
-        op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+        rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
+        distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+        op.getL3HintAttr(), /*layout=*/nullptr);
 
     Value result = newOp->getResult(0);
-    if (supportedWiResultTy != expectedWiResultTy)
+    if (distResultTy1D != distResultTy)
       result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
-                           expectedWiResultTy);
+                           distResultTy);
     rewriter.replaceOp(op, result);
     return success();
   }
@@ -589,6 +623,38 @@ struct LowerVectorMultiReductionPattern
 
 /// Distributes a subgroup-level StoreScatter (xegpu.store) op to
 /// workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+///   %mask = producer_op : vector<16xi1>
+///   %offset = producer_op : vector<16xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+///     memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+///   %mask = producer_op : vector<1xi1>
+///   %offset = producer_op : vector<1xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 3 (3D with leading unit dims):
+///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+///   %mask = producer_op : vector<1x1x16xi1>
+///   %offset = producer_op : vector<1x1x16xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
+///     memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+/// Distributed to:
+///   %mask = producer_op : vector<1x1x1xi1>
+///   %offset = producer_op : vector<1x1x1xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
 struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
 
@@ -599,56 +665,56 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
     if (!layout)
       return failure();
 
-    VectorType valueTy = op.getValueType();
-    if (!valueTy)
+    VectorType origValueTy = op.getValueType();
+    if (!origValueTy)
       return failure();
 
     // Check that all leading dimensions are unit dimensions.
     int chunkSize = op.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    ArrayRef<int64_t> shape = valueTy.getShape();
-    if (llvm::any_of(shape.take_front(valueTy.getRank() - effectiveVecRank),
+    ArrayRef<int64_t> shape = origValueTy.getShape();
+    if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
                      [](int64_t d) { return d != 1; }))
       return rewriter.notifyMatchFailure(
           op, "Only unit dimensions allowed for the leading "
               "dimensions of the store vector!");
 
-    auto expectedWiValueTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
-    if (failed(expectedWiValueTyOrFailure))
+    auto distValueTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
+    if (failed(distValueTyOrFailure))
       return rewriter.notifyMatchFailure(
           op,
           "unable to compute expected workitem vector type from lane layout");
 
-    VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
-    VectorType supportedWiValueTy =
-        VectorType::get({expectedWiValueTy.getNumElements()},
-                        expectedWiValueTy.getElementType());
+    VectorType distValueTy = distValueTyOrFailure.value();
+    VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
+                                               distValueTy.getElementType());
 
-    Value adaptedValue = adaptor.getValue();
-    if (adaptedValue.getType() != supportedWiValueTy)
-      adaptedValue =
-          castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptedValue),
-                      supportedWiValueTy);
+    Value distValue = adaptor.getValue();
+    if (distValue.getType() != distValueTy1D)
+      distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
+                              distValueTy1D);
 
     // Flatten offsets and mask to 1D to match the 1D value type.
-    Value offsets = adaptor.getOffsets();
-    auto offsetsTy = cast<VectorType>(offsets.getType());
-    VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
-                                             offsetsTy.getElementType());
-    offsets = castValueTo(rewriter, cast<TypedValue<VectorType>>(offsets),
-                          offsetsTy1D);
-
-    Value mask = adaptor.getMask();
-    auto maskTy = cast<VectorType>(mask.getType());
-    VectorType maskTy1D =
-        VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
-    mask = castValueTo(rewriter, cast<TypedValue<VectorType>>(mask), maskTy1D);
-
-    xegpu::StoreScatterOp::create(
-        rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
-        op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
-        op.getL3HintAttr(), /*layout=*/nullptr);
+    Value distOffsets = adaptor.getOffsets();
+    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+                                             distOffsetsTy.getElementType());
+    distOffsets = castValueTo(
+        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+    Value distMask = adaptor.getMask();
+    auto distMaskTy = cast<VectorType>(distMask.getType());
+    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+                                          distMaskTy.getElementType());
+    distMask =
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+    Value distDest = adaptor.getDest();
+    xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
+                                  distOffsets, distMask, op.getChunkSizeAttr(),
+                                  op.getL1HintAttr(), op.getL2HintAttr(),
+                                  op.getL3HintAttr(), /*layout=*/nullptr);
     rewriter.eraseOp(op);
     return success();
   }



More information about the Mlir-commits mailing list