[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (PR #162095)

Dmitry Chigarev llvmlistbot at llvm.org
Thu Oct 30 07:36:12 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/162095

>From ef520d09280a43cbc86cf760b24ab670508a0df8 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Mon, 6 Oct 2025 14:37:21 +0000
Subject: [PATCH 01/14] [MLIR][XeGPU][VectorToXeGPU] Lower
 vector.load/store/transfer_read/transfer_write to new offsets syntax

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 218 ++++++++++++------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  12 +-
 .../VectorToXeGPU/load-to-xegpu.mlir          |   4 +-
 .../VectorToXeGPU/store-to-xegpu.mlir         |   4 +-
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |   8 +-
 .../transfer-write-to-xegpu.mlir              |   4 +-
 6 files changed, 171 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..41526a7e34971 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,64 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
+static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
+                                      SmallVector<OpFoldResult> &mixedShapes,
+                                      SmallVector<OpFoldResult> &mixedStrides,
+                                      SmallVector<int64_t> &strides,
+                                      TypedValue<MemRefType> src) {
+  auto srcTy = src.getType();
+  // In case of any dynamic shapes, source's shape and strides have to be
+  // explicitly provided.
+  SmallVector<Value> sourceDims;
+  unsigned srcRank = srcTy.getRank();
+  for (unsigned i = 0; i < srcRank; ++i)
+    sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
+
+  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+    if (shape == ShapedType::kDynamic)
+      mixedShapes.push_back(sourceDims[idx]);
+    else
+      mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
+  }
+
+  // Compute strides in reverse order.
+  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  // Last stride is guaranteed to be static and unit.
+  mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+    accStride =
+        arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+    if (strides[i] == ShapedType::kDynamic)
+      mixedStrides.push_back(accStride);
+    else
+      mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+  }
+  std::reverse(mixedStrides.begin(), mixedStrides.end());
+}
+
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+                                                Location loc,
+                                                xegpu::TensorDescType descType,
+                                                TypedValue<MemRefType> src) {
+  MemRefType srcTy = src.getType();
+  auto [strides, offset] = srcTy.getStridesAndOffset();
+
+  xegpu::CreateNdDescOp ndDesc;
+  if (srcTy.hasStaticShape())
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
+  else {
+    SmallVector<OpFoldResult> mixedShapes;
+    SmallVector<OpFoldResult> mixedStrides;
+    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+                              src);
+
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           mixedShapes, mixedStrides);
+  }
+
+  return ndDesc;
+}
+
 static xegpu::CreateNdDescOp
 createNdDescriptor(PatternRewriter &rewriter, Location loc,
                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -109,45 +167,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
                                            getAsOpFoldResult(offsets));
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
-    // explicitly provided.
-    SmallVector<Value> sourceDims;
-    unsigned srcRank = srcTy.getRank();
-    for (unsigned i = 0; i < srcRank; ++i)
-      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-    SmallVector<int64_t> constOffsets;
-    SmallVector<Value> dynOffsets;
+    SmallVector<OpFoldResult> mixedOffsets;
     for (Value offset : offsets) {
       std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (!staticVal)
-        dynOffsets.push_back(offset);
-      constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
-    }
-
-    SmallVector<Value> dynShapes;
-    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-      if (shape == ShapedType::kDynamic)
-        dynShapes.push_back(sourceDims[idx]);
+      if (staticVal)
+        mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+      else
+        mixedOffsets.push_back(offset);
     }
 
-    // Compute strides in reverse order.
-    SmallVector<Value> dynStrides;
-    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    // Last stride is guaranteed to be static and unit.
-    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-      accStride =
-          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-      if (strides[i] == ShapedType::kDynamic)
-        dynStrides.push_back(accStride);
-    }
-    std::reverse(dynStrides.begin(), dynStrides.end());
+    SmallVector<OpFoldResult> mixedShapes;
+    SmallVector<OpFoldResult> mixedStrides;
+    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+                              src);
 
     ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
-        DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
-        DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
-        DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+        rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
   }
 
   return ndDesc;
@@ -523,21 +558,35 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
         descShape, elementType, /*array_length=*/1,
         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
 
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                           readOp.getIndices());
-
     DenseI64ArrayAttr transposeAttr =
         !isTransposeLoad ? nullptr
                          : DenseI64ArrayAttr::get(rewriter.getContext(),
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                          /*packed=*/nullptr, transposeAttr,
-                                          /*l1_hint=*/hint,
-                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::LoadNdOp loadOp;
+
+    if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                       getAsOpFoldResult(readOp.getIndices()),
+                                       /*packed=*/nullptr, transposeAttr,
+                                       /*l1_hint=*/hint,
+                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType,
+                             dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
+                             readOp.getIndices());
+
+      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                       /*packed=*/nullptr, transposeAttr,
+                                       /*l1_hint=*/hint,
+                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -579,17 +628,30 @@ struct TransferWriteLowering
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
         xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-                           writeOp.getIndices());
-
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeOp =
-        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::StoreNdOp storeOp;
+    if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+      storeOp =
+          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                   getAsOpFoldResult(writeOp.getIndices()),
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
+          writeOp.getIndices());
+
+      storeOp =
+          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -674,19 +736,32 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
 
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
 
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
 
-    // By default, no specific caching policy is assigned.
-    xegpu::CachePolicyAttr hint = nullptr;
-    auto loadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::LoadNdOp loadNdOp;
+
+    if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+      loadNdOp = xegpu::LoadNdOp::create(
+          rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+          /*packed=*/nullptr, /*transpose=*/nullptr,
+          /*l1_hint=*/hint,
+          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
+      loadNdOp =
+          xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                  /*packed=*/nullptr, /*transpose=*/nullptr,
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -711,15 +786,28 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::StoreNdOp storeNdOp;
+    if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+      storeNdOp =
+          xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                   getAsOpFoldResult(storeOp.getIndices()),
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
+
+      storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                           /*l1_hint=*/hint,
+                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
+
     rewriter.replaceOp(storeOp, storeNdOp);
 
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..8ed8b26dd2a0e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -258,8 +258,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -320,8 +322,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..c7c0485768b99 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x15xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..19240abe1e75c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79af1bd9a..72bdab0a4db3a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -83,9 +83,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
 // LOAD-ND-LABEL:  @load_zero_pad_out_of_bounds(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_zero_pad_out_of_bounds(
@@ -109,9 +109,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
 // LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
 // LOAD-ND-SAME:     -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc9414da4f6..ca3bbc11a5180 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -126,9 +126,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_out_of_bounds(
 // STORE-SCATTER:   vector.transfer_write

>From 46af25a0944431b69084a640dd399082d398e3c9 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:45:14 +0000
Subject: [PATCH 02/14] Relax len(offsets) == tdescRank requirement

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 210 +++++-------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  96 ++++++--
 .../VectorToXeGPU/load-to-xegpu.mlir          |  12 +-
 .../VectorToXeGPU/store-to-xegpu.mlir         |  12 +-
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  20 +-
 .../transfer-write-to-xegpu.mlir              |  16 +-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  14 +-
 7 files changed, 172 insertions(+), 208 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 41526a7e34971..f3dcb31f6b0be 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,41 +97,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
-                                      SmallVector<OpFoldResult> &mixedShapes,
-                                      SmallVector<OpFoldResult> &mixedStrides,
-                                      SmallVector<int64_t> &strides,
-                                      TypedValue<MemRefType> src) {
-  auto srcTy = src.getType();
-  // In case of any dynamic shapes, source's shape and strides have to be
-  // explicitly provided.
-  SmallVector<Value> sourceDims;
-  unsigned srcRank = srcTy.getRank();
-  for (unsigned i = 0; i < srcRank; ++i)
-    sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-    if (shape == ShapedType::kDynamic)
-      mixedShapes.push_back(sourceDims[idx]);
-    else
-      mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
-  }
-
-  // Compute strides in reverse order.
-  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-  // Last stride is guaranteed to be static and unit.
-  mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
-  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-    accStride =
-        arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-    if (strides[i] == ShapedType::kDynamic)
-      mixedStrides.push_back(accStride);
-    else
-      mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
-  }
-  std::reverse(mixedStrides.begin(), mixedStrides.end());
-}
-
 static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 Location loc,
                                                 xegpu::TensorDescType descType,
@@ -143,46 +108,38 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
   if (srcTy.hasStaticShape())
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   else {
-    SmallVector<OpFoldResult> mixedShapes;
-    SmallVector<OpFoldResult> mixedStrides;
-    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
-                              src);
+    // In case of any dynamic shapes, source's shape and strides have to be
+    // explicitly provided.
+    SmallVector<Value> sourceDims;
+    unsigned srcRank = srcTy.getRank();
+    for (unsigned i = 0; i < srcRank; ++i)
+      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
 
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           mixedShapes, mixedStrides);
-  }
-
-  return ndDesc;
-}
-
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
-                   xegpu::TensorDescType descType, TypedValue<MemRefType> src,
-                   Operation::operand_range offsets) {
-  MemRefType srcTy = src.getType();
-  auto [strides, offset] = srcTy.getStridesAndOffset();
-
-  xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           getAsOpFoldResult(offsets));
-  } else {
-    SmallVector<OpFoldResult> mixedOffsets;
-    for (Value offset : offsets) {
-      std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (staticVal)
-        mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+    SmallVector<OpFoldResult> mixedShapes;
+    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+      if (shape == ShapedType::kDynamic)
+        mixedShapes.push_back(sourceDims[idx]);
       else
-        mixedOffsets.push_back(offset);
+        mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
     }
 
-    SmallVector<OpFoldResult> mixedShapes;
+    // Compute strides in reverse order.
     SmallVector<OpFoldResult> mixedStrides;
-    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
-                              src);
+    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+    // Last stride is guaranteed to be static and unit.
+    mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+      accStride =
+          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+      if (strides[i] == ShapedType::kDynamic)
+        mixedStrides.push_back(accStride);
+      else
+        mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+    }
+    std::reverse(mixedStrides.begin(), mixedStrides.end());
 
-    ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           mixedShapes, mixedStrides);
   }
 
   return ndDesc;
@@ -564,29 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::LoadNdOp loadOp;
-
-    if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
-      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                       getAsOpFoldResult(readOp.getIndices()),
-                                       /*packed=*/nullptr, transposeAttr,
-                                       /*l1_hint=*/hint,
-                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType,
-                             dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                             readOp.getIndices());
-
-      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                       /*packed=*/nullptr, transposeAttr,
-                                       /*l1_hint=*/hint,
-                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType,
+        dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                      getAsOpFoldResult(readOp.getIndices()),
+                                      /*packed=*/nullptr, transposeAttr,
+                                      /*l1_hint=*/hint,
+                                      /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -630,28 +573,15 @@ struct TransferWriteLowering
         xegpu::MemorySpace::Global);
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::StoreNdOp storeOp;
-    if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
-
-      storeOp =
-          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                   getAsOpFoldResult(writeOp.getIndices()),
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-          writeOp.getIndices());
-
-      storeOp =
-          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType,
+        dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+    auto storeOp =
+        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                  getAsOpFoldResult(writeOp.getIndices()),
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -743,25 +673,13 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
 
-    xegpu::LoadNdOp loadNdOp;
-
-    if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
-      loadNdOp = xegpu::LoadNdOp::create(
-          rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
-          /*packed=*/nullptr, /*transpose=*/nullptr,
-          /*l1_hint=*/hint,
-          /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
-      loadNdOp =
-          xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                  /*packed=*/nullptr, /*transpose=*/nullptr,
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+    auto loadNdOp = xegpu::LoadNdOp::create(
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+        /*packed=*/nullptr, /*transpose=*/nullptr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -789,24 +707,14 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::StoreNdOp storeNdOp;
-    if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
-
-      storeNdOp =
-          xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                   getAsOpFoldResult(storeOp.getIndices()),
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
-
-      storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                           /*l1_hint=*/hint,
-                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+    auto storeNdOp =
+        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                  getAsOpFoldResult(storeOp.getIndices()),
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8ed8b26dd2a0e..b565e39464b52 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -475,14 +475,30 @@ LogicalResult PrefetchNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = tdescTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
@@ -600,14 +616,30 @@ LogicalResult LoadNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = tdescTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
@@ -694,14 +726,30 @@ LogicalResult StoreNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << dstTy;
 
-  int64_t tDescRank = dstTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = dstTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index c7c0485768b99..b5fb2c4aa3e27 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -53,10 +53,10 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 19240abe1e75c..57e754f7d7c00 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,10 +12,10 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
 
 // -----
 
@@ -31,9 +31,9 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
@@ -55,10 +55,10 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 72bdab0a4db3a..78a2692119142 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -13,10 +13,10 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     %[[SRC]]
 // LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_1D_vector(
@@ -47,10 +47,10 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     %[[SRC]]
 // LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_2D_vector(
@@ -151,8 +151,8 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 
@@ -186,8 +186,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source2(
 // LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
 // LOAD-GATHER-LABEL:  @load_dynamic_source2(
@@ -419,10 +419,10 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME:     %[[SUBVIEW]]
 // LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8xf16>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index ca3bbc11a5180..e1b754f952bbe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,10 +16,10 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
 
 // STORE-SCATTER-LABEL:  @store_1D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
@@ -50,10 +50,10 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_2D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
@@ -91,10 +91,10 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
 // STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
@@ -301,10 +301,10 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
 // STORE-ND-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// STORE-ND-SAME:     %[[SUBVIEW]]
 // STORE-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // STORE-ND-SAME:     boundary_check = false
-// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]], %[[OFF2]]] : vector<8xf16>
 
 // STORE-SCATTER-LABEL:  @store_to_subview(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..00a586dee1f51 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -135,7 +135,7 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
 // -----
 func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
   %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
   return
 }
@@ -143,7 +143,7 @@ func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
 // -----
 func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -152,11 +152,19 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
   return
 }
 
+// -----
+func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
+  %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+  %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  return
+}
+
 // -----
 func.func @load_nd_layout(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

>From 65f57c7f14a844ec03928887c8c867ca4d6324d5 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:46:28 +0000
Subject: [PATCH 03/14] Apply formatting

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 39 +++++++++----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 24 +++++++-----
 2 files changed, 34 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index f3dcb31f6b0be..7f11d427191e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -521,15 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType,
-        dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                      getAsOpFoldResult(readOp.getIndices()),
-                                      /*packed=*/nullptr, transposeAttr,
-                                      /*l1_hint=*/hint,
-                                      /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+    auto loadOp = xegpu::LoadNdOp::create(
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
+        /*packed=*/nullptr, transposeAttr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -573,15 +573,15 @@ struct TransferWriteLowering
         xegpu::MemorySpace::Global);
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType,
-        dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
 
     auto storeOp =
         xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                  getAsOpFoldResult(writeOp.getIndices()),
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                 getAsOpFoldResult(writeOp.getIndices()),
+                                 /*l1_hint=*/hint,
+                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -710,11 +710,10 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     xegpu::CreateNdDescOp ndDesc =
         createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
 
-    auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                  getAsOpFoldResult(storeOp.getIndices()),
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+    auto storeNdOp = xegpu::StoreNdOp::create(
+        rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b565e39464b52..0435216e306af 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -484,16 +484,18 @@ LogicalResult PrefetchNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(
@@ -625,16 +627,18 @@ LogicalResult LoadNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(
@@ -735,16 +739,18 @@ LogicalResult StoreNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = dstTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(

>From 49d38a079b72d14c5cd40cb5096814935d78fe9a Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:06:12 +0000
Subject: [PATCH 04/14] generalize 'offsets-check'

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 120 ++++++++-----------------
 mlir/test/Dialect/XeGPU/invalid.mlir   |   8 --
 2 files changed, 39 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0435216e306af..0624e8c4a6a38 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   return success();
 }
 
+// Verify that number of offsets matches either the source rank or the tdesc
+// rank.
+static LogicalResult
+isValidNdOffset(TypedValue<TensorDescType> tDesc,
+                std::optional<llvm::ArrayRef<long int>> constOffsets,
+                int64_t offsetSize,
+                function_ref<InFlightDiagnostic()> emitError) {
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else
+      sourceRank = staticSource.size();
+
+    int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+    auto tDescRank = tDesc.getType().getRank();
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitError() << "Offsets rank must match either the source or the "
+                            "TensorDesc rank.";
+  }
+  return success();
+}
+
 static LogicalResult
 isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
                                  VectorType valueTy, int64_t chunkSize,
@@ -476,33 +509,8 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -619,33 +627,8 @@ LogicalResult LoadNdOp::verify() {
                          << tdescTy;
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -731,33 +714,8 @@ LogicalResult StoreNdOp::verify() {
                          << dstTy;
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = dstTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 00a586dee1f51..614f21bcebc48 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -157,14 +157,6 @@ func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   return
 }
 
-// -----
-func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
-  %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
-  %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  return
-}
-
 // -----
 func.func @load_nd_layout(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

>From f9f73ad7c7921d0ce038fae526fce6df23c85e97 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:23:52 +0000
Subject: [PATCH 05/14] fix windows build

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0624e8c4a6a38..b3bdfc58abafc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,7 +125,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
 // rank.
 static LogicalResult
 isValidNdOffset(TypedValue<TensorDescType> tDesc,
-                std::optional<llvm::ArrayRef<long int>> constOffsets,
+                std::optional<llvm::ArrayRef<int64_t>> constOffsets,
                 int64_t offsetSize,
                 function_ref<InFlightDiagnostic()> emitError) {
   if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {

>From 2c75b294e9ea00e1d9cb12de9cf26ac7a5121ef5 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:31 +0000
Subject: [PATCH 06/14] add docs for new offset syntax

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 45 +++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..93c9f305c080c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -261,6 +261,21 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
         : !xegpu.tensor_desc<8x16xf16>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+    ```
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -350,6 +365,21 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
         : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+    ```
 
   }];
 
@@ -445,6 +475,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                              : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+    ```
 
   }];
 

>From 15f3aa70fd6709bec1a9dfed73d3c44d2dc0acf2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:59 +0000
Subject: [PATCH 07/14] Update validation to not depend on 'create_nd_tdesc' op

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 29 ++++++--------------------
 1 file changed, 6 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b3bdfc58abafc..76640bb59be46 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -128,29 +128,12 @@ isValidNdOffset(TypedValue<TensorDescType> tDesc,
                 std::optional<llvm::ArrayRef<int64_t>> constOffsets,
                 int64_t offsetSize,
                 function_ref<InFlightDiagnostic()> emitError) {
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
-    auto tDescRank = tDesc.getType().getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitError() << "Offsets rank must match either the source or the "
-                            "TensorDesc rank.";
-  }
+  int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+  auto tDescRank = tDesc.getType().getRank();
+  if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
+      ((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
+    return emitError() << "Offsets rank cannot be smaller than tensor "
+                          "descriptor rank.";
   return success();
 }
 

>From ccf8b92429503fd2ee9fcf175004708f51e7fe86 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:57:19 +0000
Subject: [PATCH 08/14] Use memref.extract_strided_metadata to compute strides

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 21 ++++---------------
 .../VectorToXeGPU/load-to-xegpu.mlir          |  4 ++--
 .../VectorToXeGPU/store-to-xegpu.mlir         |  4 ++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  4 ++--
 .../transfer-write-to-xegpu.mlir              |  4 ++--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 12 ++---------
 6 files changed, 14 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 7f11d427191e5..0f031be26cebc 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -105,9 +105,9 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
   auto [strides, offset] = srcTy.getStridesAndOffset();
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape())
+  if (srcTy.hasStaticShape()) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
-  else {
+  } else {
     // In case of any dynamic shapes, source's shape and strides have to be
     // explicitly provided.
     SmallVector<Value> sourceDims;
@@ -123,21 +123,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
         mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
     }
 
-    // Compute strides in reverse order.
-    SmallVector<OpFoldResult> mixedStrides;
-    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    // Last stride is guaranteed to be static and unit.
-    mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
-    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-      accStride =
-          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-      if (strides[i] == ShapedType::kDynamic)
-        mixedStrides.push_back(accStride);
-      else
-        mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
-    }
-    std::reverse(mixedStrides.begin(), mixedStrides.end());
-
+    auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+    SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
                                            mixedShapes, mixedStrides);
   }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index b5fb2c4aa3e27..1975c96bfe796 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -52,9 +52,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 57e754f7d7c00..63e78ca20bcee 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -54,9 +54,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 78a2692119142..81527a8111bb0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -150,7 +150,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD-ND:        {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -186,7 +186,7 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source2(
 // LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index e1b754f952bbe..83d33e1905f7c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -90,9 +90,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// STORE-ND:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 614f21bcebc48..4b710d3f51557 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -132,18 +132,10 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
   return
 }
 
-// -----
-func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
-  %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
-  %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
-  return
-}
-
 // -----
 func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
   xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -152,7 +144,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
   xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
   return
 }

>From 0b26a417c5c4e5a33a319d70564385b332dba39b Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 23:00:36 +0000
Subject: [PATCH 09/14] apply clang-format

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0f031be26cebc..11bf3152e5cc4 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -124,7 +124,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
     }
 
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
+    SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
+                                           meta.getStrides().end());
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
                                            mixedShapes, mixedStrides);
   }

>From 1f0e95384aa384278b240a38aee1ea40dc21c2d2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 22 Oct 2025 09:14:56 +0000
Subject: [PATCH 10/14] fix docs

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 93c9f305c080c..489bd513a0bd4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -262,9 +262,9 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
     ```
 
     The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater or equal to the rank of the tensor descriptor
-    and less than the rank of the source memref. The offsets are applied to the innermost
-    dimension of the source memref.
+    The number of offsets must be greater than or equal to the rank of the tensor
+    descriptor and less than or equal to the rank of the source memref.
+    The offsets are applied to the innermost dimensions of the source memref.
 
     Examples:
     ```mlir
@@ -274,7 +274,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
       // memref[0, %off0, %off1, %off2]
       xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
       // memref[%off0, %off1, %off2, %off3]
-      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf16>
     ```
   }];
 
@@ -366,9 +366,9 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     ```
 
     The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater or equal to the rank of the tensor descriptor
-    and less than the rank of the source memref. The offsets are applied to the innermost
-    dimension of the source memref.
+    The number of offsets must be greater than or equal to the rank of the tensor
+    descriptor and less than or equal to the rank of the source memref.
+    The offsets are applied to the innermost dimensions of the source memref.
 
     Examples:
     ```mlir
@@ -476,9 +476,9 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     ```
 
     The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater or equal to the rank of the tensor descriptor
-    and less than the rank of the source memref. The offsets are applied to the innermost
-    dimension of the source memref.
+    The number of offsets must be greater than or equal to the rank of the tensor
+    descriptor and less than or equal to the rank of the source memref.
+    The offsets are applied to the innermost dimensions of the source memref.
 
     Examples:
     ```mlir

>From 173eb6ddc4f716dd12e2d94d5e7494abba53ce97 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 22 Oct 2025 12:03:51 +0000
Subject: [PATCH 11/14] use extractStridedMetadataOp to compute shapes for
 tdesc

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 18 ++----------------
 .../VectorToXeGPU/load-to-xegpu.mlir           | 10 ++--------
 .../VectorToXeGPU/store-to-xegpu.mlir          | 10 ++--------
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir  | 13 +++----------
 .../VectorToXeGPU/transfer-write-to-xegpu.mlir | 10 ++--------
 5 files changed, 11 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 11bf3152e5cc4..ee2e8a69edcc0 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -110,24 +110,10 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
   } else {
     // In case of any dynamic shapes, source's shape and strides have to be
     // explicitly provided.
-    SmallVector<Value> sourceDims;
-    unsigned srcRank = srcTy.getRank();
-    for (unsigned i = 0; i < srcRank; ++i)
-      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-    SmallVector<OpFoldResult> mixedShapes;
-    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-      if (shape == ShapedType::kDynamic)
-        mixedShapes.push_back(sourceDims[idx]);
-      else
-        mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
-    }
-
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
-                                           meta.getStrides().end());
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           mixedShapes, mixedStrides);
+                                           meta.getConstifiedMixedSizes(),
+                                           meta.getConstifiedMixedStrides());
   }
 
   return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 1975c96bfe796..a3ed559f6413d 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -46,15 +46,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-LABEL: @load_dynamic_source(
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// CHECK-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 63e78ca20bcee..573e35de7b42e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -48,15 +48,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// CHECK-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 81527a8111bb0..1b0f492372eef 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -144,13 +144,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        %[[C2:.+]] = arith.constant 2 : index
-// LOAD-ND:        %[[C1:.+]] = arith.constant 1 : index
-// LOAD-ND:        %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND:        {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// LOAD-ND:        {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -184,9 +178,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 }
 
 // LOAD-ND-LABEL:  @load_dynamic_source2(
-// LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND-DAG:    {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[SIZES]]#0, 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 83d33e1905f7c..8ca86c39d640d 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -84,15 +84,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
-// STORE-ND-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// STORE-ND-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// STORE-ND-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// STORE-ND-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// STORE-ND:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// STORE-ND-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
 // STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 

>From babf57e4fded2d1973885aa431bc26d77b175f15 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 30 Oct 2025 11:16:25 +0000
Subject: [PATCH 12/14] revert xegpu def changes

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 45 -------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 67 ++++++++++---------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 12 +++-
 3 files changed, 44 insertions(+), 80 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 489bd513a0bd4..426377fcf598f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -261,21 +261,6 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
         : !xegpu.tensor_desc<8x16xf16>
     ```
 
-    The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater than or equal to the rank of the tensor
-    descriptor and less than or equal to the rank of the source memref.
-    The offsets are applied to the innermost dimensions of the source memref.
-
-    Examples:
-    ```mlir
-      %tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
-      // memref[0, 0, %off0, %off1]
-      xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
-      // memref[0, %off0, %off1, %off2]
-      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
-      // memref[%off0, %off1, %off2, %off3]
-      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf16>
-    ```
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -365,21 +350,6 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
         : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
     ```
 
-    The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater than or equal to the rank of the tensor
-    descriptor and less than or equal to the rank of the source memref.
-    The offsets are applied to the innermost dimensions of the source memref.
-
-    Examples:
-    ```mlir
-      %1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
-      // memref[0, 0, %off0, %off1]
-      xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-      // memref[0, %off0, %off1, %off2]
-      xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-      // memref[%off0, %off1, %off2, %off3]
-      xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-    ```
 
   }];
 
@@ -475,21 +445,6 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                              : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
 
-    The operation may take optional offsets for the tensor descriptor.
-    The number of offsets must be greater than or equal to the rank of the tensor
-    descriptor and less than or equal to the rank of the source memref.
-    The offsets are applied to the innermost dimensions of the source memref.
-
-    Examples:
-    ```mlir
-      %2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
-      // memref[0, 0, %off0, %off1]
-      xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
-      // memref[0, %off0, %off1, %off2]
-      xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
-      // memref[%off0, %off1, %off2, %off3]
-      xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
-    ```
 
   }];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 76640bb59be46..abd12e2e69ac0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,22 +121,6 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   return success();
 }
 
-// Verify that number of offsets matches either the source rank or the tdesc
-// rank.
-static LogicalResult
-isValidNdOffset(TypedValue<TensorDescType> tDesc,
-                std::optional<llvm::ArrayRef<int64_t>> constOffsets,
-                int64_t offsetSize,
-                function_ref<InFlightDiagnostic()> emitError) {
-  int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
-  auto tDescRank = tDesc.getType().getRank();
-  if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
-    return emitError() << "Offsets rank cannot be smaller than tensor "
-                          "descriptor rank.";
-  return success();
-}
-
 static LogicalResult
 isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
                                  VectorType valueTy, int64_t chunkSize,
@@ -274,10 +258,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean (only do so for full-static case, otherwise
-    // printer would fail trying to print empty array-attr).
-    if (staticShape == memrefShape && staticStrides == memrefStrides &&
-        dynamicShape.empty() && dynamicStrides.empty()) {
+    // to keep the IR print clean.
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -338,10 +320,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean (only do so for full-static case, otherwise
-    // printer would fail trying to print empty array-attr).
-    if (staticShape == memrefShape && staticStrides == memrefStrides &&
-        dynamicShape.empty() && dynamicStrides.empty()) {
+    // to keep the IR print clean.
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -491,9 +471,16 @@ LogicalResult PrefetchNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  auto tDesc = getTensorDesc();
-  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
-                         [&]() { return emitOpError(); });
+  int64_t tDescRank = tdescTy.getRank();
+  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+  int64_t constOffsetSize =
+      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+    return emitOpError(
+        "Mismatched ranks between offsets and tensor descriptor");
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -609,9 +596,16 @@ LogicalResult LoadNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
 
-  auto tDesc = getTensorDesc();
-  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
-                         [&]() { return emitOpError(); });
+  int64_t tDescRank = tdescTy.getRank();
+  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+  int64_t constOffsetSize =
+      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+    return emitOpError(
+        "Mismatched ranks between offsets and tensor descriptor");
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -696,9 +690,16 @@ LogicalResult StoreNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << dstTy;
 
-  auto tDesc = getTensorDesc();
-  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
-                         [&]() { return emitOpError(); });
+  int64_t tDescRank = dstTy.getRank();
+  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+  int64_t constOffsetSize =
+      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+    return emitOpError(
+        "Mismatched ranks between offsets and tensor descriptor");
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 4b710d3f51557..ebbe3ce0ec0d0 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -132,10 +132,18 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
   return
 }
 
+// -----
+func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
+  %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
+// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+  %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
+  return
+}
+
 // -----
 func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
+    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
   xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -144,7 +152,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
+    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
   xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
   return
 }

>From 2a38c2cc7733c3ba934a66e5a198db864edd9c49 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 30 Oct 2025 14:32:34 +0000
Subject: [PATCH 13/14] collapse memref shape to 2d

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 128 ++++++++++++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  12 +-
 .../VectorToXeGPU/load-to-xegpu.mlir          |  29 ++--
 .../VectorToXeGPU/store-to-xegpu.mlir         |  29 ++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  38 +++---
 .../transfer-write-to-xegpu.mlir              |  41 +++---
 6 files changed, 179 insertions(+), 98 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index ee2e8a69edcc0..34c302b4968c5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -358,6 +358,63 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
       .getResult();
 }
 
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+//   input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+//   targetRank=2 output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, returned
+//   offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+                                    Value memref,
+                                    SmallVector<OpFoldResult> offsets,
+                                    int64_t targetRank) {
+  auto memrefType = cast<MemRefType>(memref.getType());
+  unsigned rank = memrefType.getRank();
+
+  if (rank <= targetRank)
+    return {memref, offsets};
+
+  int64_t numCombinedDims = rank - targetRank;
+  SmallVector<OpFoldResult> subviewOffsets;
+  SmallVector<OpFoldResult> subviewSizes;
+  SmallVector<OpFoldResult> subviewStrides;
+
+  // For the combined dimensions: use the provided offsets, size=1, stride=1
+  for (unsigned i = 0; i < numCombinedDims; ++i) {
+    subviewOffsets.push_back(offsets[i]);
+    subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+    subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+  }
+
+  // For the last targetRank dimensions: offset=0, use full size, stride=1
+  SmallVector<int64_t> resultShape;
+  auto originalShape = memrefType.getShape();
+  auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+  for (unsigned i = numCombinedDims; i < rank; ++i) {
+    subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+    if (ShapedType::isDynamic(originalShape[i])) {
+      subviewSizes.push_back(meta.getSizes()[i]);
+      resultShape.push_back(ShapedType::kDynamic);
+    } else {
+      subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+      resultShape.push_back(originalShape[i]);
+    }
+    subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+  }
+
+  auto resultType = memref::SubViewOp::inferRankReducedResultType(
+      resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+  auto subviewOp =
+      memref::SubViewOp::create(rewriter, loc, resultType, memref,
+                                subviewOffsets, subviewSizes, subviewStrides);
+
+  // Return the remaining offsets for the last targetRank dimensions
+  SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+                                       offsets.end());
+  return {subviewOp.getResult(), newOffsets};
+}
+
 template <
     typename OpType,
     typename = std::enable_if_t<llvm::is_one_of<
@@ -493,17 +550,18 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
         !isTransposeLoad ? nullptr
                          : DenseI64ArrayAttr::get(rewriter.getContext(),
                                                   ArrayRef<int64_t>{1, 0});
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+        vecTy.getRank());
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
-    auto loadOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
-        /*packed=*/nullptr, transposeAttr,
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+                                          /*packed=*/nullptr, transposeAttr,
+                                          /*l1_hint=*/hint,
+                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -541,21 +599,23 @@ struct TransferWriteLowering
     if (!map.isMinorIdentity())
       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
 
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, writeOp.getBase(),
+        getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
         xegpu::MemorySpace::Global);
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
 
-    auto storeOp =
-        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                 getAsOpFoldResult(writeOp.getIndices()),
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+                                            ndDesc, indices,
+                                            /*l1_hint=*/hint,
+                                            /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -643,17 +703,21 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
 
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+        vecTy.getRank());
+
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
 
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
-    auto loadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
-        /*packed=*/nullptr, /*transpose=*/nullptr,
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+    auto loadNdOp =
+        xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+                                /*packed=*/nullptr, /*transpose=*/nullptr,
+                                /*l1_hint=*/hint,
+                                /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -675,19 +739,23 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
 
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, storeOp.getBase(),
+        getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
 
-    auto storeNdOp = xegpu::StoreNdOp::create(
-        rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    auto storeNdOp =
+        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
+                                 /*l1_hint=*/hint,
+                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..8ed8b26dd2a0e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -258,8 +258,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -320,8 +322,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index a3ed559f6413d..ae5141db16c09 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-LABEL: @load_1D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -28,29 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-LABEL: @load_2D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
 
 func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
-    %offset: index) -> vector<8x16xf32> {
-  %0 = vector.load %source[%offset, %offset, %offset]
+    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+  %0 = vector.load %source[%i, %j, %k]
     : memref<?x?x?xf32>, vector<8x16xf32>
   return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @load_dynamic_source(
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 573e35de7b42e..1a10d917623cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
 // -----
 
@@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
 func.func @store_dynamic_source(%vec: vector<8x16xf32>,
-    %source: memref<?x?x?xf32>, %offset: index) {
-  vector.store %vec, %source[%offset, %offset, %offset]
+    %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+  vector.store %vec, %source[%i, %j, %k]
     : memref<?x?x?xf32>, vector<8x16xf32>
   return
 }
@@ -47,12 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-LABEL: @store_dynamic_source(
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 1b0f492372eef..c87a5304babfe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -12,11 +12,12 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-ND-LABEL:  @load_1D_vector(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]]
-// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_1D_vector(
@@ -46,11 +47,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-LABEL:  @load_2D_vector(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]]
-// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_2D_vector(
@@ -143,10 +145,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 }
 // LOAD-ND-LABEL:  @load_dynamic_source(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND:        {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 
@@ -178,9 +181,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 }
 
 // LOAD-ND-LABEL:  @load_dynamic_source2(
-// LOAD-ND-DAG:    {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[SIZES]]#0, 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
+// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
 // LOAD-GATHER-LABEL:  @load_dynamic_source2(
@@ -411,11 +416,12 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]]
-// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 8ca86c39d640d..43a1a7206e2cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
 // RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
 
 
@@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]]
-// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME:    %[[COLLAPSED]]
+// STORE-ND-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
 // STORE-SCATTER-LABEL:  @store_1D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
@@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]]
-// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME:    %[[COLLAPSED]]
+// STORE-ND-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_2D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
@@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // -----
 gpu.module @xevm_module {
 gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
-    %source: memref<?x?x?xf32>, %offset: index) {
-  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+  vector.transfer_write %vec, %source[%i, %j, %k]
     {in_bounds = [true, true]}
     : vector<8x16xf32>, memref<?x?x?xf32>
   gpu.return
@@ -83,12 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-LABEL: @store_dynamic_source(
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// STORE-ND-SAME:  %[[OFFSET:.+]]: index
-// STORE-ND:       {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME:  , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// STORE-ND:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
 // STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
@@ -292,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND-SAME:   %[[VEC:.+]]: vector<8xf16>,
 // STORE-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
-// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
-// STORE-ND-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
 // STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[SUBVIEW]]
-// STORE-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME:     %[[COLLAPSED]]
+// STORE-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // STORE-ND-SAME:     boundary_check = false
-// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]], %[[OFF2]]] : vector<8xf16>
+// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
 
 // STORE-SCATTER-LABEL:  @store_to_subview(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,

>From 8fd49c09420aca8ae70e0cfe91d9ac12b58a006a Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 30 Oct 2025 14:35:53 +0000
Subject: [PATCH 14/14] format comments

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 34c302b4968c5..49a564e0bbc87 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -362,13 +362,12 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
 // the collapsed dimensions. Returns the new memref value and the remaining
 // offsets for the last targetRank dimensions. For example:
 //   input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
-//   targetRank=2 output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, returned
-//   offsets: [%i2, %i3]
+//   output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
 static std::pair<Value, SmallVector<OpFoldResult>>
 convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
                                     Value memref,
                                     SmallVector<OpFoldResult> offsets,
-                                    int64_t targetRank) {
+                                    int64_t targetRank = 2) {
   auto memrefType = cast<MemRefType>(memref.getType());
   unsigned rank = memrefType.getRank();
 



More information about the Mlir-commits mailing list