[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 17 13:53:16 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>

This PR adds distribution pattern for xegpu.load & store ops for the new sg-to-wi pass

---
Full diff: https://github.com/llvm/llvm-project/pull/181917.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp (+148-2) 
- (modified) mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir (+114) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 3787fbb44e1b8..b3a9f8cd86667 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -395,6 +395,78 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType resultTy = op.getValueType();
+    if (!resultTy)
+      return failure();
+
+    // Check that leading dimensions are unit.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
+      if (resultTy.getShape()[i] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "Only unit dimensions allowed for the leading "
+                "dimensions of the load vector!");
+    }
+
+    auto expectedWiResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+    if (failed(expectedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+    VectorType supportedWiResultTy =
+        VectorType::get({expectedWiResultTy.getNumElements()},
+                        expectedWiResultTy.getElementType());
+
+    // Flatten offsets and mask to 1D to match the 1D result type.
+    Value offsets = adaptor.getOffsets();
+    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                               offsetsTy.getElementType());
+      if (offsetsTy != offsetsTy1D)
+        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                              offsetsTy1D, offsets)
+                      .getResult();
+    }
+    Value mask = adaptor.getMask();
+    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+      VectorType maskTy1D =
+          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+      if (maskTy != maskTy1D)
+        mask =
+            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                .getResult();
+    }
+
+    auto newOp = xegpu::LoadGatherOp::create(
+        rewriter, op.getLoc(), supportedWiResultTy, adaptor.getSource(),
+        offsets, mask, op.getChunkSizeAttr(), op.getL1HintAttr(),
+        op.getL2HintAttr(), op.getL3HintAttr(), /*layout=*/nullptr);
+
+    Value result = newOp->getResult(0);
+    if (supportedWiResultTy != expectedWiResultTy)
+      result = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                           expectedWiResultTy, result)
+                   .getResult();
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /// This pattern distributes a subgroup-level vector.reduction op to
 /// workitem-level. This require shuffling the data across the workitems (using
 /// gpu::ShuffleOp) and reducing in stages until all workitems have the final
@@ -522,6 +594,80 @@ struct LowerVectorMultiReductionPattern
   }
 };
 
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level.
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType valueTy = op.getValueType();
+    if (!valueTy)
+      return failure();
+
+    // Check that all leading dimensions are unit dimensions.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    for (int i = 0; i < valueTy.getRank() - effectiveVecRank; i++) {
+      if (valueTy.getShape()[i] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "Only unit dimensions allowed for the leading "
+                "dimensions of the store vector!");
+    }
+
+    auto expectedWiValueTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, valueTy);
+    if (failed(expectedWiValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType expectedWiValueTy = expectedWiValueTyOrFailure.value();
+    VectorType supportedWiValueTy =
+        VectorType::get({expectedWiValueTy.getNumElements()},
+                        expectedWiValueTy.getElementType());
+
+    Value adaptedValue = adaptor.getValue();
+    if (adaptedValue.getType() != supportedWiValueTy)
+      adaptedValue =
+          vector::ShapeCastOp::create(rewriter, op.getLoc(), supportedWiValueTy,
+                                      adaptedValue)
+              .getResult();
+
+    // Flatten offsets and mask to 1D to match the 1D value type.
+    Value offsets = adaptor.getOffsets();
+    if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+      VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+                                               offsetsTy.getElementType());
+      if (offsetsTy != offsetsTy1D)
+        offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                              offsetsTy1D, offsets)
+                      .getResult();
+    }
+    Value mask = adaptor.getMask();
+    if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+      VectorType maskTy1D =
+          VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+      if (maskTy != maskTy1D)
+        mask =
+            vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+                .getResult();
+    }
+
+    xegpu::StoreScatterOp::create(
+        rewriter, op.getLoc(), adaptedValue, adaptor.getDest(), offsets, mask,
+        op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+        op.getL3HintAttr(), /*layout=*/nullptr);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -730,8 +876,8 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
-               SgToWiVectorReduction, SgToWiMultiDimReduction>(
-      typeConverter, patterns.getContext());
+               SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+               SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
 }
 
 void xegpu::populateXeGPUSgToWiLowerVectorMultiReductionAndLegality(
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 1ec0879d4fb47..1cf41667b7a97 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -154,6 +154,120 @@ gpu.func @prefetch_nd() {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @scatter_load_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+gpu.func @scatter_load_chunksize(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store_chunksize
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK: %[[C1:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf16> to vector<1x8xf16>
+// CHECK: %[[C2:.*]] = vector.shape_cast %[[C1]] : vector<1x8xf16> to vector<8xf16>
+// CHECK: xegpu.store %[[C2]], %arg0[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
+// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store_chunksize(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{chunk_size = 8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}>
+    : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_load
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+gpu.func @scatter_load(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_store
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: xegpu.store %[[LOAD]], %arg0[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_store(%src: memref<256xf16>) {
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<12> : vector<16xindex>
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    dense<true> : vector<16xi1>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
+    : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x1xi1>
+// CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1x1x1xindex>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V2:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[V1]]], %[[V2]]
+// CHECK-SAME: : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<1xf16> to vector<1x1x1xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[CAST]] : vector<1x1x1xf16> to vector<1xf16>
+// CHECK: %[[V3:.*]] = vector.shape_cast %[[OFFSET]] : vector<1x1x1xindex> to vector<1xindex>
+// CHECK: %[[V4:.*]] = vector.shape_cast %[[MASK]] : vector<1x1x1xi1> to vector<1xi1>
+// CHECK: xegpu.store %[[CAST2]], %arg0[%[[V3]]], %[[V4]]
+// CHECK-SAME: : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>) {
+  %mask = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+    dense<1> : vector<1x1x16xi1>
+  %offset = arith.constant
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
+    dense<12> : vector<1x1x16xindex>
+  %0 = xegpu.load %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+    : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+  xegpu.store %0, %src[%offset], %mask
+    <{layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}>
+    : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+  gpu.return
+}
+
 // CHECK-LABEL: gpu.func @vector_reduction
 // CHECK:     %[[CST:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK:     %[[LANE_RED:.*]] = vector.reduction <add>, %[[CAST:.*]] : vector<2xf32> into f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/181917


More information about the Mlir-commits mailing list