[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg (PR #154420)

Nishant Patel llvmlistbot at llvm.org
Fri Aug 22 07:41:24 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/154420

>From b94a37fb467510f4cfe9e3b902a2147f99010c8f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 18 Aug 2025 16:18:26 +0000
Subject: [PATCH 1/2] Add pattern for load_gather and store_scatter ops

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  12 +++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  34 ++++++
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 100 +++++++++++++++++-
 3 files changed, 145 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 480b43e740736..34429a34f6d96 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -751,6 +751,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
@@ -859,6 +865,12 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest, 
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b7ce19e6937b..2e9bb88174c74 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -737,6 +737,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
         l1_hint, l2_hint, l3_hint);
 }
 
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_StoreScatterOp
 //===----------------------------------------------------------------------===//
@@ -785,6 +801,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
         l2_hint, l3_hint);
 }
 
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+                           Value value, Value dest,
+                           ArrayRef<OpFoldResult> offsets, Value mask,
+                           IntegerAttr chunk_size,
+                           xegpu::CachePolicyAttr l1_hint,
+                           xegpu::CachePolicyAttr l2_hint,
+                           xegpu::CachePolicyAttr l3_hint) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_UpdateOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 270d71aaa7273..a4d697b5357e6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -685,6 +685,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+    : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    SmallVector<Value> newLoadOps;
+    auto chunkSizeAttr = rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+    VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+    for (auto [offsets, mask] :
+         llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newTy, op.getSource(), offsets, mask,
+          chunkSizeAttr, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
+      xegpu::setLayoutAttr(newLoadOp->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+    : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+    if (!valueType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = valueType.getShape();
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    auto chunkSizeOpt = op.getChunkSize();
+    int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+    auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+    for (auto [val, offs, mask] : llvm::zip(
+             adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+      if (auto layoutAttr =
+              op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0")) {
+        if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+          op->setAttr("layout_result_0", newLayout);
+      }
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -694,7 +776,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
                WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
                WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
                WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
-               WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
+               WgToSgConvertLayoutOp, WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+           WgToSgStoreScatterOpWithOffset>(
       patterns.getContext());
 }
 } // namespace xegpu
@@ -815,6 +898,21 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
+      [=](xegpu::LoadGatherOp op) -> bool {
+        auto layout = xegpu::getLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
+      [=](xegpu::StoreScatterOp op) -> bool {
+        // Check if the layout attribute is present on the result.
+        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0");
+        if (!layout)
+          return true;
+        return isLegal(layout);
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

>From 459e98ab8e4f770f4832873d9adc4d97035c9933 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 19 Aug 2025 17:26:10 +0000
Subject: [PATCH 2/2] Add tests

---
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 41 +++++++++++++++++++
 1 file changed, 41 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 07a0b86223c33..03fd95f0e778e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -263,4 +263,45 @@ gpu.module @test_distribution {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: @load_gather
+  // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+  gpu.func @load_gather(%src : memref<?xf16>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x4xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<32x4xi1>
+    // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+      : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @store_scatter
+  // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
+  gpu.func @store_scatter(%dest : memref<256xf16>) {
+    // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
+    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+      : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @load_with_chunk_size
+  // CHECK-SAME: %[[ARG0:.*]]: memref<?xf16>
+  gpu.func @load_with_chunk_size(%src : memref<?xf16>) {
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 4 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+      : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
+    gpu.return
+  }
 }



More information about the Mlir-commits mailing list