[Mlir-commits] [mlir] cd5d5b3 - [mlir][XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops (#167850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 11:00:07 PST 2025


Author: Dmitry Chigarev
Date: 2025-11-17T11:00:03-08:00
New Revision: cd5d5b31bff0052be214357133ad3dd7f3f24a74

URL: https://github.com/llvm/llvm-project/commit/cd5d5b31bff0052be214357133ad3dd7f3f24a74
DIFF: https://github.com/llvm/llvm-project/commit/cd5d5b31bff0052be214357133ad3dd7f3f24a74.diff

LOG: [mlir][XeGPU] Use DistributeLayoutAttr instead of LayoutAttr for load gather/scatter ops (#167850)

The PR changes the layout attribute type for
`xegpu::LoadGatherOp/StoreScatterOp` from `LayoutAttr` to
`DistributeLayoutAttr` to also support `xegpu.slice` layouts.

Initially we [wanted to restrict slice
layouts](https://github.com/llvm/llvm-project/pull/163414#discussion_r2478978798)
from the attribute, but now it turns out there are actually valid use
cases for that:
```mlir
gpu.func @distribute_load_slice_attr() {
  %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
  %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
  %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>

  %3 = xegpu.load %2[%offset], %mask <{chunk_size = 1, layout = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]>>} {
      layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> 
  } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>

  %4 = vector.broadcast %3 {layout_result_0 =
      #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
  gpu.return
}
```

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 689ebd0d1179a..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -844,7 +844,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
-      OptionalAttr<XeGPU_LayoutAttr>:$layout);
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -903,7 +903,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint,
-                    "xegpu::LayoutAttr": $layout)>
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -988,7 +988,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
-      OptionalAttr<XeGPU_LayoutAttr>:$layout);
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1046,7 +1046,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
                     "xegpu::CachePolicyAttr": $l3_hint,
-                    "xegpu::LayoutAttr": $layout)>
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4dd10bedc6d84..85c9a966f0fe8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -901,7 +901,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint,
-                         xegpu::LayoutAttr layout) {
+                         DistributeLayoutAttr layout) {
   auto loc = source.getLoc();
   int64_t size = static_cast<int64_t>(offsets.size());
   auto type = VectorType::get(size, builder.getIndexType());
@@ -985,7 +985,7 @@ void StoreScatterOp::build(
     OpBuilder &builder, OperationState &state, Value value, Value dest,
     ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
     xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
-    xegpu::CachePolicyAttr l3_hint, xegpu::LayoutAttr layout) {
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
   auto loc = dest.getLoc();
   int64_t size = static_cast<int64_t>(offsets.size());
   auto type = VectorType::get(size, builder.getIndexType());

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c3bf9606693a8..330553564f81a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -678,7 +678,7 @@ struct UnrollLoadGatherOpWithOffset
           pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
     }
 
-    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    auto layout = op.getLayoutAttr();
     if (layout)
       layout = layout.dropInstData();
 
@@ -778,7 +778,7 @@ struct UnrollStoreScatterOpWithOffsets
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
-    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(op.getLayoutAttr());
+    auto layout = op.getLayoutAttr();
     if (layout)
       layout = layout.dropInstData();
 

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0a9ef0aa6df96..33d4b0457e5d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -889,8 +889,8 @@ struct WgToSgLoadGatherOpWithOffset
       return failure();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
-        xegpu::getDistributeLayoutAttr(op.getResult()));
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -913,10 +913,12 @@ struct WgToSgLoadGatherOpWithOffset
     VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
     for (auto [offsets, mask] :
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLayout = layout.dropSgLayoutAndData();
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
           op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
-          layout.dropSgLayoutAndData());
+          newLayout);
+      xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), newLayout);
       newLoadOps.push_back(newLoadOp);
     }
     rewriter.replaceOpWithMultiple(op, {newLoadOps});
@@ -941,8 +943,8 @@ struct WgToSgStoreScatterOpWithOffset
     if (!valueType)
       return failure();
 
-    xegpu::LayoutAttr layout = dyn_cast_if_present<xegpu::LayoutAttr>(
-        xegpu::getDistributeLayoutAttr(op.getOperand(0)));
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getOperand(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 4fbb566cfbe73..5dde84e8e0bc2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -547,4 +547,21 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_load_slice_attr
+  gpu.func @distribute_load_slice_attr() {
+    %2 = memref.alloca() {alignment = 1024} : memref<4096xf32>
+    %offset =  arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<0> : vector<256xindex>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [32], inst_data = [16]> } dense<1> : vector<256xi1>
+
+    // CHECK: %[[LOAD:.*]] = xegpu.load {{.*}} <{chunk_size = 1 : i64, layout = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>}>
+    // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [8, 16]>, dims = [0]>} :
+    // CHECK-SAME: memref<4096xf32>, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+    %3 = xegpu.load %2[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>, dims = [0]> } : memref<4096xf32>, vector<256xindex>, vector<256xi1> -> vector<256xf32>
+
+    // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32xf32> to vector<32x32xf32>
+    %4 = vector.broadcast %3 {layout_result_0 =
+        #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list