[Mlir-commits] [mlir] 5fc8e87 - [MLIR][XeGPU] Retain anchor op layouts for XeGPU nD ops (#170934)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 5 21:49:17 PST 2025


Author: Nishant Patel
Date: 2025-12-05T21:49:13-08:00
New Revision: 5fc8e87fe20d41654d8efc6bc93f7503b3bfdd36

URL: https://github.com/llvm/llvm-project/commit/5fc8e87fe20d41654d8efc6bc93f7503b3bfdd36
DIFF: https://github.com/llvm/llvm-project/commit/5fc8e87fe20d41654d8efc6bc93f7503b3bfdd36.diff

LOG: [MLIR][XeGPU] Retain anchor op layouts for XeGPU nD ops (#170934)

This PR adds support to retain the anchor op layouts (after dropping
what's not required) for xegpu nD ops during workgroup to subgroup &
unroll transformation

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index d93ffb70881bd..4fe1087d18879 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -329,7 +329,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
                    "ArrayRef<OpFoldResult>": $offsets,
                    "xegpu::CachePolicyAttr": $l1_hint,
                    "xegpu::CachePolicyAttr": $l2_hint,
-                   "xegpu::CachePolicyAttr": $l3_hint)>
+                   "xegpu::CachePolicyAttr": $l3_hint,
+                   "xegpu::DistributeLayoutAttr": $layout)>
   ];
 
   let hasVerifier = 1;
@@ -453,7 +454,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
                     "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
-                    "xegpu::CachePolicyAttr": $l3_hint)>
+                    "xegpu::CachePolicyAttr": $l3_hint,
+                    "xegpu::DistributeLayoutAttr": $layout)>
   ];
 
   let hasVerifier = 1;
@@ -564,7 +566,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                   "ArrayRef<OpFoldResult>": $offsets,
                   "xegpu::CachePolicyAttr": $l1_hint,
                   "xegpu::CachePolicyAttr": $l2_hint,
-                  "xegpu::CachePolicyAttr": $l3_hint)>
+                  "xegpu::CachePolicyAttr": $l3_hint,
+                  "xegpu::DistributeLayoutAttr": $layout)>
   ];
 
 

diff  --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 4358ef07da91d..079e1e2a8ac67 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -567,7 +567,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
                                           /*packed=*/nullptr, transposeAttr,
                                           /*l1_hint=*/hint,
-                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                          /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                          /*layout=*/nullptr);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -621,7 +622,8 @@ struct TransferWriteLowering
     auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
                                             ndDesc, indices,
                                             /*l1_hint=*/hint,
-                                            /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                            /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                            /*layout=*/nullptr);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -725,7 +727,8 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
         xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
                                 /*packed=*/nullptr, /*transpose=*/nullptr,
                                 /*l1_hint=*/hint,
-                                /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                /*layout=*/nullptr);
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -763,7 +766,8 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     auto storeNdOp =
         xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
                                  /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                 /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                 /*layout=*/nullptr);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8cb666298c959..91ba07a8e0256 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -472,7 +472,8 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
                          Value tensorDesc, ArrayRef<OpFoldResult> offsets,
                          xegpu::CachePolicyAttr l1_hint,
                          xegpu::CachePolicyAttr l2_hint,
-                         xegpu::CachePolicyAttr l3_hint) {
+                         xegpu::CachePolicyAttr l3_hint,
+                         xegpu::DistributeLayoutAttr layout) {
   SmallVector<Value> dynamicOffsets;
   SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -480,7 +481,7 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
 
   build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
-        l2_hint, l3_hint, /*anchor_layout=*/nullptr);
+        l2_hint, l3_hint, /*anchor_layout=*/layout);
 }
 
 LogicalResult PrefetchNdOp::verify() {
@@ -527,7 +528,8 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
                      UnitAttr packed, DenseI64ArrayAttr transpose,
                      xegpu::CachePolicyAttr l1_hint,
                      xegpu::CachePolicyAttr l2_hint,
-                     xegpu::CachePolicyAttr l3_hint) {
+                     xegpu::CachePolicyAttr l3_hint,
+                     xegpu::DistributeLayoutAttr layout) {
   SmallVector<Value> dynamicOffsets;
   SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -536,7 +538,7 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
 
   build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
         packed, transpose, l1_hint, l2_hint, l3_hint,
-        /*anchor_layout=*/nullptr);
+        /*anchor_layout=*/layout);
 }
 
 LogicalResult LoadNdOp::verify() {
@@ -647,7 +649,8 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
                       Value tensorDesc, ArrayRef<OpFoldResult> offsets,
                       xegpu::CachePolicyAttr l1_hint,
                       xegpu::CachePolicyAttr l2_hint,
-                      xegpu::CachePolicyAttr l3_hint) {
+                      xegpu::CachePolicyAttr l3_hint,
+                      xegpu::DistributeLayoutAttr layout) {
   SmallVector<Value> dynamicOffsets;
   SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -655,7 +658,7 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
 
   build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
-        l1_hint, l2_hint, l3_hint, /*anchor_layout=*/nullptr);
+        l1_hint, l2_hint, l3_hint, /*anchor_layout=*/layout);
 }
 
 LogicalResult StoreNdOp::verify() {

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8995ab3082d24..e6009d5afeab2 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -528,7 +528,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
   xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
                               newDescOp.getResult(),
                               getPrefetchOffsets(initForOp.getInductionVar()),
-                              readCacheHint, readCacheHint, readCacheHint);
+                              readCacheHint, readCacheHint, readCacheHint,
+                              /*layout=*/nullptr);
 
   // Insert prefetch op in main loop.
   // Calculate prefetch offset after the init prefetches have been issued.
@@ -539,7 +540,7 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
   xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
                               newDescOp.getResult(),
                               getPrefetchOffsets(prefetchOffset), readCacheHint,
-                              readCacheHint, readCacheHint);
+                              readCacheHint, readCacheHint, /*layout=*/nullptr);
 
   // Unroll the init loop.
   if (failed(loopUnrollFull(initForOp)))

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
index 4dc5ea4f7bb24..ab41fe4298d99 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -214,7 +214,7 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
           newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
           origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
           origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
-          origLoadOp.getL3HintAttr());
+          origLoadOp.getL3HintAttr(), origLoadOp.getLayoutAttr());
       // Set the layout for the loadOp.
       auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
       xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 330553564f81a..af63f09cd5e4a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -238,6 +238,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
     if (!targetShape)
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropInstData();
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
     bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
@@ -255,7 +258,7 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
       auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
         xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
                                     op.getL1HintAttr(), op.getL2HintAttr(),
-                                    op.getL3HintAttr());
+                                    op.getL3HintAttr(), layout);
         // return dummy Value to satisfy function's signature
         return nullptr;
       };
@@ -282,6 +285,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     if (!targetShape)
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropInstData();
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
     bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
@@ -306,7 +312,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
         return xegpu::LoadNdOp::create(
             rewriter, loc, newValueTy, convertedTdescs[0], offsets,
             op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
-            op.getL2HintAttr(), op.getL3HintAttr());
+            op.getL2HintAttr(), op.getL3HintAttr(), layout);
       };
       newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
                                       *targetShape, createLoad, loc, rewriter);
@@ -331,6 +337,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     if (!targetShape)
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropInstData();
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
     bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
@@ -354,7 +363,7 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
         xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
                                  convertedTdescs[0], offsets,
                                  op.getL1HintAttr(), op.getL2HintAttr(),
-                                 op.getL3HintAttr());
+                                 op.getL3HintAttr(), layout);
         // return dummy Value to satisfy function's signature
         return nullptr;
       };

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e7182ed8d05f7..be82cda574f1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -317,6 +317,9 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
     if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropSgLayoutAndData();
     SmallVector<Value> newOps;
     for (auto [tdesc, offsets] :
          llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
@@ -326,7 +329,7 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
       auto newOp = xegpu::LoadNdOp::create(
           rewriter, op.getLoc(), newResTy, tdesc, offsets,
           /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          op.getL2HintAttr(), op.getL3HintAttr(), layout);
       newOps.push_back(newOp);
     }
     rewriter.replaceOpWithMultiple(op, {newOps});
@@ -347,11 +350,14 @@ struct WgToSgStoreNdOpWithOffset
     if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropSgLayoutAndData();
     for (auto [v, tdesc, offsets] :
          llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
       xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
                                op.getL1HintAttr(), op.getL2HintAttr(),
-                               op.getL3HintAttr());
+                               op.getL3HintAttr(), layout);
     }
     rewriter.eraseOp(op);
 
@@ -371,11 +377,14 @@ struct WgToSgPrefetchNdOpWithOffset
     if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
+    if (layout)
+      layout = layout.dropSgLayoutAndData();
     for (auto [tdesc, offsets] :
          llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
       xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
                                   op.getL1HintAttr(), op.getL2HintAttr(),
-                                  op.getL3HintAttr());
+                                  op.getL3HintAttr(), layout);
     }
     rewriter.eraseOp(op);
 

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 6fc8a8d9df7fd..69eb8ce9dfba5 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -633,4 +633,17 @@ gpu.module @test_distribution {
         #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: load_nd_tdesc_with_anchor_layout
+  gpu.func @load_nd_tdesc_with_anchor_layout(%src: memref<256x128xf32>) {
+    //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<inst_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1]>}>
+    // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+    %load =  xegpu.load_nd %tdesc[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16],lane_layout = [1, 16], lane_data = [1, 1]>}>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list