[Mlir-commits] [mlir] [mlir][XeGPU] Add optional layout attribute to LoadGather StoreScatter ops (PR #163414)

Dmitry Chigarev llvmlistbot at llvm.org
Thu Oct 23 01:52:58 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/163414

>From feb4def9770e5ae56b90a23b1d343b8e4e7b8e4b Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 14 Oct 2025 15:22:38 +0000
Subject: [PATCH] [mlir][XeGPU] Add optional layout attribute to LoadGather
 StoreScatter ops

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 24 +++++++++--
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 12 ++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 41 +++++++++++++++++--
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  7 +++-
 5 files changed, 72 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..4c67856b559b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -843,7 +843,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
   let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
@@ -895,7 +896,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $source,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
@@ -979,7 +987,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
       OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
       OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
-      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
+      OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
+      OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let extraClassDeclaration = extraBaseClassDeclaration#[{
     Type getDestType() {
@@ -1030,7 +1039,14 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                     "IntegerAttr": $chunk_size,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $dest,
+                    "ArrayRef<OpFoldResult>": $offsets, "Value": $mask,
+                    "IntegerAttr": $chunk_size,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
    ];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..b5d9323de47a6 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -435,7 +435,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
       /*chunk_size=*/IntegerAttr{},
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
-      /*l3_hint=*/xegpu::CachePolicyAttr{});
+      /*l3_hint=*/xegpu::CachePolicyAttr{},
+      /*layout=*/nullptr);
 
   rewriter.replaceOp(readOp, gatherOp.getResult());
   return success();
@@ -469,7 +470,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
                                 /*chunk_size=*/IntegerAttr{},
                                 /*l1_hint=*/xegpu::CachePolicyAttr{},
                                 /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                /*layout=*/nullptr);
   rewriter.eraseOp(writeOp);
   return success();
 }
@@ -621,7 +623,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
         /*chunk_size=*/IntegerAttr{},
         /*l1_hint=*/xegpu::CachePolicyAttr{},
         /*l2_hint=*/xegpu::CachePolicyAttr{},
-        /*l3_hint=*/xegpu::CachePolicyAttr{});
+        /*l3_hint=*/xegpu::CachePolicyAttr{},
+        /*layout=*/nullptr);
 
     auto selectOp =
         arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +658,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
                                   /*chunk_size=*/IntegerAttr{},
                                   /*l1_hint=*/xegpu::CachePolicyAttr{},
                                   /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{},
+                                  /*layout=*/nullptr);
     rewriter.eraseOp(scatterOp);
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..2a7c7ac7e8cde 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -859,7 +859,7 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
-        l1_hint, l2_hint, l3_hint);
+        l1_hint, l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -875,7 +875,24 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
   auto offset = vector::FromElementsOp::create(builder, loc, type, values);
 
   build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source,
+                         ArrayRef<OpFoldResult> offsets, Value mask,
+                         IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint,
+                         DistributeLayoutAttr layout) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint,
+        l2_hint, l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
@@ -926,7 +943,7 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
   build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
-        l2_hint, l3_hint);
+        l2_hint, l3_hint, /*layout=*/nullptr);
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -944,7 +961,23 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
 
   // Call the correct builder overload that does not expect result types.
   build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
-        l3_hint);
+        l3_hint, /*layout=*/nullptr);
+}
+
+void StoreScatterOp::build(
+    OpBuilder &builder, OperationState &state, Value value, Value dest,
+    ArrayRef<OpFoldResult> offsets, Value mask, IntegerAttr chunk_size,
+    xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint,
+    xegpu::CachePolicyAttr l3_hint, DistributeLayoutAttr layout) {
+  auto loc = dest.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = vector::FromElementsOp::create(builder, loc, type, values);
+
+  // Call the correct builder overload that does not expect result types.
+  build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint,
+        l3_hint, layout);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index aafa1b7deb84b..ccc4c9bc9dbe9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -687,7 +687,7 @@ struct UnrollLoadGatherOpWithOffset
       auto newOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newValueTy, op.getSource(), o, m,
           rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), /*layout*/ nullptr);
       newOps.push_back(newOp);
     }
 
@@ -783,7 +783,7 @@ struct UnrollStoreScatterOpWithOffsets
       xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
                                     rewriter.getI64IntegerAttr(chunkSize),
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), /*layout*/ nullptr);
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..460b04eaf1994 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -914,7 +914,8 @@ struct WgToSgLoadGatherOpWithOffset
          llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
       auto newLoadOp = xegpu::LoadGatherOp::create(
           rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          /*layout*/ nullptr);
       xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0),
                                      layout.dropSgLayoutAndData());
       newLoadOps.push_back(newLoadOp);
@@ -962,9 +963,11 @@ struct WgToSgStoreScatterOpWithOffset
     auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
     for (auto [val, offs, mask] : llvm::zip(
              adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+
       auto store = xegpu::StoreScatterOp::create(
           rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
+          /*layout*/ nullptr);
       // Update the layout attribute to drop sg_layout and sg_data.
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty()) {



More information about the Mlir-commits mailing list