[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write cases with non-contiguous memrefs (PR #158126)

Dmitry Chigarev llvmlistbot at llvm.org
Thu Sep 11 11:06:35 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/158126

>From 31299bcd30f8baaafd0e668f5181f178e95d3a0d Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 17:47:33 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write
 cases with non-contiguous memrefs

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 101 ++++++++----------
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  77 ++++++++++---
 .../transfer-write-to-xegpu.mlir              |  72 +++++++++++--
 3 files changed, 169 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..c29e345d10684 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -183,23 +183,28 @@ static void adjustStridesForPermutation(AffineMap permMap,
 // Computes memory strides for vector transfer operations, handling both
 // static and dynamic memrefs while applying permutation transformations
 // for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
-                                         PatternRewriter &rewriter) {
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
   SmallVector<Value> strides;
   Value baseMemref = xferOp.getBase();
   AffineMap permMap = xferOp.getPermutationMap();
   MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
+  Value offsetVal = nullptr;
   if (memrefType.hasStaticShape()) {
     int64_t offset;
     SmallVector<int64_t> intStrides;
     if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
-      return {};
+      return {{}, offsetVal};
     // Wrap static strides as MLIR values
     for (int64_t s : intStrides)
       strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
-  } else {
+    if (!ShapedType::isDynamic(offset))
+      offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  }
+
+  if (strides.empty() || !offsetVal) {
     // For dynamic shape memref, use memref.extract_strided_metadata to get
     // stride values
     unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 
     auto meta = memref::ExtractStridedMetadataOp::create(
         rewriter, loc, resultTypes, baseMemref);
-    strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (strides.empty())
+      strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+    if (!offsetVal)
+      offsetVal = meta.getOffset();
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
   adjustStridesForPermutation(permMap, strides);
-  return strides;
+  return {strides, offsetVal};
 }
 
 // This function compute the vectors of localOffsets for scattered load/stores.
@@ -256,8 +266,8 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
 //   %offsets =  orig_offset + local_offsets
 static Value computeOffsets(VectorTransferOpInterface xferOp,
-                            PatternRewriter &rewriter,
-                            ArrayRef<Value> strides) {
+                            PatternRewriter &rewriter, ArrayRef<Value> strides,
+                            Value baseOffset) {
   Location loc = xferOp.getLoc();
   VectorType vectorType = xferOp.getVectorType();
   SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
         arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
 
   // Compute base offset from transfer read indices
-  Value baseOffset = nullptr;
-  if (!indices.empty()) {
-    baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    for (size_t i = 0; i < indices.size(); ++i) {
-      Value strideVal = strides[i];
-      Value offsetContrib =
-          arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
-      baseOffset =
-          arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
-    }
-    // Broadcast base offset to match vector shape
-    Value bcastBase = vector::BroadcastOp::create(
-        rewriter, loc, fullIndexVectorType, baseOffset);
-    localOffsets =
-        arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+  for (size_t i = 0; i < indices.size(); ++i) {
+    Value strideVal = strides[i];
+    Value offsetContrib =
+        arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+    baseOffset =
+        arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
   }
+  // Broadcast base offset to match vector shape
+  Value bcastBase = vector::BroadcastOp::create(
+      rewriter, loc, fullIndexVectorType, baseOffset);
+  localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
   return localOffsets;
 }
 
 // Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+                              PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();
-
-  Value baseMemref = xferOp.getBase();
-  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
-  Type elementType = memrefType.getElementType();
-
-  // Compute the total number of elements in the memref
-  MemRefType flatMemrefType;
-  if (memrefType.hasStaticShape()) {
-    auto totalElements = memrefType.getNumElements();
-    flatMemrefType = MemRefType::get({totalElements}, elementType);
-  } else {
-    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
-  }
-
-  SmallVector<ReassociationIndices> reassociation;
-  ReassociationIndices allDims =
-      llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
-  reassociation.push_back(allDims);
-
-  auto collapseOp = memref::CollapseShapeOp::create(
-      rewriter, loc, flatMemrefType, baseMemref, reassociation);
-  return collapseOp;
+  auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+                      rewriter, loc, xferOp.getBase())
+                      .getResult();
+  return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+                                    indexPtr)
+      .getResult();
 }
 
 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(readOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(readOp, rewriter);
-  if (strides.empty())
+  auto meta = computeMemrefMeta(readOp, rewriter);
+  if (meta.first.empty())
     return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(readOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(readOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(readOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
 
-  SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+  auto meta = computeMemrefMeta(writeOp, rewriter);
+  if (meta.first.empty())
+    return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
 
-  Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+  Value localOffsets =
+      computeOffsets(writeOp, rewriter, meta.first, meta.second);
 
-  Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+  Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
 
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
 
 }
 
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 // LOAD-GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
 }
 
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
 
 }
 
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
 // LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
 // LOAD-GATHER:        return %[[VEC]]
 }
 
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
 // LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
 
 }
 
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
 // LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
 
 // -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
 // LOAD-GATHER:        vector.transfer_read
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+  %c0 = arith.constant 0.0 : f16
+  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  %0 = vector.transfer_read %subview[%off2, %off2], %c0
+    {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+  gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL:  @load_from_subview(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND:        return %[[VEC]]
+
+// LOAD-GATHER-LABEL:  @load_from_subview(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER:        arith.muli {{.*}} : index
+// LOAD-GATHER:        arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER:        arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
 // STORE-SCATTER-DAG:    %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
 // STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
 }
 
 // -----
@@ -64,8 +65,9 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
 // STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -104,8 +106,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
 // STORE-SCATTER-DAG:   %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // STORE-SCATTER-DAG:   %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG:   %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG:   %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// STORE-SCATTER-DAG:   %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -155,8 +158,9 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
 // STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
 // STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32>
-// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64xf32> -> index
+// STORE-SCATTER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -186,8 +190,9 @@ gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
 // STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
 // STORE-SCATTER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
 // STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> 
+// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index
+// STORE-SCATTER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1>
 }
 
 // -----
@@ -275,4 +280,49 @@ gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
 
 // STORE-SCATTER-LABEL:  @no_store_out_of_bounds_1D_vector(
 // STORE-SCATTER:        vector.transfer_write
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_to_subview(%vec: vector<8xf16>,
+    %source: memref<4096x4096xf16>, %off1: index, %off2: index) {
+  %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
+      : memref<4096x4096xf16>
+        to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  vector.transfer_write %vec, %subview[%off2, %off2]
+      {in_bounds = [true]}
+      : vector<8xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  gpu.return
+}
+// STORE-ND-LABEL:  @store_to_subview(
+// STORE-ND-SAME:   %[[VEC:.+]]: vector<8xf16>,
+// STORE-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// STORE-ND-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME:     %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// STORE-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME:     boundary_check = false
+// STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+
+// STORE-SCATTER-LABEL:  @store_to_subview(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf16>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// STORE-SCATTER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// STORE-SCATTER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// STORE-SCATTER-SAME:     : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-SCATTER:        %[[BB:.+]], %[[OFFSET:.+]], {{.*}}, {{.*}} = memref.extract_strided_metadata %[[SUBVIEW]]
+// STORE-SCATTER-SAME:     : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// STORE-SCATTER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// STORE-SCATTER:        arith.muli {{.*}} : index
+// STORE-SCATTER:        arith.addi %[[OFFSET]]{{.*}} : index
+// STORE-SCATTER:        arith.addi {{.*}} : index
+// STORE-SCATTER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]]
+// STORE-SCATTER-SAME:     : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// STORE-SCATTER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
+}

>From 4d4f2cfeaeb62390bf68eb2564f1a38a65c52417 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 18:06:20 +0000
Subject: [PATCH 2/2] Fix docstrings

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c29e345d10684..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,9 +180,9 @@ static void adjustStridesForPermutation(AffineMap permMap,
   strides = applyPermutation(strides, perms64);
 }
 
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
 static std::pair<SmallVector<Value>, Value>
 computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
   SmallVector<Value> strides;
@@ -264,7 +264,7 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
 //   %23 = arith.add %20, %21
 //   %local_offsets = arith.add %22, %23
 //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-//   %offsets =  orig_offset + local_offsets
+//   %offsets =  memref_offset + orig_offset + local_offsets
 static Value computeOffsets(VectorTransferOpInterface xferOp,
                             PatternRewriter &rewriter, ArrayRef<Value> strides,
                             Value baseOffset) {
@@ -339,7 +339,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   return localOffsets;
 }
 
-// Collapse memref shape to 1D
+// Convert memref to i64 base pointer
 static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
                               PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();



More information about the Mlir-commits mailing list