[Mlir-commits] [mlir] [mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops (PR #163414)
    Dmitry Chigarev 
    llvmlistbot at llvm.org
       
    Tue Oct 14 08:45:44 PDT 2025
    
    
  
https://github.com/dchigarev created https://github.com/llvm/llvm-project/pull/163414
As [suggested here](https://github.com/llvm/llvm-project/pull/163071#discussion_r2427229637), adding optional layout attribute to LoadGather/StoreScatter ops.
>From fe902beb885d5a0a188583d57f77e8db75ea8783 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 14 Oct 2025 15:22:38 +0000
Subject: [PATCH 1/2] [mlir][XeGPU] Add optional layout attribute to LoadGather
 StoreScatter ops
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 54 ++++++++++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 64 +++++++++++++++++--
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 21 +++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 15 ++++-
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 10 +--
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 12 ++--
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 15 ++---
 8 files changed, 152 insertions(+), 43 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..9e80b8e453a73 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -852,6 +853,16 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       return getSource().getType();
     }
 
+    xegpu::DistributeLayoutAttr getDistributeLayout() {
+      xegpu::DistributeLayoutAttr layout = nullptr;
+      if (auto tdescType = getTensorDescType()) {
+        layout = tdescType.getLayoutAttr();
+      }
+      if (!layout)
+        layout = getLayoutAttr();
+      return layout;
+    }
+
     TypedValue<xegpu::TensorDescType> getTensorDesc() {
       if (auto tdescType = getTensorDescType()) {
         return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
@@ -895,7 +906,19 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -979,7 +1002,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -993,6 +1017,16 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       return TypedValue<xegpu::TensorDescType>();
     }
 
+    xegpu::DistributeLayoutAttr getDistributeLayout() {
+      xegpu::DistributeLayoutAttr layout = nullptr;
+      if (auto tdescType = getTensorDescType()) {
+        layout = tdescType.getLayoutAttr();
+      }
+      if (!layout)
+        layout = getLayoutAttr();
+      return layout;
+    }
+
     xegpu::TensorDescType getTensorDescType() {
       return dyn_cast<xegpu::TensorDescType>(getDestType());
     }
@@ -1030,7 +1064,19 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest, 
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788d0b9b4..22240d9bad095 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -816,7 +816,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
-        l1_hint, l2_hint, l3_hint);
+        l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -832,7 +832,34 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
   auto offset = vector::FromElementsOp::create(builder, loc, type, values);
 
   build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source, Value mask,
+                         xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+        l1_hint, l2_hint, l3_hint, layout);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
@@ -883,7 +910,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -901,7 +928,36 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
 
   // Call the correct builder overload that does not expect result types.
   build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
-        l3_hint);
+        l3_hint, /*layout=*/nullptr);
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+                           Value value, Value dest, Value mask,
+                           xegpu::CachePolicyAttr l1_hint,
+                           xegpu::CachePolicyAttr l2_hint,
+                           xegpu::CachePolicyAttr l3_hint,
+                           DistributeLayoutAttr layout) {
+  build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+        l2_hint, l3_hint, layout);
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+                           Value value, Value dest,
+                           ArrayRef<OpFoldResult> offsets, Value mask,
+                           IntegerAttr chunk_size,
+                           xegpu::CachePolicyAttr l1_hint,
+                           xegpu::CachePolicyAttr l2_hint,
+                           xegpu::CachePolicyAttr l3_hint,
+                           DistributeLayoutAttr layout) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..fad698bd2001f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -687,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset
       auto newOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newValueTy, op.getSource(), o, m,
           rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), op.getLayoutAttr());
       newOps.push_back(newOp);
     }
 
@@ -783,7 +783,7 @@ struct UnrollStoreScatterOpWithOffsets
       xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
                                     rewriter.getI64IntegerAttr(chunkSize),
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), op.getLayoutAttr());
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c28d2fc6c2b63..6feee60a92abe 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -913,7 +913,7 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/nullptr);
       xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
                                      layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
@@ -961,19 +961,16 @@ struct WgToSgStoreScatterOpWithOffset
     auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
     for (auto [val, offs, mask] : llvm::zip(
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+      xegpu::DistributeLayoutAttr newLayout = nullptr;
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        // Update the layout attribute to drop sg_layout and sg_data.
+        newLayout = layout.dropSgLayoutAndData();
+
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-      // Update the layout attribute to drop sg_layout and sg_data.
-      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-          !layout.getEffectiveInstDataAsInt().empty()) {
-        for (OpOperand &operand : store->getOpOperands()) {
-          // Skip for operand one (memref)
-          if (operand.getOperandNumber() == 1)
-            continue;
-          xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
-        }
-      }
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/newLayout);
+      
     }
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a438ea62c..8981d3fa846ca 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -109,6 +109,8 @@ std::string xegpu::getLayoutName(const OpOperand &operand) {
 }
 
 std::string xegpu::getLayoutName(const OpResult result) {
+  if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(result.getOwner()))
+    return "layout";
   const StringRef prefix = "layout_result_";
   return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
 }
@@ -141,6 +143,9 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
       return storeOp.getLayoutAttr();
 
+    if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+      return loadGatherOp.getLayoutAttr();
+
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -168,6 +173,12 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
     return storeOp.getLayoutAttr();
 
+  // if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op))
+  //   return loadGatherOp.getDistributeLayout();
+
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    return storeScatterOp.getDistributeLayout();
+
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
     return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -196,7 +207,7 @@ template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
 void xegpu::setDistributeLayoutAttrs(
     Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {
-    if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
+    if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(nestOp))
       return;
 
     for (OpOperand &opr : nestOp->getOpOperands()) {
@@ -216,6 +227,8 @@ void xegpu::removeLayoutAttr(const T &operandOrResult) {
   std::string name = xegpu::getLayoutName(operandOrResult);
   if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
     owner->removeAttr(name);
+  if (isa<xegpu::StoreMatrixOp, xegpu::LoadGatherOp>(owner) && owner->hasAttrOfType<DistributeLayoutAttr>("layout"))
+    owner->removeAttr("layout");
 }
 
 // Explicit instantiation for OpResult
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 30f785ded975a..f807bcbda1984 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
 func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64,
+// CHECK-SAME: layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
@@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 0e1365aa64171..f8e8794b4e066 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -136,9 +136,9 @@ gpu.module @xevm_module{
     %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
     %loaded = scf.if %pred -> (vector<16x8xf16>) {
-      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, 
+        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      }> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
       scf.yield %3 : vector<16x8xf16>
     } else {
       %3 = arith.constant {
@@ -168,9 +168,9 @@ gpu.module @xevm_module{
     %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
     scf.if %pred  {
-      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-      } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+      %3 = xegpu.load %src[%offset], %1 <{chunk_size=8,
+        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+      }> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
       xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     }
     gpu.return
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 742d11f8052ec..888cbddebf2f0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -274,7 +274,7 @@ gpu.module @test_distribution {
     // CHECK-SAME: : memref<?xf16>, vector<32x4xindex>, vector<32x4xi1> -> vector<32x4xf16>
     %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<0> : vector<256x16xindex>
     %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>} dense<1> : vector<256x16xi1>
-    %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 4]>, l1_hint = #xegpu.cache_hint<cached>}
       : memref<?xf16>, vector<256x16xindex>, vector<256x16xi1> -> vector<256x16xf16>
     gpu.return
   }
@@ -285,16 +285,13 @@ gpu.module @test_distribution {
     // CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
     // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
     // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
-    // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
-    // CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
-    // CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
+    // CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>,
+    // CHECK-SAME: layout = #xegpu.layout<inst_data = [8]>}>
     // CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
     %val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
     %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
     %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
-    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>, 
-                                             layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
-                                             layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
+    xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
                                              l1_hint = #xegpu.cache_hint<cached>}
       : vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
     gpu.return
@@ -309,7 +306,7 @@ gpu.module @test_distribution {
     // CHECK-SAME: : memref<?xf16>, vector<8xindex>, vector<8xi1> -> vector<8x4xf16>
     %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
     %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
-    %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
+    %load = xegpu.load %src[%offset], %mask {chunk_size = 4, layout = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 4]>, l1_hint = #xegpu.cache_hint<cached>}
       : memref<?xf16>, vector<256xindex>, vector<256xi1> -> vector<256x4xf16>
     gpu.return
   }
@@ -401,7 +398,7 @@ gpu.module @test_distribution {
       %cst_acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} dense<0.0> : vector<4x2x6xf16>
       %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<0>  : vector<4x2x6x32xindex>
       %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<4x2x6x32xi1>
-      %load = xegpu.load %src[%offset], %mask  {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16>
+      %load = xegpu.load %src[%offset], %mask  {layout = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : ui64, vector<4x2x6x32xindex>, vector<4x2x6x32xi1> -> vector<4x2x6x32xf16>
       // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [3] : vector<1x1x1x32xf16> to vector<1x1x1xf16>
       %reduce = vector.multi_reduction <add>, %load, %cst_acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>, dims = [3]>} [3]
       : vector<4x2x6x32xf16> to vector<4x2x6xf16>
>From 946ab9a7a1531e108674afc0c08775be35325035 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 14 Oct 2025 15:26:26 +0000
Subject: [PATCH 2/2] clang-format
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp              | 13 +++++--------
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp      |  7 ++++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp         |  6 ++++--
 3 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 22240d9bad095..fa42ffbad4d31 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -941,14 +941,11 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
         l2_hint, l3_hint, layout);
 }
 
-void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
-                           Value value, Value dest,
-                           ArrayRef<OpFoldResult> offsets, Value mask,
-                           IntegerAttr chunk_size,
-                           xegpu::CachePolicyAttr l1_hint,
-                           xegpu::CachePolicyAttr l2_hint,
-                           xegpu::CachePolicyAttr l3_hint,
-                           DistributeLayoutAttr layout) {
+void StoreScatterOp::build(
+    OpBuilder &builder, OperationState &state, Value value, Value dest,
+    ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+    xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
   auto loc = dest.getLoc();
   int64_t size = static_cast<int64_t>(offsets.size());
   auto type = VectorType::get(size, builder.getIndexType());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6feee60a92abe..33bf2808abd01 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -913,7 +913,8 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/nullptr);
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          /*layout*/ nullptr);
       xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
                                      layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
@@ -969,8 +970,8 @@ struct WgToSgStoreScatterOpWithOffset
 
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/newLayout);
-      
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          /*layout*/ newLayout);
     }
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 8981d3fa846ca..c9ef37dc1daa9 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -207,7 +207,8 @@ template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
 void xegpu::setDistributeLayoutAttrs(
     Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {
-    if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(nestOp))
+    if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::LoadGatherOp,
+            xegpu::StoreScatterOp>(nestOp))
       return;
 
     for (OpOperand &opr : nestOp->getOpOperands()) {
@@ -227,7 +228,8 @@ void xegpu::removeLayoutAttr(const T &operandOrResult) {
   std::string name = xegpu::getLayoutName(operandOrResult);
   if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
     owner->removeAttr(name);
-  if (isa<xegpu::StoreMatrixOp, xegpu::LoadGatherOp>(owner) && owner->hasAttrOfType<DistributeLayoutAttr>("layout"))
+  if (isa<xegpu::StoreMatrixOp, xegpu::LoadGatherOp>(owner) &&
+      owner->hasAttrOfType<DistributeLayoutAttr>("layout"))
     owner->removeAttr("layout");
 }
 
    
    
More information about the Mlir-commits
mailing list