[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (PR #162095)

Dmitry Chigarev llvmlistbot at llvm.org
Tue Oct 21 15:57:41 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/162095

>From e581a0bd94886d8e96571886a46f9980a4e42a9c Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Mon, 6 Oct 2025 14:37:21 +0000
Subject: [PATCH 1/8] [MLIR][XeGPU][VectorToXeGPU] Lower
 vector.load/store/transfer_read/transfer_write to new offsets syntax

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 218 ++++++++++++------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  12 +-
 .../VectorToXeGPU/load-to-xegpu.mlir          |   4 +-
 .../VectorToXeGPU/store-to-xegpu.mlir         |   4 +-
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |   8 +-
 .../transfer-write-to-xegpu.mlir              |   4 +-
 6 files changed, 171 insertions(+), 79 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..41526a7e34971 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,64 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
+static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
+                                      SmallVector<OpFoldResult> &mixedShapes,
+                                      SmallVector<OpFoldResult> &mixedStrides,
+                                      SmallVector<int64_t> &strides,
+                                      TypedValue<MemRefType> src) {
+  auto srcTy = src.getType();
+  // In case of any dynamic shapes, source's shape and strides have to be
+  // explicitly provided.
+  SmallVector<Value> sourceDims;
+  unsigned srcRank = srcTy.getRank();
+  for (unsigned i = 0; i < srcRank; ++i)
+    sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
+
+  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+    if (shape == ShapedType::kDynamic)
+      mixedShapes.push_back(sourceDims[idx]);
+    else
+      mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
+  }
+
+  // Compute strides in reverse order.
+  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  // Last stride is guaranteed to be static and unit.
+  mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+    accStride =
+        arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+    if (strides[i] == ShapedType::kDynamic)
+      mixedStrides.push_back(accStride);
+    else
+      mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+  }
+  std::reverse(mixedStrides.begin(), mixedStrides.end());
+}
+
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+                                                Location loc,
+                                                xegpu::TensorDescType descType,
+                                                TypedValue<MemRefType> src) {
+  MemRefType srcTy = src.getType();
+  auto [strides, offset] = srcTy.getStridesAndOffset();
+
+  xegpu::CreateNdDescOp ndDesc;
+  if (srcTy.hasStaticShape())
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
+  else {
+    SmallVector<OpFoldResult> mixedShapes;
+    SmallVector<OpFoldResult> mixedStrides;
+    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+                              src);
+
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           mixedShapes, mixedStrides);
+  }
+
+  return ndDesc;
+}
+
 static xegpu::CreateNdDescOp
 createNdDescriptor(PatternRewriter &rewriter, Location loc,
                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -109,45 +167,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
                                            getAsOpFoldResult(offsets));
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
-    // explicitly provided.
-    SmallVector<Value> sourceDims;
-    unsigned srcRank = srcTy.getRank();
-    for (unsigned i = 0; i < srcRank; ++i)
-      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-    SmallVector<int64_t> constOffsets;
-    SmallVector<Value> dynOffsets;
+    SmallVector<OpFoldResult> mixedOffsets;
     for (Value offset : offsets) {
       std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (!staticVal)
-        dynOffsets.push_back(offset);
-      constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
-    }
-
-    SmallVector<Value> dynShapes;
-    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-      if (shape == ShapedType::kDynamic)
-        dynShapes.push_back(sourceDims[idx]);
+      if (staticVal)
+        mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+      else
+        mixedOffsets.push_back(offset);
     }
 
-    // Compute strides in reverse order.
-    SmallVector<Value> dynStrides;
-    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    // Last stride is guaranteed to be static and unit.
-    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-      accStride =
-          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-      if (strides[i] == ShapedType::kDynamic)
-        dynStrides.push_back(accStride);
-    }
-    std::reverse(dynStrides.begin(), dynStrides.end());
+    SmallVector<OpFoldResult> mixedShapes;
+    SmallVector<OpFoldResult> mixedStrides;
+    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+                              src);
 
     ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
-        DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
-        DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
-        DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+        rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
   }
 
   return ndDesc;
@@ -523,21 +558,35 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
         descShape, elementType, /*array_length=*/1,
         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
 
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                           readOp.getIndices());
-
     DenseI64ArrayAttr transposeAttr =
         !isTransposeLoad ? nullptr
                          : DenseI64ArrayAttr::get(rewriter.getContext(),
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                          /*packed=*/nullptr, transposeAttr,
-                                          /*l1_hint=*/hint,
-                                          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::LoadNdOp loadOp;
+
+    if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                       getAsOpFoldResult(readOp.getIndices()),
+                                       /*packed=*/nullptr, transposeAttr,
+                                       /*l1_hint=*/hint,
+                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType,
+                             dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
+                             readOp.getIndices());
+
+      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                       /*packed=*/nullptr, transposeAttr,
+                                       /*l1_hint=*/hint,
+                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -579,17 +628,30 @@ struct TransferWriteLowering
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
         xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-                           writeOp.getIndices());
-
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeOp =
-        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::StoreNdOp storeOp;
+    if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+      storeOp =
+          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                   getAsOpFoldResult(writeOp.getIndices()),
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType,
+          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
+          writeOp.getIndices());
+
+      storeOp =
+          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -674,19 +736,32 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
 
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
 
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
 
-    // By default, no specific caching policy is assigned.
-    xegpu::CachePolicyAttr hint = nullptr;
-    auto loadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::LoadNdOp loadNdOp;
+
+    if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+      loadNdOp = xegpu::LoadNdOp::create(
+          rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+          /*packed=*/nullptr, /*transpose=*/nullptr,
+          /*l1_hint=*/hint,
+          /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
+      loadNdOp =
+          xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                  /*packed=*/nullptr, /*transpose=*/nullptr,
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -711,15 +786,28 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::StoreNdOp storeNdOp;
+    if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
+      xegpu::CreateNdDescOp ndDesc =
+          createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+      storeNdOp =
+          xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                   getAsOpFoldResult(storeOp.getIndices()),
+                                   /*l1_hint=*/hint,
+                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
+    } else {
+      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+          rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
+
+      storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                           /*l1_hint=*/hint,
+                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
+    }
+
     rewriter.replaceOp(storeOp, storeNdOp);
 
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e0a8ac40648e0..d8f3d7bc4956d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -215,8 +215,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -277,8 +279,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..c7c0485768b99 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x15xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..19240abe1e75c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79af1bd9a..72bdab0a4db3a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -83,9 +83,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
 // LOAD-ND-LABEL:  @load_zero_pad_out_of_bounds(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_zero_pad_out_of_bounds(
@@ -109,9 +109,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
 // LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
 // LOAD-ND-SAME:     -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc9414da4f6..ca3bbc11a5180 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -126,9 +126,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_out_of_bounds(
 // STORE-SCATTER:   vector.transfer_write

>From b56c1cdce34b1d49b013d6cf2bdd7ea0c290d678 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:45:14 +0000
Subject: [PATCH 2/8] Relax len(offsets) == tdescRank requirement

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 210 +++++-------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  96 ++++++--
 .../VectorToXeGPU/load-to-xegpu.mlir          |  12 +-
 .../VectorToXeGPU/store-to-xegpu.mlir         |  12 +-
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  20 +-
 .../transfer-write-to-xegpu.mlir              |  16 +-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  14 +-
 7 files changed, 172 insertions(+), 208 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 41526a7e34971..f3dcb31f6b0be 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,41 +97,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
-                                      SmallVector<OpFoldResult> &mixedShapes,
-                                      SmallVector<OpFoldResult> &mixedStrides,
-                                      SmallVector<int64_t> &strides,
-                                      TypedValue<MemRefType> src) {
-  auto srcTy = src.getType();
-  // In case of any dynamic shapes, source's shape and strides have to be
-  // explicitly provided.
-  SmallVector<Value> sourceDims;
-  unsigned srcRank = srcTy.getRank();
-  for (unsigned i = 0; i < srcRank; ++i)
-    sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-    if (shape == ShapedType::kDynamic)
-      mixedShapes.push_back(sourceDims[idx]);
-    else
-      mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
-  }
-
-  // Compute strides in reverse order.
-  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-  // Last stride is guaranteed to be static and unit.
-  mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
-  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-    accStride =
-        arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-    if (strides[i] == ShapedType::kDynamic)
-      mixedStrides.push_back(accStride);
-    else
-      mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
-  }
-  std::reverse(mixedStrides.begin(), mixedStrides.end());
-}
-
 static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 Location loc,
                                                 xegpu::TensorDescType descType,
@@ -143,46 +108,38 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
   if (srcTy.hasStaticShape())
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   else {
-    SmallVector<OpFoldResult> mixedShapes;
-    SmallVector<OpFoldResult> mixedStrides;
-    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
-                              src);
+    // In case of any dynamic shapes, source's shape and strides have to be
+    // explicitly provided.
+    SmallVector<Value> sourceDims;
+    unsigned srcRank = srcTy.getRank();
+    for (unsigned i = 0; i < srcRank; ++i)
+      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
 
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           mixedShapes, mixedStrides);
-  }
-
-  return ndDesc;
-}
-
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
-                   xegpu::TensorDescType descType, TypedValue<MemRefType> src,
-                   Operation::operand_range offsets) {
-  MemRefType srcTy = src.getType();
-  auto [strides, offset] = srcTy.getStridesAndOffset();
-
-  xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           getAsOpFoldResult(offsets));
-  } else {
-    SmallVector<OpFoldResult> mixedOffsets;
-    for (Value offset : offsets) {
-      std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (staticVal)
-        mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+    SmallVector<OpFoldResult> mixedShapes;
+    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+      if (shape == ShapedType::kDynamic)
+        mixedShapes.push_back(sourceDims[idx]);
       else
-        mixedOffsets.push_back(offset);
+        mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
     }
 
-    SmallVector<OpFoldResult> mixedShapes;
+    // Compute strides in reverse order.
     SmallVector<OpFoldResult> mixedStrides;
-    computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
-                              src);
+    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+    // Last stride is guaranteed to be static and unit.
+    mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+      accStride =
+          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+      if (strides[i] == ShapedType::kDynamic)
+        mixedStrides.push_back(accStride);
+      else
+        mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+    }
+    std::reverse(mixedStrides.begin(), mixedStrides.end());
 
-    ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           mixedShapes, mixedStrides);
   }
 
   return ndDesc;
@@ -564,29 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::LoadNdOp loadOp;
-
-    if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
-      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                       getAsOpFoldResult(readOp.getIndices()),
-                                       /*packed=*/nullptr, transposeAttr,
-                                       /*l1_hint=*/hint,
-                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType,
-                             dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                             readOp.getIndices());
-
-      loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                       /*packed=*/nullptr, transposeAttr,
-                                       /*l1_hint=*/hint,
-                                       /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType,
+        dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+                                      getAsOpFoldResult(readOp.getIndices()),
+                                      /*packed=*/nullptr, transposeAttr,
+                                      /*l1_hint=*/hint,
+                                      /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -630,28 +573,15 @@ struct TransferWriteLowering
         xegpu::MemorySpace::Global);
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::StoreNdOp storeOp;
-    if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
-
-      storeOp =
-          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                   getAsOpFoldResult(writeOp.getIndices()),
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType,
-          dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-          writeOp.getIndices());
-
-      storeOp =
-          xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType,
+        dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+    auto storeOp =
+        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+                                  getAsOpFoldResult(writeOp.getIndices()),
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -743,25 +673,13 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
 
-    xegpu::LoadNdOp loadNdOp;
-
-    if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
-      loadNdOp = xegpu::LoadNdOp::create(
-          rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
-          /*packed=*/nullptr, /*transpose=*/nullptr,
-          /*l1_hint=*/hint,
-          /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
-      loadNdOp =
-          xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                  /*packed=*/nullptr, /*transpose=*/nullptr,
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+    auto loadNdOp = xegpu::LoadNdOp::create(
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+        /*packed=*/nullptr, /*transpose=*/nullptr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -789,24 +707,14 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::StoreNdOp storeNdOp;
-    if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
-      xegpu::CreateNdDescOp ndDesc =
-          createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
-
-      storeNdOp =
-          xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                   getAsOpFoldResult(storeOp.getIndices()),
-                                   /*l1_hint=*/hint,
-                                   /*l2_hint=*/hint, /*l3_hint=*/hint);
-    } else {
-      xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-          rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
-
-      storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                           /*l1_hint=*/hint,
-                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
-    }
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+    auto storeNdOp =
+        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+                                  getAsOpFoldResult(storeOp.getIndices()),
+                                  /*l1_hint=*/hint,
+                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d8f3d7bc4956d..1e2f009648bd9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -432,14 +432,30 @@ LogicalResult PrefetchNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = tdescTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
@@ -557,14 +573,30 @@ LogicalResult LoadNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << tdescTy;
 
-  int64_t tDescRank = tdescTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = tdescTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
@@ -651,14 +683,30 @@ LogicalResult StoreNdOp::verify() {
                          << " is not consistent with tensor descriptor "
                          << dstTy;
 
-  int64_t tDescRank = dstTy.getRank();
-  int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-  int64_t constOffsetSize =
-      getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-  if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
-      ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
-    return emitOpError(
-        "Mismatched ranks between offsets and tensor descriptor");
+  auto tDesc = getTensorDesc();
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else 
+      sourceRank = staticSource.size();
+
+    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+    int64_t constOffsetSize =
+        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+    auto tDescRank = dstTy.getRank();
+    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitOpError(
+          "Offsets rank must match either the source or the TensorDesc rank.");
+  }
 
   return success();
 }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index c7c0485768b99..b5fb2c4aa3e27 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -53,10 +53,10 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 19240abe1e75c..57e754f7d7c00 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,10 +12,10 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
 
 // -----
 
@@ -31,9 +31,9 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
@@ -55,10 +55,10 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 72bdab0a4db3a..78a2692119142 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -13,10 +13,10 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     %[[SRC]]
 // LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_1D_vector(
@@ -47,10 +47,10 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     %[[SRC]]
 // LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_2D_vector(
@@ -151,8 +151,8 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 
@@ -186,8 +186,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source2(
 // LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
 // LOAD-GATHER-LABEL:  @load_dynamic_source2(
@@ -419,10 +419,10 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME:     %[[SUBVIEW]]
 // LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8xf16>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index ca3bbc11a5180..e1b754f952bbe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,10 +16,10 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
 
 // STORE-SCATTER-LABEL:  @store_1D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
@@ -50,10 +50,10 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_2D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
@@ -91,10 +91,10 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
 // STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
 // STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
@@ -301,10 +301,10 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
 // STORE-ND-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// STORE-ND-SAME:     %[[SUBVIEW]]
 // STORE-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // STORE-ND-SAME:     boundary_check = false
-// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]], %[[OFF2]]] : vector<8xf16>
 
 // STORE-SCATTER-LABEL:  @store_to_subview(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 228ef69d9a478..db284dc2f0797 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -135,7 +135,7 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
 // -----
 func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
   %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
   return
 }
@@ -143,7 +143,7 @@ func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
 // -----
 func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -152,11 +152,19 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
   xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
   return
 }
 
+// -----
+func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
+  %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+  %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  return
+}
+
 // -----
 func.func @load_nd_layout(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

>From 8581183e6219410f08fc6fdb4f21864a49e88694 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:46:28 +0000
Subject: [PATCH 3/8] Apply formatting

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 39 +++++++++----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 24 +++++++-----
 2 files changed, 34 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index f3dcb31f6b0be..7f11d427191e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -521,15 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
                                                   ArrayRef<int64_t>{1, 0});
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType,
-        dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
-                                      getAsOpFoldResult(readOp.getIndices()),
-                                      /*packed=*/nullptr, transposeAttr,
-                                      /*l1_hint=*/hint,
-                                      /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+    auto loadOp = xegpu::LoadNdOp::create(
+        rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
+        /*packed=*/nullptr, transposeAttr,
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(readOp, loadOp);
 
     return success();
@@ -573,15 +573,15 @@ struct TransferWriteLowering
         xegpu::MemorySpace::Global);
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType,
-        dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+    xegpu::CreateNdDescOp ndDesc =
+        createNdDescriptor(rewriter, loc, descType,
+                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
 
     auto storeOp =
         xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                  getAsOpFoldResult(writeOp.getIndices()),
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+                                 getAsOpFoldResult(writeOp.getIndices()),
+                                 /*l1_hint=*/hint,
+                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -710,11 +710,10 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     xegpu::CreateNdDescOp ndDesc =
         createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
 
-    auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
-                                  getAsOpFoldResult(storeOp.getIndices()),
-                                  /*l1_hint=*/hint,
-                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+    auto storeNdOp = xegpu::StoreNdOp::create(
+        rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
+        /*l1_hint=*/hint,
+        /*l2_hint=*/hint, /*l3_hint=*/hint);
 
     rewriter.replaceOp(storeOp, storeNdOp);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1e2f009648bd9..a34fd0c831de7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -441,16 +441,18 @@ LogicalResult PrefetchNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(
@@ -582,16 +584,18 @@ LogicalResult LoadNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(
@@ -692,16 +696,18 @@ LogicalResult StoreNdOp::verify() {
     if (!staticSource || staticSource.empty()) {
       auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
       sourceRank = sourceTy.getRank();
-    } else 
+    } else
       sourceRank = staticSource.size();
 
     int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
     int64_t constOffsetSize =
         getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
     auto tDescRank = dstTy.getRank();
-    bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
         ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
     if (sourceRankMismatch && tdescRankMismatch)
       return emitOpError(

>From e04202b55a4e05afa1ccb165f4a399fe968d04be Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:06:12 +0000
Subject: [PATCH 4/8] generalize 'offsets-check'

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 120 ++++++++-----------------
 mlir/test/Dialect/XeGPU/invalid.mlir   |   8 --
 2 files changed, 39 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index a34fd0c831de7..07a4b94727146 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   return success();
 }
 
+// Verify that number of offsets matches either the source rank or the tdesc
+// rank.
+static LogicalResult
+isValidNdOffset(TypedValue<TensorDescType> tDesc,
+                std::optional<llvm::ArrayRef<long int>> constOffsets,
+                int64_t offsetSize,
+                function_ref<InFlightDiagnostic()> emitError) {
+  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+    // If CreateNdDescOp is available, we can further
+    // check the offsets rank against the source rank.
+    auto staticSource = createTDescOp.getConstShapeAttr();
+    int64_t sourceRank;
+    if (!staticSource || staticSource.empty()) {
+      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+      sourceRank = sourceTy.getRank();
+    } else
+      sourceRank = staticSource.size();
+
+    int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+    auto tDescRank = tDesc.getType().getRank();
+    bool sourceRankMismatch =
+        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+    bool tdescRankMismatch =
+        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+    if (sourceRankMismatch && tdescRankMismatch)
+      return emitError() << "Offsets rank must match either the source or the "
+                            "TensorDesc rank.";
+  }
+  return success();
+}
+
 static LogicalResult
 isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
                                  VectorType valueTy, int64_t chunkSize,
@@ -433,33 +466,8 @@ LogicalResult PrefetchNdOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,33 +584,8 @@ LogicalResult LoadNdOp::verify() {
                          << tdescTy;
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = tdescTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -688,33 +671,8 @@ LogicalResult StoreNdOp::verify() {
                          << dstTy;
 
   auto tDesc = getTensorDesc();
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
-    int64_t constOffsetSize =
-        getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
-    auto tDescRank = dstTy.getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitOpError(
-          "Offsets rank must match either the source or the TensorDesc rank.");
-  }
-
-  return success();
+  return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+                         [&]() { return emitOpError(); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index db284dc2f0797..27bd457a96833 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -157,14 +157,6 @@ func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   return
 }
 
-// -----
-func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
-  %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
-  %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  return
-}
-
 // -----
 func.func @load_nd_layout(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

>From 37e1843fcc4778023fecd6addfc87c8d3cbcacb2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:23:52 +0000
Subject: [PATCH 5/8] fix windows build

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 07a4b94727146..e09a084ac7ad2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,7 +125,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
 // rank.
 static LogicalResult
 isValidNdOffset(TypedValue<TensorDescType> tDesc,
-                std::optional<llvm::ArrayRef<long int>> constOffsets,
+                std::optional<llvm::ArrayRef<int64_t>> constOffsets,
                 int64_t offsetSize,
                 function_ref<InFlightDiagnostic()> emitError) {
   if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {

>From 614887bff0d28156b7b3ebfef9b0c8f37e7ff64b Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:31 +0000
Subject: [PATCH 6/8] add docs for new offset syntax

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 45 +++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..ad3cfac8de7bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -261,6 +261,21 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
         : !xegpu.tensor_desc<8x16xf16>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+    ```
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -350,6 +365,21 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
         : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+    ```
 
   }];
 
@@ -445,6 +475,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
                              : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
 
+    The operation may take optional offsets for the tensor descriptor.
+    The number of offsets must be greater or equal to the rank of the tensor descriptor
+    and less than the rank of the source memref. The offsets are applied to the innermost
+    dimension of the source memref.
+
+    Examples:
+    ```mlir
+      %2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+      // memref[0, 0, %off0, %off1]
+      xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      // memref[0, %off0, %off1, %off2]
+      xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      // memref[%off0, %off1, %off2, %off3]
+      xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+    ```
 
   }];
 

>From beeac485f7760ad0791950f45adf9205f708058d Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:59 +0000
Subject: [PATCH 7/8] Update validation to not depend on 'create_nd_tdesc' op

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 29 ++++++--------------------
 1 file changed, 6 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e09a084ac7ad2..55ec8a6bde6eb 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -128,29 +128,12 @@ isValidNdOffset(TypedValue<TensorDescType> tDesc,
                 std::optional<llvm::ArrayRef<int64_t>> constOffsets,
                 int64_t offsetSize,
                 function_ref<InFlightDiagnostic()> emitError) {
-  if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
-    // If CreateNdDescOp is available, we can further
-    // check the offsets rank against the source rank.
-    auto staticSource = createTDescOp.getConstShapeAttr();
-    int64_t sourceRank;
-    if (!staticSource || staticSource.empty()) {
-      auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
-      sourceRank = sourceTy.getRank();
-    } else
-      sourceRank = staticSource.size();
-
-    int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
-    auto tDescRank = tDesc.getType().getRank();
-    bool sourceRankMismatch =
-        ((offsetSize != 0) && (offsetSize != sourceRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
-    bool tdescRankMismatch =
-        ((offsetSize != 0) && (offsetSize != tDescRank)) ||
-        ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
-    if (sourceRankMismatch && tdescRankMismatch)
-      return emitError() << "Offsets rank must match either the source or the "
-                            "TensorDesc rank.";
-  }
+  int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+  auto tDescRank = tDesc.getType().getRank();
+  if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
+      ((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
+    return emitError() << "Offsets rank cannot be smaller than tensor "
+                          "descriptor rank.";
   return success();
 }
 

>From 0ef9ed7479ecaf3398cd774e430d40852d3edbb2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:57:19 +0000
Subject: [PATCH 8/8] Use memref.extract_strided_metadata to compute strides

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 21 ++++---------------
 .../VectorToXeGPU/load-to-xegpu.mlir          |  4 ++--
 .../VectorToXeGPU/store-to-xegpu.mlir         |  4 ++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  4 ++--
 .../transfer-write-to-xegpu.mlir              |  4 ++--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 12 ++---------
 6 files changed, 14 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 7f11d427191e5..0f031be26cebc 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -105,9 +105,9 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
   auto [strides, offset] = srcTy.getStridesAndOffset();
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape())
+  if (srcTy.hasStaticShape()) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
-  else {
+  } else {
     // In case of any dynamic shapes, source's shape and strides have to be
     // explicitly provided.
     SmallVector<Value> sourceDims;
@@ -123,21 +123,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
         mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
     }
 
-    // Compute strides in reverse order.
-    SmallVector<OpFoldResult> mixedStrides;
-    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    // Last stride is guaranteed to be static and unit.
-    mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
-    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-      accStride =
-          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-      if (strides[i] == ShapedType::kDynamic)
-        mixedStrides.push_back(accStride);
-      else
-        mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
-    }
-    std::reverse(mixedStrides.begin(), mixedStrides.end());
-
+    auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+    SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
                                            mixedShapes, mixedStrides);
   }
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index b5fb2c4aa3e27..1975c96bfe796 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -52,9 +52,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 57e754f7d7c00..63e78ca20bcee 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -54,9 +54,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 78a2692119142..81527a8111bb0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -150,7 +150,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD-ND:        {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -186,7 +186,7 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source2(
 // LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index e1b754f952bbe..83d33e1905f7c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -90,9 +90,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
 // STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
 // STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// STORE-ND:       {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
 // STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 27bd457a96833..d203cf82d7960 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -132,18 +132,10 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
   return
 }
 
-// -----
-func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
-  %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
-  %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
-  return
-}
-
 // -----
 func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
   xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
   return
 }
@@ -152,7 +144,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
 func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
   %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-    // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+    // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
   xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
   return
 }



More information about the Mlir-commits mailing list