[Mlir-commits] [mlir] 747050b - [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (#162095)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 4 04:52:28 PST 2025


Author: Dmitry Chigarev
Date: 2025-11-04T13:52:23+01:00
New Revision: 747050bcceca18d32dc1140461984ec2c30ae96a

URL: https://github.com/llvm/llvm-project/commit/747050bcceca18d32dc1140461984ec2c30ae96a
DIFF: https://github.com/llvm/llvm-project/commit/747050bcceca18d32dc1140461984ec2c30ae96a.diff

LOG: [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (#162095)

Changes the `VectorToXeGPU` pass to generate `xegpu.load_nd/store_nd`
ops using new syntax with where offsets are specified at the load/store
ops level.
```mlir
// from this
%desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>

// to this
%desc = xegpu.create_nd_tdesc %src: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
```

In order to support cases with dimension reduction at the
`create_nd_tdesc` level (e.g. `memref<8x8x16xf16> ->
tensor_desc<8x16xf16>` it was decided to insert a memref.subview that
collapses the source shape to 2d, for example:

```mlir
// input:
%0 = vector.load %source[%off0, %off1, %off2] : memref<8x16x32xf32>, vector<8x16xf32>

// --vector-to-xegpu (old)
%tdesc = xegpu.create_nd_tdesc %source[%off0, %off1, %off2] : memref<8x16x32xf32> -> tdesc<8x32xf32>
%vec = xegpu.load_nd %tdesc

// --vector-to-xegpu (new)
%collapsed = memref.subview %source[%off0, 0, 0] [1, 16, 32] [1, 1, 1] :
    memref<8x16x32xf32> -> memref<16x32xf32, strided<[32, 1], offset: ?>>
%tdesc = xegpu.create_nd_tdesc %collapsed : memref<16x32xf32, ...> -> tdesc<8x32xf32>
%vec = xegpu.load_nd %tdesc[%off1, %off2]
```

<details><summary>Why we need to change that?</summary>

```mlir
// reduce dim and apply all 3 offsets at load_nd
%desc = xegpu.create_nd_tdesc %source : memref<8x16x32xf32> -> !xegpu.tensor_desc<16x32xf32>
// error: xegpu.load_nd len(offsets) != desc.rank
%res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc<16x32xf32> -> vector<8x16xf32>
```

</details>

---------

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 91c1aa55fdb4e..abea84f6b01fe 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
-                   xegpu::TensorDescType descType, TypedValue<MemRefType> src,
-                   Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+                                                Location loc,
+                                                xegpu::TensorDescType descType,
+                                                TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
   auto [strides, offset] = srcTy.getStridesAndOffset();
 
   xegpu::CreateNdDescOp ndDesc;
   if (srcTy.hasStaticShape()) {
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           getAsOpFoldResult(offsets));
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   } else {
     // In case of any dynamic shapes, source's shape and strides have to be
     // explicitly provided.
-    SmallVector<Value> sourceDims;
-    unsigned srcRank = srcTy.getRank();
-    for (unsigned i = 0; i < srcRank; ++i)
-      sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
-    SmallVector<int64_t> constOffsets;
-    SmallVector<Value> dynOffsets;
-    for (Value offset : offsets) {
-      std::optional<int64_t> staticVal = getConstantIntValue(offset);
-      if (!staticVal)
-        dynOffsets.push_back(offset);
-      constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
-    }
-
-    SmallVector<Value> dynShapes;
-    for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
-      if (shape == ShapedType::kDynamic)
-        dynShapes.push_back(sourceDims[idx]);
-    }
-
-    // Compute strides in reverse order.
-    SmallVector<Value> dynStrides;
-    Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
-    // Last stride is guaranteed to be static and unit.
-    for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
-      accStride =
-          arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
-      if (strides[i] == ShapedType::kDynamic)
-        dynStrides.push_back(accStride);
-    }
-    std::reverse(dynStrides.begin(), dynStrides.end());
-
-    ndDesc = xegpu::CreateNdDescOp::create(
-        rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
-        DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
-        DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
-        DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+    auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+                                           meta.getConstifiedMixedSizes(),
+                                           meta.getConstifiedMixedStrides());
   }
 
   return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
       .getResult();
 }
 
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+//   input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+//   output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+                                    Value memref,
+                                    SmallVector<OpFoldResult> offsets,
+                                    int64_t targetRank) {
+  auto memrefType = cast<MemRefType>(memref.getType());
+  unsigned rank = memrefType.getRank();
+
+  if (rank <= targetRank)
+    return {memref, offsets};
+
+  int64_t numCombinedDims = rank - targetRank;
+  SmallVector<OpFoldResult> subviewOffsets;
+  SmallVector<OpFoldResult> subviewSizes;
+  SmallVector<OpFoldResult> subviewStrides;
+
+  // For the combined dimensions: use the provided offsets, size=1, stride=1
+  for (unsigned i = 0; i < numCombinedDims; ++i) {
+    subviewOffsets.push_back(offsets[i]);
+    subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+    subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+  }
+
+  // For the last targetRank dimensions: offset=0, use full size, stride=1
+  SmallVector<int64_t> resultShape;
+  auto originalShape = memrefType.getShape();
+  auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+  for (unsigned i = numCombinedDims; i < rank; ++i) {
+    subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+    if (ShapedType::isDynamic(originalShape[i])) {
+      subviewSizes.push_back(meta.getSizes()[i]);
+      resultShape.push_back(ShapedType::kDynamic);
+    } else {
+      subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+      resultShape.push_back(originalShape[i]);
+    }
+    subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+  }
+
+  auto resultType = memref::SubViewOp::inferRankReducedResultType(
+      resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+  auto subviewOp =
+      memref::SubViewOp::create(rewriter, loc, resultType, memref,
+                                subviewOffsets, subviewSizes, subviewStrides);
+
+  // Return the remaining offsets for the last targetRank dimensions
+  SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+                                       offsets.end());
+  return {subviewOp.getResult(), newOffsets};
+}
+
 template <
     typename OpType,
     typename = std::enable_if_t<llvm::is_one_of<
@@ -523,18 +545,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
         descShape, elementType, /*array_length=*/1,
         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
 
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
-                           readOp.getIndices());
-
     DenseI64ArrayAttr transposeAttr =
         !isTransposeLoad ? nullptr
                          : DenseI64ArrayAttr::get(rewriter.getContext(),
                                                   ArrayRef<int64_t>{1, 0});
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+        vecTy.getRank());
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
                                           /*packed=*/nullptr, transposeAttr,
                                           /*l1_hint=*/hint,
                                           /*l2_hint=*/hint, /*l3_hint=*/hint);
@@ -575,21 +598,23 @@ struct TransferWriteLowering
     if (!map.isMinorIdentity())
       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
 
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, writeOp.getBase(),
+        getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
         xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc =
-        createNdDescriptor(rewriter, loc, descType,
-                           dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
-                           writeOp.getIndices());
-
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
-    auto storeOp =
-        xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
-                                 /*l1_hint=*/hint,
-                                 /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+    auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+                                            ndDesc, indices,
+                                            /*l1_hint=*/hint,
+                                            /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
@@ -674,19 +699,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
 
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
+    // By default, no specific caching policy is assigned.
+    xegpu::CachePolicyAttr hint = nullptr;
+
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+        vecTy.getRank());
 
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
         boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
 
-    // By default, no specific caching policy is assigned.
-    xegpu::CachePolicyAttr hint = nullptr;
-    auto loadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
-        /*l1_hint=*/hint,
-        /*l2_hint=*/hint, /*l3_hint=*/hint);
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+    auto loadNdOp =
+        xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+                                /*packed=*/nullptr, /*transpose=*/nullptr,
+                                /*l1_hint=*/hint,
+                                /*l2_hint=*/hint, /*l3_hint=*/hint);
     rewriter.replaceOp(loadOp, loadNdOp);
 
     return success();
@@ -708,18 +738,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
     // Boundary check is available only for block instructions.
     bool boundaryCheck = vecTy.getRank() > 1;
 
+    auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+        rewriter, loc, storeOp.getBase(),
+        getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
-    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
-        rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
 
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
+    xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+        rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
     auto storeNdOp =
-        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+        xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
                                  /*l1_hint=*/hint,
                                  /*l2_hint=*/hint, /*l3_hint=*/hint);
+
     rewriter.replaceOp(storeOp, storeNdOp);
 
     return success();

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b6c4b6c2c813..c8f5c86c03686 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -280,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }
@@ -342,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
     auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
 
     // if shape and strides are from Memref, we don't need attributes for them
-    // to keep the IR print clean.
-    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+    // to keep the IR print clean (only do so for full-static case, otherwise
+    // printer would fail trying to print empty array-attr).
+    if (staticShape == memrefShape && staticStrides == memrefStrides &&
+        dynamicShape.empty() && dynamicStrides.empty()) {
       staticShapeAttr = DenseI64ArrayAttr();
       staticStridesAttr = DenseI64ArrayAttr();
     }

diff  --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..ae5141db16c09 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-LABEL: @load_1D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -28,35 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-LABEL: @load_2D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
 
 func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
-    %offset: index) -> vector<8x16xf32> {
-  %0 = vector.load %source[%offset, %offset, %offset]
+    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+  %0 = vector.load %source[%i, %j, %k]
     : memref<?x?x?xf32>, vector<8x16xf32>
   return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @load_dynamic_source(
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----
@@ -72,9 +67,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x15xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
 // -----

diff  --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..1a10d917623cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
 // -----
 
@@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME:    %[[COLLAPSED]]
+// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 
 func.func @store_dynamic_source(%vec: vector<8x16xf32>,
-    %source: memref<?x?x?xf32>, %offset: index) {
-  vector.store %vec, %source[%offset, %offset, %offset]
+    %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+  vector.store %vec, %source[%i, %j, %k]
     : memref<?x?x?xf32>, vector<8x16xf32>
   return
 }
@@ -47,18 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-LABEL: @store_dynamic_source(
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // -----
 
@@ -74,9 +69,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
 // CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:    %[[SRC]]
 // CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
 

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79af1bd9a..c87a5304babfe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -12,11 +12,12 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-ND-LABEL:  @load_1D_vector(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_1D_vector(
@@ -46,11 +47,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-LABEL:  @load_2D_vector(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_2D_vector(
@@ -83,9 +85,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
 // LOAD-ND-LABEL:  @load_zero_pad_out_of_bounds(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_zero_pad_out_of_bounds(
@@ -109,9 +111,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
 // LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
 // LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
 // LOAD-ND-SAME:     -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
@@ -143,16 +145,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 }
 // LOAD-ND-LABEL:  @load_dynamic_source(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD-ND:        %[[C2:.+]] = arith.constant 2 : index
-// LOAD-ND:        %[[C1:.+]] = arith.constant 1 : index
-// LOAD-ND:        %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND:        {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
 
@@ -184,10 +181,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 }
 
 // LOAD-ND-LABEL:  @load_dynamic_source2(
-// LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
+// LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
 // LOAD-GATHER-LABEL:  @load_dynamic_source2(
@@ -418,11 +416,12 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME:     %[[COLLAPSED]]
+// LOAD-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // LOAD-ND-SAME:     boundary_check = false
-// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16>
 // LOAD-ND:        return %[[VEC]]
 
 // LOAD-GATHER-LABEL:  @load_from_subview(

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc9414da4f6..43a1a7206e2cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
 // RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
 
 
@@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME:    %[[COLLAPSED]]
+// STORE-ND-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
 // STORE-SCATTER-LABEL:  @store_1D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
@@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME:    %[[COLLAPSED]]
+// STORE-ND-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_2D_vector(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
@@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // -----
 gpu.module @xevm_module {
 gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
-    %source: memref<?x?x?xf32>, %offset: index) {
-  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+  vector.transfer_write %vec, %source[%i, %j, %k]
     {in_bounds = [true, true]}
     : vector<8x16xf32>, memref<?x?x?xf32>
   gpu.return
@@ -83,18 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-LABEL: @store_dynamic_source(
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// STORE-ND-SAME:  %[[OFFSET:.+]]: index
-// STORE-ND-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// STORE-ND-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// STORE-ND-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// STORE-ND-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// STORE-ND:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
 // STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
@@ -126,9 +121,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
 // STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    %[[SRC]]
 // STORE-ND-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL:  @store_out_of_bounds(
 // STORE-SCATTER:   vector.transfer_write
@@ -298,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND-SAME:   %[[VEC:.+]]: vector<8xf16>,
 // STORE-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
-// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
-// STORE-ND-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
 // STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// STORE-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME:     %[[COLLAPSED]]
+// STORE-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
 // STORE-ND-SAME:     boundary_check = false
-// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
 
 // STORE-SCATTER-LABEL:  @store_to_subview(
 // STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,


        


More information about the Mlir-commits mailing list