[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (PR #153432)

Nishant Patel llvmlistbot at llvm.org
Mon Aug 18 09:34:07 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/153432

>From fe75a0820d8260d4b39f20b8f8535beba682e447 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 1 Aug 2025 22:12:18 +0000
Subject: [PATCH 01/12] Add create_nd_desc pattern without offset

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  69 +++++++++-
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  60 ++++++---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 120 ++++++++++--------
 3 files changed, 175 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..b2eaa436ac76e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -161,6 +161,18 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
+    // Ensure that the op has explicit offsets specified (either dynamic or
+    // constant).
+    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+    if (offsetSize == 0) {
+      auto constOffsetsAttr = op.getConstOffsetsAttr();
+      if (!constOffsetsAttr || constOffsetsAttr.empty() ||
+          llvm::all_of(constOffsetsAttr.asArrayRef(),
+                       [](auto v) { return v == 0; }))
+        return failure();
+    }
+
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
@@ -250,6 +262,52 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   }
 };
 
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+    if (offsetSize != 0 || (op.getConstOffsetsAttr() &&
+                            llvm::any_of(op.getConstOffsetsAttr().asArrayRef(),
+                                         [](auto v) { return v != 0; })))
+      return failure();
+
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+    SmallVector<int64_t> sgShape;
+    int count;
+    std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+
+    SmallVector<Value> newCreateNdOps;
+    for (int i = 0; i < count; ++i) {
+      auto newOp = xegpu::CreateNdDescOp::create(
+          rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
+          ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
+          DenseI64ArrayAttr());
+      newCreateNdOps.push_back(newOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
 /// This pattern transforms the LoadNdOp to load subgroup data.
 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
@@ -654,11 +712,12 @@ struct UnrealizedConversionCastOpPattern
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
-               WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
-      patterns.getContext());
+  patterns
+      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
+           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a4857d1253..f1b68c0decdc2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -7,7 +7,20 @@ gpu.module @test_round_robin_assignment {
       // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
       // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+    }
+
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref<256x128xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
     }
@@ -15,7 +28,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
       // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -30,7 +44,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd(%src: memref<256x128xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      %cst0 = arith.constant 0 : index
+      %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
       // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -46,7 +61,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: update_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @update_nd(%src: memref<256x128xf32>){
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
     // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
@@ -69,12 +85,13 @@ gpu.module @test_round_robin_assignment {
     // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
+    %cst0 = arith.constant 0 : index
+    %tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<256x128xf16>
       -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
       : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf16>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x256xf16>
       -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
     %load_b =  xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -91,7 +108,8 @@ gpu.module @test_round_robin_assignment {
     // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
     // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -101,7 +119,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: broadcast
   // CHECK-SAME: %[[ARG_0:.*]]: memref<128x1xf32>
   gpu.func @broadcast(%src: memref<128x1xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<128x1xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<128x1xf32>
       -> !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
@@ -122,8 +141,8 @@ gpu.module @test_round_robin_assignment {
     %c0 = arith.constant 0 : index
     %c256 = arith.constant 256 : index
     %c1024 = arith.constant 1024 : index
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
-    %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg1[%c0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     // CHECK-LABEL: scf.for
     // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
     %2:2 = scf.for %arg2 = %c0 to %c1024 step %c256 iter_args(%arg3 = %0, %arg4 = %1)
@@ -143,9 +162,10 @@ gpu.module @test_round_robin_assignment {
     %c1_i32 = arith.constant 1 : i32
     %c10_i32 = arith.constant 10 : i32
     %c0_i32 = arith.constant 0 : i32
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
-    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
     %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
       %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
@@ -164,10 +184,11 @@ gpu.module @test_round_robin_assignment {
   }
 
   gpu.func @scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %cst0 = arith.constant 0 : index
     %c10 = arith.constant 10 : index
     %0 = gpu.subgroup_id : index
-    %1 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
-    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %3 = arith.cmpi eq, %0, %c10 : index
     // CHECK-LABEL: scf.if
     //  CHECK-SAME: (vector<16xf32>, vector<16xf32>)
@@ -189,20 +210,20 @@ gpu.module @test_round_robin_assignment {
   gpu.func @scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c10 = arith.constant 10 : index
     %id = gpu.subgroup_id : index
-
-    %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %t = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
 
     %0 = arith.cmpi eq, %id, %c10 : index
     // CHECK-LABEL: scf.if
     //  CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
     %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
-      %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %2 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     } else {
-      %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
+      %3 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -212,7 +233,8 @@ gpu.module @test_round_robin_assignment {
   }
 
   gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
-    %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0, %cst0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
     //CHECK-2: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
     //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b00372bc193..a0352169f2380 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -6,32 +6,42 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-  // CHECK: %[[SGID:.*]] = gpu.subgroup_id
-  // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[C32:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
-  // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
-  // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
-  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
-  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[C256:.*]] = arith.constant 256 : index
-  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
-  // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
-  // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
-  // CHECK: %[[C128:.*]] = arith.constant 128 : index
-  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
-  // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-    -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-  gpu.return
+    // CHECK: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[C32:.*]] = arith.constant 32 : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+    // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
+    // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
+    // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
+    // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+    // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
+    // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+    // CHECK: %[[C256:.*]] = arith.constant 256 : index
+    // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+    // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
+    // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+    // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0]]
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: gpu.return
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][0, 0] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
   }
 
   // CHECK-LABEL: load_nd_tdesc
@@ -42,7 +52,8 @@ gpu.module @test_1_1_assignment {
     // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<32x32xf32>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -60,7 +71,8 @@ gpu.module @test_1_1_assignment {
     // CHECK-SAME: -> vector<32x32xf32>
     // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
     // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load = xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -77,7 +89,8 @@ gpu.func @update_nd(%src: memref<256x128xf32>){
   // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
   // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+  %cst0 = arith.constant 0 : index
+  %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
     -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
   %update = xegpu.update_nd_offset %tdesc, [0, 16]
     : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -86,13 +99,14 @@ gpu.func @update_nd(%src: memref<256x128xf32>){
 
 // CHECK-LABEL: dpas
 gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    %cst0 = arith.constant 0 : index
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+    %tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
       : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<128x128xf16>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
     %load_b =  xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -106,15 +120,16 @@ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
 
 // CHECK-LABEL: dpas_no_sg_data
 gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    %cst0 = arith.constant 0 : index
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
+    %tdesc_a = xegpu.create_nd_tdesc %a[%cst0, %cst0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
     %load_a =  xegpu.load_nd %tdesc_a
       : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b[%cst0, %cst0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
     %load_b =  xegpu.load_nd %tdesc_b
@@ -134,7 +149,8 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK: xegpu.prefetch_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -153,7 +169,8 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: broadcast_dim1
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
   gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x1xf32>
       -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
@@ -169,7 +186,8 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: broadcast_dim0
   // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
   gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<1x128xf32>
       -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -229,9 +247,10 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %c1_i32 = arith.constant 1 : i32
     %c10_i32 = arith.constant 10 : i32
     %c0_i32 = arith.constant 0 : i32
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
-    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
 
     // CHECK: scf.while {{.*}} : (vector<16xf32>, i32) -> (vector<16xf32>, i32)
     %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
@@ -253,9 +272,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   gpu.func @scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c10 = arith.constant 10 : index
     %id = gpu.subgroup_id : index
-
-    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
-    %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
 
     %4 = arith.cmpi eq, %id, %c10 : index
     // CHECK-LABEL: scf.if
@@ -282,8 +301,8 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   gpu.func @scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c10 = arith.constant 10 : index
     %id = gpu.subgroup_id : index
-
-    %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %cst0 = arith.constant 0 : index
+    %t = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
 
     %0 = arith.cmpi eq, %id, %c10 : index
@@ -292,14 +311,14 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>) {
       // CHECK-LABEL: xegpu.create_nd_tdesc
       //  CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
-      %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      %2 = xegpu.create_nd_tdesc %arg0[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>
       scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     } else {
       // CHECK-LABEL: xegpu.create_nd_tdesc
       //  CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
-      %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      %3 = xegpu.create_nd_tdesc %arg1[%cst0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
       //  CHECK-SAME: !xegpu.tensor_desc<16xf32>
       scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
@@ -321,7 +340,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %cond = arith.andi %cond1, %cond2 : i1
     scf.if %cond {
         // CHECK-NOT: index.sub
-        %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+        %tdesc = xegpu.create_nd_tdesc %src[%c0, %c0] : memref<256x128xf32>
           -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
         %load =  xegpu.load_nd %tdesc
           : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
@@ -334,7 +353,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
       // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
       // CHECK: %[[C2:.*]] = arith.constant 2 : index
       // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
-      %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
+      %tdesc = xegpu.create_nd_tdesc %src2[%c0, %c0] : memref<128x64xf32>
         -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
       %load =  xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
@@ -347,10 +366,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: @subgroup_id_range_nested_if
   gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
     %sg_id = gpu.subgroup_id : index
-    %c1 = arith.constant 1 : i1 
+    %cst0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i1
     %c3 = arith.constant 3 : index
     %c32 = arith.constant 32 : index
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[%cst0, %cst0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
@@ -363,7 +383,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
         // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
         // CHECK: %[[C3:.*]] = arith.constant 3 : index
         // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
-        %td = xegpu.create_nd_tdesc %src1[0, 0] : memref<128x64xf32>
+        %td = xegpu.create_nd_tdesc %src1[%cst0, %cst0] : memref<128x64xf32>
           -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
         %ld =  xegpu.load_nd %td
           : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>

>From d43e55ea03bbdaad98644a14a4dc8383b0a8af97 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 11 Aug 2025 15:36:12 +0000
Subject: [PATCH 02/12] Add newfile

---
 .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 14 ++++++++++++++
 .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir    | 13 +++++++++++++
 2 files changed, 27 insertions(+)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
new file mode 100644
index 0000000000000..75e2cd7dc00b9
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_distribution {
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+    }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
new file mode 100644
index 0000000000000..2d2a0a87870d1
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_distribution {
+  // CHECK-LABEL: create_nd_tdesc_no_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_no_offset(%src: memref<256x128xf32>) {
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      gpu.return
+  }
+}
\ No newline at end of file

>From 036bc993aeb48b949bc1efa387e466af620264a4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 11 Aug 2025 15:37:23 +0000
Subject: [PATCH 03/12] Newline

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 2 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir    | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 26cac2735a4fb..e5cc65e6bd3d7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -247,4 +247,4 @@ gpu.module @test_round_robin_assignment {
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
     gpu.return
   }
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 059f9f01797eb..180ba8a162c9f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -365,4 +365,4 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
-}
\ No newline at end of file
+}

>From e21900340c3adea3480909230658d7f8dc31b10b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 11 Aug 2025 15:38:06 +0000
Subject: [PATCH 04/12] Newline

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 2 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir    | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 75e2cd7dc00b9..b6f44b5bc0b68 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -11,4 +11,4 @@ gpu.module @test_distribution {
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
     }
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 2d2a0a87870d1..837ee386c0f99 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -10,4 +10,4 @@ gpu.module @test_distribution {
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
   }
-}
\ No newline at end of file
+}

>From 9e799d6b7b7284f4ba5fed55a30f520c30f864b2 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 11 Aug 2025 20:36:27 +0000
Subject: [PATCH 05/12] Fix builders

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  6 +--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 52 +++++++------------
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  5 +-
 3 files changed, 22 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 1a6a34c8d775a..480b43e740736 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,11 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
   let builders = [
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value ": $source,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>,
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 2cd086feb5deb..4dd937eb5114d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -156,41 +156,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && strides.size() && shape.size() == strides.size() &&
-         "Shape and strides must be present and of equal size for ui64 "
-         "initialization.");
+  Type srcTy = source.getType();
+  assert((isa<IntegerType, MemRefType>(srcTy)) &&
+         "Source has to be either int or memref.");
 
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
-        dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
-        staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && strides.size() && shape.size() == strides.size() &&
-         "Shape and strides must be present and of equal size for ui64 "
-         "initialization.");
-
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +175,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    // to keep the IR print clean.
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
         dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
         staticStridesAttr);
@@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList(
 void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                    OperandRange values,
                                    DenseI64ArrayAttr integers) {
-
-  if (!integers)
+  if (!integers || integers.empty())
     return;
-
-  return printDynamicIndexList(printer, op, values, integers,
-                               /*scalableFlags=*/{}, {},
-                               AsmParser::Delimiter::Square);
+  printDynamicIndexList(printer, op, values, integers,
+                        /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
 }
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d776b01c5e3bd..d626111161b44 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -240,9 +240,8 @@ struct WgToSgCreateNdOpNoOffset
     SmallVector<Value> newCreateNdOps;
     for (int i = 0; i < count; ++i) {
       auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
-          ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
-          DenseI64ArrayAttr());
+          rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
+          op.getMixedStrides());
       newCreateNdOps.push_back(newOp);
     }
     rewriter.replaceOpWithMultiple(op, {newCreateNdOps});

>From fcbdb9138db8b3beb941c2aa7599e09572354de5 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 11 Aug 2025 22:06:49 +0000
Subject: [PATCH 06/12] Add pattern for load/store/prefetch nd with offsets

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  18 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  46 +++++
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 178 ++++++++++++++++++
 3 files changed, 241 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 480b43e740736..4f0e0f64f99be 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
 
   let builders = [
     OpBuilder<(ins "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
+                   "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $TensorDesc,
+                   "ArrayRef<OpFoldResult>": $offsets,
                    "xegpu::CachePolicyAttr": $l1_hint,
                    "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
@@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+                    "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+                    "ArrayRef<OpFoldResult>": $offsets,
                     "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
@@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
                    "xegpu::CachePolicyAttr": $l1_hint,
                    "xegpu::CachePolicyAttr": $l2_hint,
-                   "xegpu::CachePolicyAttr": $l3_hint)>
+                   "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+                  "ArrayRef<OpFoldResult>": $offsets,
+                  "xegpu::CachePolicyAttr": $l1_hint,
+                  "xegpu::CachePolicyAttr": $l2_hint,
+                  "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4dd937eb5114d..c36ca150ba023 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -364,6 +364,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
                l1_hint, l2_hint, l3_hint);
 }
 
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+                         Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                         xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+        l2_hint, l3_hint);
+}
+
 LogicalResult PrefetchNdOp::verify() {
   auto tdescTy = getTensorDescType();
   if (tdescTy.isScattered())
@@ -406,6 +421,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
                l3_hint);
 }
 
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+                     Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                     UnitAttr packed, DenseI64ArrayAttr transpose,
+                     xegpu::CachePolicyAttr l1_hint,
+                     xegpu::CachePolicyAttr l2_hint,
+                     xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+        packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
 LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
@@ -512,6 +543,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
                DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
 }
 
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+                      Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                      xegpu::CachePolicyAttr l1_hint,
+                      xegpu::CachePolicyAttr l2_hint,
+                      xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+        l1_hint, l2_hint, l3_hint);
+}
+
 LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d626111161b44..0939699ff044d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -77,6 +77,82 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// Helper function to compute new offsets for subgroup operations.
+static SmallVector<SmallVector<OpFoldResult>>
+computeSgOffsets(PatternRewriter &rewriter, Location loc,
+                 xegpu::LayoutAttr layout, Value linearSgId,
+                 ArrayRef<int64_t> wgShape, ArrayRef<OpFoldResult> oldOffsets) {
+  SmallVector<SmallVector<OpFoldResult>> result;
+  auto maybeTdescOffsets =
+      layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+  if (failed(maybeTdescOffsets))
+    return result;
+
+  for (auto &tdescOffsets : *maybeTdescOffsets) {
+    SmallVector<OpFoldResult> newOffsets;
+    size_t rank = tdescOffsets.size();
+    for (size_t i = 0; i < rank; i++) {
+      size_t idx = oldOffsets.size() - rank + i;
+      Value add = rewriter.createOrFold<index::AddOp>(
+          loc, tdescOffsets[i],
+          getValueOrCreateConstantIndexOp(rewriter, loc, oldOffsets[idx]));
+      newOffsets.push_back(add);
+    }
+    result.push_back(std::move(newOffsets));
+  }
+  return result;
+}
+
+// Helper struct to hold extracted subgroup info for ops with explicit offsets.
+struct SgOffsetInfo {
+  Location loc;
+  Value tdesc;
+  xegpu::TensorDescType tdescTy;
+  xegpu::LayoutAttr layout;
+  SmallVector<int64_t> sgShape;
+  int count;
+  Value linearSgId;
+  SmallVector<OpFoldResult> oldOffsets;
+};
+
+// Helper function to extract subgroup info for ops with explicit offsets.
+// Returns std::nullopt on failure.
+template <typename OpTy>
+std::optional<SgOffsetInfo>
+extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
+  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+  if (offsetSize == 0)
+    return std::nullopt;
+
+  Location loc = op.getLoc();
+  Value tdesc = op.getTensorDesc();
+  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+  if (!tdescTy)
+    return std::nullopt;
+  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+  if (!layout)
+    return std::nullopt;
+
+  ArrayRef<int64_t> wgShape = tdescTy.getShape();
+  SmallVector<int64_t> sgShape;
+  int count;
+  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+  Value linearSgId =
+      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+  SmallVector<OpFoldResult> oldOffsets;
+  if (auto constOffsets = op.getConstOffsetsAttr()) {
+    for (auto attr : constOffsets.asArrayRef())
+      oldOffsets.push_back(rewriter.getIndexAttr(attr));
+  }
+  for (auto v : op.getOffsets())
+    oldOffsets.push_back(v);
+
+  return SgOffsetInfo{loc,     tdesc, tdescTy,    layout,
+                      sgShape, count, linearSgId, oldOffsets};
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -275,6 +351,43 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   }
 };
 
+// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
+// data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto infoOpt = extractSgOffsetInfo(op, rewriter);
+    if (!infoOpt)
+      return failure();
+    const auto &info = *infoOpt;
+
+    auto sgOffsets =
+        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
+                         info.tdescTy.getShape(), info.oldOffsets);
+    if (sgOffsets.empty())
+      return failure();
+
+    SmallVector<Value> newLoadOps;
+    auto tdescRange = adaptor.getTensorDesc();
+    for (auto it : llvm::zip(sgOffsets, tdescRange)) {
+      VectorType newResTy =
+          VectorType::get(info.sgShape, info.tdescTy.getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+          info.loc, newResTy, std::get<1>(it), std::get<0>(it),
+          /*packed=*/nullptr,
+          /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
 /// It creates a StoreNdOp op to store the updated values to the new subgroup
 /// src tensor descriptors.
@@ -297,6 +410,39 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+    : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto infoOpt = extractSgOffsetInfo(op, rewriter);
+    if (!infoOpt)
+      return failure();
+    const auto &info = *infoOpt;
+
+    auto sgOffsets =
+        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
+                         info.tdescTy.getShape(), info.oldOffsets);
+    if (sgOffsets.empty())
+      return failure();
+
+    auto tdescRange = adaptor.getTensorDesc();
+    auto valueRange = adaptor.getValue();
+    for (auto it : llvm::zip(sgOffsets, tdescRange, valueRange)) {
+      rewriter.create<xegpu::StoreNdOp>(
+          info.loc, std::get<2>(it), std::get<1>(it), std::get<0>(it),
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
 /// offsets of the new subgroup src tensor descriptors.
@@ -383,6 +529,38 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+    : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto infoOpt = extractSgOffsetInfo(op, rewriter);
+    if (!infoOpt)
+      return failure();
+    const auto &info = *infoOpt;
+
+    auto sgOffsets =
+        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
+                         info.tdescTy.getShape(), info.oldOffsets);
+    if (sgOffsets.empty())
+      return failure();
+
+    auto tdescRange = adaptor.getTensorDesc();
+    for (auto it : llvm::zip(sgOffsets, tdescRange)) {
+      rewriter.create<xegpu::PrefetchNdOp>(
+          info.loc, std::get<1>(it), std::get<0>(it), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// This pattern transforms vector.broadcast ops to work at subgroup level.
 struct WgToSgVectorBroadcastOp
     : public OpConversionPattern<vector::BroadcastOp> {

>From 639e997d8febb4bc043d6e38ab44f9a90fb71870 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 12 Aug 2025 05:13:51 +0000
Subject: [PATCH 07/12] Add tests

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  10 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  44 +++-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 220 ++++++++++++++++++
 3 files changed, 269 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a414b14070c6c..bb5e03e32b654 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -120,9 +120,10 @@ struct SgOffsetInfo {
 template <typename OpTy>
 std::optional<SgOffsetInfo>
 extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
+
   int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-  if (offsetSize == 0)
-    return std::nullopt;
+  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+      return std::nullopt;
 
   Location loc = op.getLoc();
   Value tdesc = op.getTensorDesc();
@@ -832,8 +833,9 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns
       .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
-           WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
-           WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
+           WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+           WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+           WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
           patterns.getContext());
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index b6f44b5bc0b68..f8f2ba4cbd393 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -10,5 +10,47 @@ gpu.module @test_distribution {
       %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
-    }
+  }
+
+  // CHECK-LABEL: load_nd_tdesc_with_offset
+  gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+    // CHECK-NOT: xegpu.load_nd
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: store_nd_with_offset
+  gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.store_nd
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    xegpu.store_nd %load, %tdesc[0, 0]
+      : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.prefetch_nd
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 025d48e22307e..87c13be3f8da8 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -21,4 +23,222 @@ gpu.module @test_distribution {
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
   }
+
+  // CHECK-LABEL: load_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: store_nd_with_offsets
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    xegpu.store_nd %load, %tdesc[0, 0]
+      : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+}
+
+  // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc[%cst0, %cst0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas
+  gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+      -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas_no_sg_data
+  gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      order = [1, 0]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      order = [1, 0]>>
+      -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      order = [1, 0]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      order = [1, 0]>>
+      -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+      : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas_with_no_create_nd_desc
+  gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+    // CHECK-NOT: vector<32x32xf32>
+    %dpas = xegpu.dpas %a, %b
+      {layout =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: broadcast_dim1
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+  gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+      -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+      -> vector<256x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+      : vector<256x1xf32> to vector<256x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: broadcast_dim0
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+  gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+      -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<1x128xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<1x128xf32> to vector<32x128xf32>
+    gpu.return
+  }
+
+  gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c128 : index
+    %1 = arith.muli %block_id_y, %c128 : index
+    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+
+    //      CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
+    // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+    // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
+    //      CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //      CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //      CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //      CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
+    //      CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
+    //      CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
+        -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
+            !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
+      %8 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %9 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
+                          : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+      %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+      %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+      scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
+                                !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
+    }
+    %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32>
+            -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    xegpu.store_nd %6#2, %7  : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    gpu.return
+  }
 }

>From 35bdf571cc274c285bf8246814eb257bf91b437a Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 12 Aug 2025 23:15:27 +0000
Subject: [PATCH 08/12] Refactor

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 350 +++++++++---------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  53 +--
 2 files changed, 207 insertions(+), 196 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index bb5e03e32b654..a874a9635ca52 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -77,83 +77,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
-// Helper function to compute new offsets for subgroup operations.
-static SmallVector<SmallVector<OpFoldResult>>
-computeSgOffsets(PatternRewriter &rewriter, Location loc,
-                 xegpu::LayoutAttr layout, Value linearSgId,
-                 ArrayRef<int64_t> wgShape, ArrayRef<OpFoldResult> oldOffsets) {
-  SmallVector<SmallVector<OpFoldResult>> result;
-  auto maybeTdescOffsets =
-      layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-  if (failed(maybeTdescOffsets))
-    return result;
-
-  for (auto &tdescOffsets : *maybeTdescOffsets) {
-    SmallVector<OpFoldResult> newOffsets;
-    size_t rank = tdescOffsets.size();
-    for (size_t i = 0; i < rank; i++) {
-      size_t idx = oldOffsets.size() - rank + i;
-      Value add = rewriter.createOrFold<index::AddOp>(
-          loc, tdescOffsets[i],
-          getValueOrCreateConstantIndexOp(rewriter, loc, oldOffsets[idx]));
-      newOffsets.push_back(add);
-    }
-    result.push_back(std::move(newOffsets));
-  }
-  return result;
-}
-
-// Helper struct to hold extracted subgroup info for ops with explicit offsets.
-struct SgOffsetInfo {
-  Location loc;
-  Value tdesc;
-  xegpu::TensorDescType tdescTy;
-  xegpu::LayoutAttr layout;
-  SmallVector<int64_t> sgShape;
-  int count;
-  Value linearSgId;
-  SmallVector<OpFoldResult> oldOffsets;
-};
-
-// Helper function to extract subgroup info for ops with explicit offsets.
-// Returns std::nullopt on failure.
-template <typename OpTy>
-std::optional<SgOffsetInfo>
-extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
-
-  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
-      return std::nullopt;
-
-  Location loc = op.getLoc();
-  Value tdesc = op.getTensorDesc();
-  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
-  if (!tdescTy)
-    return std::nullopt;
-  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-  if (!layout)
-    return std::nullopt;
-
-  ArrayRef<int64_t> wgShape = tdescTy.getShape();
-  SmallVector<int64_t> sgShape;
-  int count;
-  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
-
-  Value linearSgId =
-      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-  SmallVector<OpFoldResult> oldOffsets;
-  if (auto constOffsets = op.getConstOffsetsAttr()) {
-    for (auto attr : constOffsets.asArrayRef())
-      oldOffsets.push_back(rewriter.getIndexAttr(attr));
-  }
-  for (auto v : op.getOffsets())
-    oldOffsets.push_back(v);
-
-  return SgOffsetInfo{loc,     tdesc, tdescTy,    layout,
-                      sgShape, count, linearSgId, oldOffsets};
-}
-
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -351,43 +274,6 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   }
 };
 
-// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
-// data.
-struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
-  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    auto infoOpt = extractSgOffsetInfo(op, rewriter);
-    if (!infoOpt)
-      return failure();
-    const auto &info = *infoOpt;
-
-    auto sgOffsets =
-        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
-                         info.tdescTy.getShape(), info.oldOffsets);
-    if (sgOffsets.empty())
-      return failure();
-
-    SmallVector<Value> newLoadOps;
-    auto tdescRange = adaptor.getTensorDesc();
-    for (auto it : llvm::zip(sgOffsets, tdescRange)) {
-      VectorType newResTy =
-          VectorType::get(info.sgShape, info.tdescTy.getElementType());
-      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
-          info.loc, newResTy, std::get<1>(it), std::get<0>(it),
-          /*packed=*/nullptr,
-          /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
-          op.getL3HintAttr());
-      newLoadOps.push_back(newLoadOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newLoadOps});
-    return success();
-  }
-};
-
 /// This pattern transforms the StoreNdOp to store to a subgroup descriptor
 /// It creates a StoreNdOp op to store the updated values to the new subgroup
 /// src tensor descriptors.
@@ -410,36 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
+// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
+// data.
+// Use a template parameter for the adaptor type
+template <typename OpTy, typename AdaptorTy, typename CreateFn>
+LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
+                                       ConversionPatternRewriter &rewriter,
+                                       CreateFn &&createOp) {
+  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+    return failure();
+
+  Location loc = op.getLoc();
+  Value tdesc = op.getTensorDesc();
+  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+  if (!tdescTy)
+    return failure();
+  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+  if (!layout)
+    return failure();
+
+  SmallVector<int64_t> sgLayout;
+  if (auto sgLayoutAttr = layout.getSgLayout())
+    sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+  else
+    return rewriter.notifyMatchFailure(
+        op, "sgLayout attribute is required in layout");
+
+  ArrayRef<int64_t> wgShape = tdescTy.getShape();
+  SmallVector<int64_t> sgShape;
+  int count;
+  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+  // Get the subgroup ID
+  Value linearSgId =
+      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+  if (sgIdRangeSpecified) {
+    int64_t sgCount = endOfRange - startOfRange;
+    if (computeProduct(sgLayout) != sgCount)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    linearSgId =
+        rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+  }
+
+  auto maybeTdescOffsets =
+      layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+  if (failed(maybeTdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> oldOffsets;
+  if (auto constOffsets = op.getConstOffsetsAttr()) {
+    for (auto attr : constOffsets.asArrayRef())
+      oldOffsets.push_back(rewriter.getIndexAttr(attr));
+  }
+  for (auto v : op.getOffsets())
+    oldOffsets.push_back(v);
+
+  // Delegate to the operation-specific creation function
+  return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
+                  rewriter, op);
+}
+
+// Usage for LoadNdOp
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xegpu::LoadNdOp op,
+      typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    return distributeNdOpWithOffset(
+        op, adaptor, rewriter,
+        [](Location loc, SmallVector<int64_t> &sgShape,
+           ArrayRef<SmallVector<Value>> tdescOffsetsList,
+           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+           ConversionPatternRewriter &rewriter,
+           xegpu::LoadNdOp &op) -> LogicalResult {
+          SmallVector<Value> newLoadOps;
+          for (auto [tdescOffsets, tdesc] :
+               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
+            SmallVector<OpFoldResult> newOffsets;
+            size_t rank = tdescOffsets.size();
+            for (size_t i = 0; i < rank; i++) {
+              size_t idx = oldOffsets.size() - rank + i;
+              Value add = rewriter.createOrFold<index::AddOp>(
+                  loc, tdescOffsets[i],
+                  getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                  oldOffsets[idx]));
+              newOffsets.push_back(add);
+            }
+            VectorType newResTy = VectorType::get(
+                sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
+                             .getElementType());
+            auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+                loc, newResTy, tdesc, newOffsets,
+                /*packed=*/nullptr,
+                /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+                op.getL3HintAttr());
+            newLoadOps.push_back(newLoadOp);
+          }
+          rewriter.replaceOpWithMultiple(op, {newLoadOps});
+          return success();
+        });
+  }
+};
+
 // This pattern transforms the StoreNdOp with explicit offsets to store
 // subgroup data.
 struct WgToSgStoreNdOpWithOffset
     : public OpConversionPattern<xegpu::StoreNdOp> {
   using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xegpu::StoreNdOp op,
+      typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    return distributeNdOpWithOffset(
+        op, adaptor, rewriter,
+        [](Location loc, SmallVector<int64_t> &sgShape,
+           ArrayRef<SmallVector<Value>> tdescOffsetsList,
+           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+           ConversionPatternRewriter &rewriter,
+           xegpu::StoreNdOp &op) -> LogicalResult {
+          for (auto [tdescOffsets, tdesc, value] :
+               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
+                         adaptor.getValue())) {
+            SmallVector<OpFoldResult> newOffsets;
+            size_t rank = tdescOffsets.size();
+            for (size_t i = 0; i < rank; i++) {
+              size_t idx = oldOffsets.size() - rank + i;
+              Value add = rewriter.createOrFold<index::AddOp>(
+                  loc, tdescOffsets[i],
+                  getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                  oldOffsets[idx]));
+              newOffsets.push_back(add);
+            }
+            rewriter.create<xegpu::StoreNdOp>(
+                loc, value, tdesc, newOffsets, op.getL1HintAttr(),
+                op.getL2HintAttr(), op.getL3HintAttr());
+          }
+          rewriter.eraseOp(op);
+          return success();
+        });
+  }
+};
 
-  LogicalResult
-  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    auto infoOpt = extractSgOffsetInfo(op, rewriter);
-    if (!infoOpt)
-      return failure();
-    const auto &info = *infoOpt;
-
-    auto sgOffsets =
-        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
-                         info.tdescTy.getShape(), info.oldOffsets);
-    if (sgOffsets.empty())
-      return failure();
-
-    auto tdescRange = adaptor.getTensorDesc();
-    auto valueRange = adaptor.getValue();
-    for (auto it : llvm::zip(sgOffsets, tdescRange, valueRange)) {
-      rewriter.create<xegpu::StoreNdOp>(
-          info.loc, std::get<2>(it), std::get<1>(it), std::get<0>(it),
-          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
-    }
-    rewriter.eraseOp(op);
-    return success();
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+    : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xegpu::PrefetchNdOp op,
+      typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
+          adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    return distributeNdOpWithOffset(
+        op, adaptor, rewriter,
+        [](Location loc, SmallVector<int64_t> &sgShape,
+           ArrayRef<SmallVector<Value>> tdescOffsetsList,
+           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+           ConversionPatternRewriter &rewriter,
+           xegpu::PrefetchNdOp &op) -> LogicalResult {
+          for (auto [tdescOffsets, tdesc] :
+               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
+            SmallVector<OpFoldResult> newOffsets;
+            size_t rank = tdescOffsets.size();
+            for (size_t i = 0; i < rank; i++) {
+              size_t idx = oldOffsets.size() - rank + i;
+              Value add = rewriter.createOrFold<index::AddOp>(
+                  loc, tdescOffsets[i],
+                  getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                  oldOffsets[idx]));
+              newOffsets.push_back(add);
+            }
+            rewriter.create<xegpu::PrefetchNdOp>(
+                loc, tdesc, newOffsets, op.getL1HintAttr(), op.getL2HintAttr(),
+                op.getL3HintAttr());
+          }
+          rewriter.eraseOp(op);
+          return success();
+        });
   }
 };
 
@@ -529,38 +571,6 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
-// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
-// subgroup data.
-struct WgToSgPrefetchNdOpWithOffset
-    : public OpConversionPattern<xegpu::PrefetchNdOp> {
-  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    auto infoOpt = extractSgOffsetInfo(op, rewriter);
-    if (!infoOpt)
-      return failure();
-    const auto &info = *infoOpt;
-
-    auto sgOffsets =
-        computeSgOffsets(rewriter, info.loc, info.layout, info.linearSgId,
-                         info.tdescTy.getShape(), info.oldOffsets);
-    if (sgOffsets.empty())
-      return failure();
-
-    auto tdescRange = adaptor.getTensorDesc();
-    for (auto it : llvm::zip(sgOffsets, tdescRange)) {
-      rewriter.create<xegpu::PrefetchNdOp>(
-          info.loc, std::get<1>(it), std::get<0>(it), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
-    }
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
 /// This pattern transforms vector.broadcast ops to work at subgroup level.
 struct WgToSgVectorBroadcastOp
     : public OpConversionPattern<vector::BroadcastOp> {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 87c13be3f8da8..a7d20613f657e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -199,7 +199,9 @@ gpu.module @test_distribution {
     gpu.return
   }
 
-  gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+  // CHECK-LABEL: gemm_with_load_store_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1024x1024xf16>, %[[ARG_1:.*]]: memref<1024x1024xf16>, %[[ARG_2:.*]]: memref<1024x1024xf32>
+  gpu.func @gemm_with_load_store_offset(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
     //CHECK: [[c0:%.+]] = arith.constant 0 : index
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
     //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
@@ -210,35 +212,34 @@ gpu.module @test_distribution {
     %block_id_y = gpu.block_id  y
     %0 = arith.muli %block_id_x, %c128 : index
     %1 = arith.muli %block_id_y, %c128 : index
-    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
-    %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
-    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
-    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
-
+    %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    // CHECK: [[DESC_A:%.+]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x128xf16>
+    // CHECK: [[DESC_B:%.+]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x16xf16>
+    %3 = xegpu.create_nd_tdesc %arg0 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %4 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+    // load_nd with offset
+    %5 = xegpu.load_nd %2[%0, %1] : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %6 = xegpu.load_nd %3[%0, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+    %7 = xegpu.load_nd %4[%c0, %1] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+    // scf.for loop
     //      CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
     // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
-    // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
-    //      CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
-    //      CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
-    //      CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
-    //      CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
-    //      CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
-    //      CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
-    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
-        -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
-            !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
-      %8 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
-      %9 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
-      %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
+    // CHECK-SAME: (vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>)
+    //      CHECK: [[c:%.+]] = xegpu.dpas [[arg4]], [[arg5]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //      CHECK: [[a:%.+]] = xegpu.load_nd [[DESC_A]][{{%.*}}, {{%.*}}]  : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //      CHECK: [[b:%.+]] = xegpu.load_nd [[DESC_B]][{{%.*}}, {{%.*}}]  : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //      CHECK: scf.yield [[a]], [[b]], [[c]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>
+    %8:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %6, %arg5 = %7, %arg6 = %5)
+        -> (vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>) {
+      // load_nd with offset inside loop
+      %9 = xegpu.dpas %arg4, %arg5, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
                           : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
-      %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
-      %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
-      scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
-                                !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
+      %10 = xegpu.load_nd %3[%arg3, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %11 = xegpu.load_nd %4[%c0, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      scf.yield %10, %11, %9 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>
     }
-    %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32>
-            -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
-    xegpu.store_nd %6#2, %7  : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    // store_nd with offset
+    xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
     gpu.return
   }
 }

>From a1b35a43e7f2b9853e507e3fe2ff96c310ef4b7c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 12 Aug 2025 23:30:33 +0000
Subject: [PATCH 09/12] Add more tests

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  7 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 29 ++++++++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 66 +++++++++++++++++++
 3 files changed, 97 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a874a9635ca52..6848fb6b92050 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -296,9 +296,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
-// This pattern transforms the LoadNdOp with explicit offsets to load subgroup
-// data.
-// Use a template parameter for the adaptor type
 template <typename OpTy, typename AdaptorTy, typename CreateFn>
 LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
                                        ConversionPatternRewriter &rewriter,
@@ -359,12 +356,12 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
   for (auto v : op.getOffsets())
     oldOffsets.push_back(v);
 
-  // Delegate to the operation-specific creation function
   return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
                   rewriter, op);
 }
 
-// Usage for LoadNdOp
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
 struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
   LogicalResult matchAndRewrite(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index f8f2ba4cbd393..6ff7a94d678a3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -53,4 +53,33 @@ gpu.module @test_distribution {
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
+
+  // CHECK-LABEL: dpas
+  // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+  gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
+    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK-NOT: xegpu.dpas
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
+      -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16>
+      -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+      -> vector<128x256xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index a7d20613f657e..27fc474bd44cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -242,4 +242,70 @@ gpu.module @test_distribution {
     xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
     gpu.return
   }
+
+  // CHECK-LABEL: @subgroup_id_range
+  gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c31 = arith.constant 31 : index
+    %c3 = arith.constant 3 : index
+    %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %cond {
+        // CHECK-NOT: index.sub
+        %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load =  xegpu.load_nd %tdesc[0, 0]
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+    } {sg_id_range = #xegpu.range<[0, 32]>}
+    %cond3 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c31 : index
+    %cond5 = arith.andi %cond3, %cond4 : i1
+    scf.if %cond5 {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C2:.*]] = arith.constant 2 : index
+      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+      %tdesc = xegpu.create_nd_tdesc %src2 : memref<128x64xf32>
+        -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      %load =  xegpu.load_nd %tdesc[0, 0]
+        : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<128x64xf32>
+      %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }{sg_id_range = #xegpu.range<[2, 18]>}
+    gpu.return
+  }
+
+  // CHECK-LABEL: @subgroup_id_range_nested_if
+  gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c1 = arith.constant 1 : i1
+    %c3 = arith.constant 3 : index
+    %c32 = arith.constant 32 : index
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %c1 {
+      scf.if %cond {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C3:.*]] = arith.constant 3 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+        %td = xegpu.create_nd_tdesc %src1 : memref<128x64xf32>
+          -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %ld =  xegpu.load_nd %td[0, 0]
+          : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<128x64xf32>
+        %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }
+  } {sg_id_range = #xegpu.range<[3, 19]>}
+  gpu.return
+  }
 }

>From 45e56ffce9eb6ac423b170c3d7a1712a3ab40873 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 15 Aug 2025 18:14:53 +0000
Subject: [PATCH 10/12] Address feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 198 +++++++++---------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  45 ----
 2 files changed, 99 insertions(+), 144 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6848fb6b92050..393e784a37837 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -296,10 +296,38 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
-template <typename OpTy, typename AdaptorTy, typename CreateFn>
-LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
-                                       ConversionPatternRewriter &rewriter,
-                                       CreateFn &&createOp) {
+// Utility function to compute distributed offsets for subgroup operations.
+// Returns a vector of new offsets for each subgroup, given the original op's
+// offsets and subgroup relative offsets.
+static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
+    Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+    ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
+  SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
+  Location loc = op->getLoc();
+  for (const auto &sgOffsets : sgOffsetsList) {
+    SmallVector<OpFoldResult> newOffsets;
+    size_t rank = sgOffsets.size();
+    for (size_t i = 0; i < rank; i++) {
+      size_t idx = wgOffsets.size() - rank + i;
+      Value add = rewriter.createOrFold<index::AddOp>(
+          loc, sgOffsets[i],
+          getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+      newOffsets.push_back(add);
+    }
+    distributedOffsets.push_back(std::move(newOffsets));
+  }
+  return distributedOffsets;
+}
+
+// Utility function to get sgShape, sgOffsetList, and wgOffsets for a given
+// op.
+template <typename OpTy, typename AdaptorTy>
+LogicalResult
+prepareOpDistribution(OpTy op, AdaptorTy adaptor,
+                      ConversionPatternRewriter &rewriter,
+                      SmallVector<int64_t> &sgShape,
+                      SmallVector<SmallVector<Value>> &sgOffsetList,
+                      SmallVector<OpFoldResult> &wgOffsets) {
   int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
   if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
     return failure();
@@ -321,7 +349,6 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
         op, "sgLayout attribute is required in layout");
 
   ArrayRef<int64_t> wgShape = tdescTy.getShape();
-  SmallVector<int64_t> sgShape;
   int count;
   std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
@@ -343,21 +370,19 @@ LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
         rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
   }
 
-  auto maybeTdescOffsets =
-      layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-  if (failed(maybeTdescOffsets))
+  auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+  if (failed(sgOffsets))
     return failure();
 
-  SmallVector<OpFoldResult> oldOffsets;
   if (auto constOffsets = op.getConstOffsetsAttr()) {
     for (auto attr : constOffsets.asArrayRef())
-      oldOffsets.push_back(rewriter.getIndexAttr(attr));
+      wgOffsets.push_back(rewriter.getIndexAttr(attr));
   }
   for (auto v : op.getOffsets())
-    oldOffsets.push_back(v);
+    wgOffsets.push_back(v);
 
-  return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
-                  rewriter, op);
+  sgOffsetList = *sgOffsets;
+  return success();
 }
 
 // This pattern transforms the LoadNdOp with explicit offsets to load
@@ -368,39 +393,31 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
       xegpu::LoadNdOp op,
       typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    return distributeNdOpWithOffset(
-        op, adaptor, rewriter,
-        [](Location loc, SmallVector<int64_t> &sgShape,
-           ArrayRef<SmallVector<Value>> tdescOffsetsList,
-           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
-           ConversionPatternRewriter &rewriter,
-           xegpu::LoadNdOp &op) -> LogicalResult {
-          SmallVector<Value> newLoadOps;
-          for (auto [tdescOffsets, tdesc] :
-               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
-            SmallVector<OpFoldResult> newOffsets;
-            size_t rank = tdescOffsets.size();
-            for (size_t i = 0; i < rank; i++) {
-              size_t idx = oldOffsets.size() - rank + i;
-              Value add = rewriter.createOrFold<index::AddOp>(
-                  loc, tdescOffsets[i],
-                  getValueOrCreateConstantIndexOp(rewriter, loc,
-                                                  oldOffsets[idx]));
-              newOffsets.push_back(add);
-            }
-            VectorType newResTy = VectorType::get(
-                sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
-                             .getElementType());
-            auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
-                loc, newResTy, tdesc, newOffsets,
-                /*packed=*/nullptr,
-                /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
-                op.getL3HintAttr());
-            newLoadOps.push_back(newLoadOp);
-          }
-          rewriter.replaceOpWithMultiple(op, {newLoadOps});
-          return success();
-        });
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+    SmallVector<OpFoldResult> wgOffsets;
+    if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
+                                     sgOffsetList, wgOffsets)))
+      return failure();
+
+    auto distributedOffsets =
+        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+
+    SmallVector<Value> newLoadOps;
+    for (auto [newOffsets, tdesc] :
+         llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
+      VectorType newResTy = VectorType::get(
+          sgShape,
+          dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+          op.getLoc(), newResTy, tdesc, newOffsets,
+          /*packed=*/nullptr,
+          /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
   }
 };
 
@@ -413,33 +430,24 @@ struct WgToSgStoreNdOpWithOffset
       xegpu::StoreNdOp op,
       typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    return distributeNdOpWithOffset(
-        op, adaptor, rewriter,
-        [](Location loc, SmallVector<int64_t> &sgShape,
-           ArrayRef<SmallVector<Value>> tdescOffsetsList,
-           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
-           ConversionPatternRewriter &rewriter,
-           xegpu::StoreNdOp &op) -> LogicalResult {
-          for (auto [tdescOffsets, tdesc, value] :
-               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
-                         adaptor.getValue())) {
-            SmallVector<OpFoldResult> newOffsets;
-            size_t rank = tdescOffsets.size();
-            for (size_t i = 0; i < rank; i++) {
-              size_t idx = oldOffsets.size() - rank + i;
-              Value add = rewriter.createOrFold<index::AddOp>(
-                  loc, tdescOffsets[i],
-                  getValueOrCreateConstantIndexOp(rewriter, loc,
-                                                  oldOffsets[idx]));
-              newOffsets.push_back(add);
-            }
-            rewriter.create<xegpu::StoreNdOp>(
-                loc, value, tdesc, newOffsets, op.getL1HintAttr(),
-                op.getL2HintAttr(), op.getL3HintAttr());
-          }
-          rewriter.eraseOp(op);
-          return success();
-        });
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+    SmallVector<OpFoldResult> wgOffsets;
+    if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
+                                     sgOffsetList, wgOffsets)))
+      return failure();
+
+    auto distributedOffsets =
+        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+
+    for (auto [newOffsets, tdesc, value] : llvm::zip(
+             distributedOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, newOffsets,
+                                        op.getL1HintAttr(), op.getL2HintAttr(),
+                                        op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
   }
 };
 
@@ -453,32 +461,24 @@ struct WgToSgPrefetchNdOpWithOffset
       typename OpConversionPattern<xegpu::PrefetchNdOp>::OneToNOpAdaptor
           adaptor,
       ConversionPatternRewriter &rewriter) const override {
-    return distributeNdOpWithOffset(
-        op, adaptor, rewriter,
-        [](Location loc, SmallVector<int64_t> &sgShape,
-           ArrayRef<SmallVector<Value>> tdescOffsetsList,
-           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
-           ConversionPatternRewriter &rewriter,
-           xegpu::PrefetchNdOp &op) -> LogicalResult {
-          for (auto [tdescOffsets, tdesc] :
-               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
-            SmallVector<OpFoldResult> newOffsets;
-            size_t rank = tdescOffsets.size();
-            for (size_t i = 0; i < rank; i++) {
-              size_t idx = oldOffsets.size() - rank + i;
-              Value add = rewriter.createOrFold<index::AddOp>(
-                  loc, tdescOffsets[i],
-                  getValueOrCreateConstantIndexOp(rewriter, loc,
-                                                  oldOffsets[idx]));
-              newOffsets.push_back(add);
-            }
-            rewriter.create<xegpu::PrefetchNdOp>(
-                loc, tdesc, newOffsets, op.getL1HintAttr(), op.getL2HintAttr(),
-                op.getL3HintAttr());
-          }
-          rewriter.eraseOp(op);
-          return success();
-        });
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+    SmallVector<OpFoldResult> wgOffsets;
+    if (failed(prepareOpDistribution(op, adaptor, rewriter, sgShape,
+                                     sgOffsetList, wgOffsets)))
+      return failure();
+
+    auto distributedOffsets =
+        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+
+    for (auto [newOffsets, tdesc] :
+         llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
+      rewriter.create<xegpu::PrefetchNdOp>(
+          op.getLoc(), tdesc, newOffsets, op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 27fc474bd44cb..07a0b86223c33 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -28,23 +28,8 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
     //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
     //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
-    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
     //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -58,23 +43,8 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
     //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
     //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
-    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
     //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -90,23 +60,8 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
     //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
     //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
-    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
     //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %cst0 = arith.constant 0 : index
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>

>From 7d3dde7078eb1b5f2f87f8c497a02f64d1f13913 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 15 Aug 2025 20:29:17 +0000
Subject: [PATCH 11/12] change variable name

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 33 ++++++++++---------
 1 file changed, 17 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 966eba3774b0c..d5f3da836ff96 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -296,13 +296,14 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
-// Utility function to compute distributed offsets for subgroup operations.
+// Utility function to compute global offsets for subgroup operations.
 // Returns a vector of new offsets for each subgroup, given the original op's
 // offsets and subgroup relative offsets.
-static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
-    Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
-    ArrayRef<OpFoldResult> wgOffsets, ConversionPatternRewriter &rewriter) {
-  SmallVector<SmallVector<OpFoldResult>> distributedOffsets;
+static SmallVector<SmallVector<OpFoldResult>>
+computeGlobalOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+                     ArrayRef<OpFoldResult> wgOffsets,
+                     ConversionPatternRewriter &rewriter) {
+  SmallVector<SmallVector<OpFoldResult>> globalOffsets;
   Location loc = op->getLoc();
   for (const auto &sgOffsets : sgOffsetsList) {
     SmallVector<OpFoldResult> newOffsets;
@@ -314,9 +315,9 @@ static SmallVector<SmallVector<OpFoldResult>> computeDistributedOffsets(
           getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
       newOffsets.push_back(add);
     }
-    distributedOffsets.push_back(std::move(newOffsets));
+    globalOffsets.push_back(std::move(newOffsets));
   }
-  return distributedOffsets;
+  return globalOffsets;
 }
 
 // Utility function to get sgShape, sgOffsetList for a given
@@ -408,12 +409,12 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
     SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
 
     // Calculate the global offsets
-    auto distributedOffsets =
-        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    auto globalOffsets =
+        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
 
     SmallVector<Value> newLoadOps;
     for (auto [offsets, tdesc] :
-         llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
+         llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
       VectorType newResTy = VectorType::get(
           sgShape,
           dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
@@ -449,11 +450,11 @@ struct WgToSgStoreNdOpWithOffset
     SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
 
     // Calculate the global offsets
-    auto distributedOffsets =
-        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    auto globalOffsets =
+        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
 
     for (auto [offsets, tdesc, value] : llvm::zip(
-             distributedOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
+             globalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
       rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
                                         op.getL1HintAttr(), op.getL2HintAttr(),
                                         op.getL3HintAttr());
@@ -483,11 +484,11 @@ struct WgToSgPrefetchNdOpWithOffset
     SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
 
     // calculate the global offsets
-    auto distributedOffsets =
-        computeDistributedOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    auto globalOffsets =
+        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
 
     for (auto [offsets, tdesc] :
-         llvm::zip(distributedOffsets, adaptor.getTensorDesc())) {
+         llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
       rewriter.create<xegpu::PrefetchNdOp>(
           op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
           op.getL3HintAttr());

>From df3a4665ab98084dd6341284afddb34614330c86 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 18 Aug 2025 16:29:53 +0000
Subject: [PATCH 12/12] feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 69 +++++++++----------
 1 file changed, 33 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d5f3da836ff96..d20b21625aef6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                                    layout.dropSgLayoutAndData());
 
     SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
 
     for (auto tdescOffsets : *maybeTdescOffsets) {
       SmallVector<OpFoldResult> sgOffsets;
       size_t rank = tdescOffsets.size();
       for (size_t i = 0; i < rank; i++) {
-        size_t idx = wgOffsets.size() - rank + i;
+        size_t idx = origOffsets.size() - rank + i;
         Value add = rewriter.createOrFold<index::AddOp>(
             loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
         sgOffsets.push_back(add);
       }
 
@@ -300,24 +300,24 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
 // Returns a vector of new offsets for each subgroup, given the original op's
 // offsets and subgroup relative offsets.
 static SmallVector<SmallVector<OpFoldResult>>
-computeGlobalOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
-                     ArrayRef<OpFoldResult> wgOffsets,
-                     ConversionPatternRewriter &rewriter) {
-  SmallVector<SmallVector<OpFoldResult>> globalOffsets;
+computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+               ArrayRef<OpFoldResult> origOffsets,
+               ConversionPatternRewriter &rewriter) {
+  SmallVector<SmallVector<OpFoldResult>> finalOffsets;
   Location loc = op->getLoc();
   for (const auto &sgOffsets : sgOffsetsList) {
     SmallVector<OpFoldResult> newOffsets;
     size_t rank = sgOffsets.size();
     for (size_t i = 0; i < rank; i++) {
-      size_t idx = wgOffsets.size() - rank + i;
+      size_t idx = origOffsets.size() - rank + i;
       Value add = rewriter.createOrFold<index::AddOp>(
           loc, sgOffsets[i],
-          getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+          getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
       newOffsets.push_back(add);
     }
-    globalOffsets.push_back(std::move(newOffsets));
+    finalOffsets.push_back(std::move(newOffsets));
   }
-  return globalOffsets;
+  return finalOffsets;
 }
 
 // Utility function to get sgShape, sgOffsetList for a given
@@ -341,11 +341,11 @@ LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
     return failure();
 
   SmallVector<int64_t> sgLayout;
-  if (auto sgLayoutAttr = layout.getSgLayout())
-    sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-  else
+  auto sgLayoutAttr = layout.getSgLayout();
+  if (!sgLayoutAttr)
     return rewriter.notifyMatchFailure(
         op, "sgLayout attribute is required in layout");
+  sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
 
   ArrayRef<int64_t> wgShape = tdescTy.getShape();
   int count;
@@ -378,16 +378,16 @@ LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
 }
 
 template <typename OpTy>
-SmallVector<OpFoldResult> getWgOffsets(OpTy op,
-                                       ConversionPatternRewriter &rewriter) {
-  SmallVector<OpFoldResult> wgOffsets;
+SmallVector<OpFoldResult> getOffsets(OpTy op,
+                                     ConversionPatternRewriter &rewriter) {
+  SmallVector<OpFoldResult> origOffsets;
   if (auto constOffsets = op.getConstOffsetsAttr()) {
     for (auto attr : constOffsets.asArrayRef())
-      wgOffsets.push_back(rewriter.getIndexAttr(attr));
+      origOffsets.push_back(rewriter.getIndexAttr(attr));
   }
   for (auto v : op.getOffsets())
-    wgOffsets.push_back(v);
-  return wgOffsets;
+    origOffsets.push_back(v);
+  return origOffsets;
 }
 
 // This pattern transforms the LoadNdOp with explicit offsets to load
@@ -406,15 +406,14 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
       return failure();
 
     // Get the original workgroup offsets
-    SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
 
-    // Calculate the global offsets
-    auto globalOffsets =
-        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
 
     SmallVector<Value> newLoadOps;
     for (auto [offsets, tdesc] :
-         llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
+         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
       VectorType newResTy = VectorType::get(
           sgShape,
           dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
@@ -447,14 +446,13 @@ struct WgToSgStoreNdOpWithOffset
       return failure();
 
     // Get the original workgroup offsets
-    SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
 
-    // Calculate the global offsets
-    auto globalOffsets =
-        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
 
-    for (auto [offsets, tdesc, value] : llvm::zip(
-             globalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
+    for (auto [offsets, tdesc, value] :
+         llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
       rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
                                         op.getL1HintAttr(), op.getL2HintAttr(),
                                         op.getL3HintAttr());
@@ -481,14 +479,13 @@ struct WgToSgPrefetchNdOpWithOffset
       return failure();
 
     // Get the original workgroup offsets
-    SmallVector<OpFoldResult> wgOffsets = getWgOffsets(op, rewriter);
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
 
-    // calculate the global offsets
-    auto globalOffsets =
-        computeGlobalOffsets(op, sgOffsetList, wgOffsets, rewriter);
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
 
     for (auto [offsets, tdesc] :
-         llvm::zip(globalOffsets, adaptor.getTensorDesc())) {
+         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
       rewriter.create<xegpu::PrefetchNdOp>(
           op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
           op.getL3HintAttr());



More information about the Mlir-commits mailing list