[Mlir-commits] [mlir] [mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops (PR #163414)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 26 06:03:25 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Dmitry Chigarev (dchigarev)

<details>
<summary>Changes</summary>

As [suggested here](https://github.com/llvm/llvm-project/pull/163071#discussion_r2427229637) the PR adds an optional layout attribute for `LoadGather` and `StoreScatter` ops.

For the load-op the attribute describes the layout of the result (ex `layout_result_0`), and for store-op it describes the layout for the vector-to-store operand (ex `layout_operand_0`).

The PR also reworks `xegpu-unroll` and `wg-to-sg-distribute` passes to consider and carry the new layout attribute.

The helper utility function `getDistributeLayoutAttr` is reworked to return either `layout_operand/result_0` or `layout` for load/store ops (denepding on which one is set) to preserve backward compatibility. The `setDistributeLayoutAttr` function can now only set the new layout attr.



---

Patch is 47.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163414.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+20-4) 
- (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+8-4) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+37-4) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+10-2) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+4-4) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+24-2) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+5-5) 
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+81) 
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+63) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+125) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir (+45) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+63-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..b5d9323de47a6 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
       /*chunk_size=*/IntegerAttr{},
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
-      /*l3_hint=*/xegpu::CachePolicyAttr{});
+      /*l3_hint=*/xegpu::CachePolicyAttr{},
+      /*layout=*/nullptr);
 
   rewriter.replaceOp(readOp, gatherOp.getResult());
   return success();
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
                                 /*chunk_size=*/IntegerAttr{},
                                 /*l1_hint=*/xegpu::CachePolicyAttr{},
                                 /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                /*layout=*/nullptr);
   rewriter.eraseOp(writeOp);
   return success();
 }
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
         /*chunk_size=*/IntegerAttr{},
         /*l1_hint=*/xegpu::CachePolicyAttr{},
         /*l2_hint=*/xegpu::CachePolicyAttr{},
-        /*l3_hint=*/xegpu::CachePolicyAttr{});
+        /*l3_hint=*/xegpu::CachePolicyAttr{},
+        /*layout=*/nullptr);
 
     auto selectOp =
         arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
                                   /*chunk_size=*/IntegerAttr{},
                                   /*l1_hint=*/xegpu::CachePolicyAttr{},
                                   /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                  /*layout=*/nullptr);
     rewriter.eraseOp(scatterOp);
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..2a7c7ac7e8cde 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
-        l1_hint, l2_hint, l3_hint);
+        l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
   auto offset = vector::FromElementsOp::create(builder, loc, type, values);
 
   build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
@@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
 
   // Call the correct builder overload that does not expect result types.
   build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
-        l3_hint);
+        l3_hint, /*layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+    OpBuilder &builder, OperationState &state, Value value, Value dest,
+    ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+    xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index e6e71cc29a80a..c3bf9606693a8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -678,12 +678,16 @@ struct UnrollLoadGatherOpWithOffset
           pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
     }
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     SmallVector<Value> newOps;
     for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
       auto newOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newValueTy, op.getSource(), o, m,
           rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), layout);
       newOps.push_back(newOp);
     }
 
@@ -774,12 +778,16 @@ struct UnrollStoreScatterOpWithOffsets
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    if (layout)
+      layout = layout.dropInstData();
+
     for (auto [v, o, m] :
          llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
       xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
                                     rewriter.getI64IntegerAttr(chunkSize),
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), layout);
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..ceeafbd7bd5e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -914,9 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-      xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
     }
     rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -964,7 +963,8 @@ struct WgToSgStoreScatterOpWithOffset
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          layout.dropSgLayoutAndData());
       // Update the layout attribute to drop sg_layout and sg_data.
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty()) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd7e94d6..6e918f162a5ea 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -105,12 +105,22 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
 std::string xegpu::getLayoutName(const OpOperand &operand) {
   const StringRef prefix("layout_operand_");
   unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-  return llvm::formatv("{0}{1}", prefix, idx).str();
+  auto owner = operand.getOwner();
+  auto tempLayout = llvm::formatv("{0}{1}", prefix, idx).str();
+  if (isa<StoreScatterOp>(operand.getOwner()) && idx == 0 &&
+      !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 std::string xegpu::getLayoutName(const OpResult result) {
   const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  auto owner = result.getOwner();
+  auto tempLayout =
+      llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+  if (isa<LoadGatherOp>(owner) && !owner->hasAttr(tempLayout))
+    return "layout";
+  return tempLayout;
 }
 
 xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
@@ -144,6 +154,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+    // check for "permament" layout only after "temporary" layout name lookup
+    // for backward compatibility
+    if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
+      return loadGatherOp.getLayoutAttr();
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,6 +186,13 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
     return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+
+  // check for "permament" layout only after "temporary" layout name lookup
+  // for backward compatibility
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    if (auto layout = storeScatterOp.getLayoutAttr())
+      return layout;
+
   return getDistributeLayoutAttr(opr.get());
 }
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 30f785ded975a..3e2644beffc35 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]]  {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
 func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64,
+// CHECK-SAME: layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
@@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index f233dff609f2b..e602163ff5aaa 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -462,6 +462,47 @@ gpu.module @xevm_module{
   }
 }
 
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize_perm_layout({{.*}}) {
+// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME:    -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
+// CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
+// CHECK-SAME:      vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[T1:.*]] = xegpu.load %[[W]]#1[%[[W]]#2], %[[W]]#3 <{chunk_size = 8 : i64}>
+// CHECK-SAME:    : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT:  xegpu.store %[[T1]], %[[W]]#1[%[[W]]#2],...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/163414


More information about the Mlir-commits mailing list