[Mlir-commits] [mlir] [MLIR][XeGPU] Use operand layouts for store scatter (PR #161447)

Nishant Patel llvmlistbot at llvm.org
Tue Sep 30 14:19:32 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/161447

The PR adds a change to use the layouts from the operands since store doesn't have a result 

>From 5a2ef29d93ede034b0e651bba032e78cd42c62f0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 30 Sep 2025 21:10:00 +0000
Subject: [PATCH] Use operand layouts for store scatter

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 24 +++++++++++--------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 19 +++++++++------
 2 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..784e5d68ce885 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
       return failure();
 
     xegpu::DistributeLayoutAttr layout =
-        xegpu::getDistributeLayoutAttr(op.getValue());
+        xegpu::getDistributeLayoutAttr(op.getOperand(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
     auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
     for (auto [val, offs, mask] : llvm::zip(
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
-      xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
-                                    mask, chunkSizeAttr, op.getL1HintAttr(),
-                                    op.getL2HintAttr(), op.getL3HintAttr());
+      auto store = xegpu::StoreScatterOp::create(
+          rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
       // Update the layout attribute to drop sg_layout and sg_data.
-      if (auto newLayout = layout.dropSgLayoutAndData())
-        op->setAttr("layout", newLayout);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        for (OpOperand &operand : store->getOpOperands()) {
+          // Skip for operand one (memref)
+          if (operand.getOperandNumber() == 1)
+            continue;
+          xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
+        }
+      }
     }
     rewriter.eraseOp(op);
     return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
       [=](xegpu::StoreScatterOp op) -> bool {
-        // Check if the layout attribute is present on the result.
-        auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
-        if (!layout)
-          return true;
+        auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
         return isLegal(layout);
       });
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 03c63861705d9..38392fd10b742 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -282,15 +282,20 @@ gpu.module @test_distribution {
   // CHECK-LABEL: @store_scatter
   // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
   gpu.func @store_scatter(%dest : memref<256xf16>) {
-    // CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
-    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
-    // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
+    // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
+    // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
+    // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
     // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
+    // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
+    // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
     // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
-    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
-    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
-    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
-    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
+    %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, 
+                                             layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+                                             l1_hint = #xegpu.cache_hint<cached>}
       : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
     gpu.return
   }



More information about the Mlir-commits mailing list