[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write cases with non-contiguous memrefs (PR #158126)
Dmitry Chigarev
llvmlistbot at llvm.org
Thu Sep 11 11:06:35 PDT 2025
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/158126
>From 31299bcd30f8baaafd0e668f5181f178e95d3a0d Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 17:47:33 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU][VectorToXeGPU] Fix transfer_read/write
cases with non-contiguous memrefs
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 101 ++++++++----------
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 77 ++++++++++---
.../transfer-write-to-xegpu.mlir | 72 +++++++++++--
3 files changed, 169 insertions(+), 81 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..c29e345d10684 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -183,23 +183,28 @@ static void adjustStridesForPermutation(AffineMap permMap,
// Computes memory strides for vector transfer operations, handling both
// static and dynamic memrefs while applying permutation transformations
// for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+static std::pair<SmallVector<Value>, Value>
+computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
AffineMap permMap = xferOp.getPermutationMap();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Location loc = xferOp.getLoc();
+ Value offsetVal = nullptr;
if (memrefType.hasStaticShape()) {
int64_t offset;
SmallVector<int64_t> intStrides;
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
- return {};
+ return {{}, offsetVal};
// Wrap static strides as MLIR values
for (int64_t s : intStrides)
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
- } else {
+ if (!ShapedType::isDynamic(offset))
+ offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ }
+
+ if (strides.empty() || !offsetVal) {
// For dynamic shape memref, use memref.extract_strided_metadata to get
// stride values
unsigned rank = memrefType.getRank();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
auto meta = memref::ExtractStridedMetadataOp::create(
rewriter, loc, resultTypes, baseMemref);
- strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (strides.empty())
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+
+ if (!offsetVal)
+ offsetVal = meta.getOffset();
}
// Adjust strides according to the permutation map (e.g., for transpose)
adjustStridesForPermutation(permMap, strides);
- return strides;
+ return {strides, offsetVal};
}
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -256,8 +266,8 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
// %offsets = orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter,
- ArrayRef<Value> strides) {
+ PatternRewriter &rewriter, ArrayRef<Value> strides,
+ Value baseOffset) {
Location loc = xferOp.getLoc();
VectorType vectorType = xferOp.getVectorType();
SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
// Compute base offset from transfer read indices
- Value baseOffset = nullptr;
- if (!indices.empty()) {
- baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- for (size_t i = 0; i < indices.size(); ++i) {
- Value strideVal = strides[i];
- Value offsetContrib =
- arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
- baseOffset =
- arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
- }
- // Broadcast base offset to match vector shape
- Value bcastBase = vector::BroadcastOp::create(
- rewriter, loc, fullIndexVectorType, baseOffset);
- localOffsets =
- arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
return localOffsets;
}
// Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
- PatternRewriter &rewriter) {
+static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
-
- Value baseMemref = xferOp.getBase();
- MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
- Type elementType = memrefType.getElementType();
-
- // Compute the total number of elements in the memref
- MemRefType flatMemrefType;
- if (memrefType.hasStaticShape()) {
- auto totalElements = memrefType.getNumElements();
- flatMemrefType = MemRefType::get({totalElements}, elementType);
- } else {
- flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
- }
-
- SmallVector<ReassociationIndices> reassociation;
- ReassociationIndices allDims =
- llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
- reassociation.push_back(allDims);
-
- auto collapseOp = memref::CollapseShapeOp::create(
- rewriter, loc, flatMemrefType, baseMemref, reassociation);
- return collapseOp;
+ auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, xferOp.getBase())
+ .getResult();
+ return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
+ indexPtr)
+ .getResult();
}
static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
if (!memrefType)
return rewriter.notifyMatchFailure(readOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(readOp, rewriter);
- if (strides.empty())
+ auto meta = computeMemrefMeta(readOp, rewriter);
+ if (meta.first.empty())
return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(readOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(readOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(readOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
if (!memrefType)
return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
- SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+ auto meta = computeMemrefMeta(writeOp, rewriter);
+ if (meta.first.empty())
+ return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+ Value localOffsets =
+ computeOffsets(writeOp, rewriter, meta.first, meta.second);
- Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+ Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
Value mask = vector::ConstantMaskOp::create(
rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b373bdab80567..c4ca79af1bd9a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
}
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
+// LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
}
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
// LOAD-GATHER: return %[[VEC]]
}
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
}
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
}
// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
// LOAD-GATHER: vector.transfer_read
}
+// -----
+gpu.module @xevm_module {
+gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2: index) -> vector<8xf16> {
+ %c0 = arith.constant 0.0 : f16
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ %0 = vector.transfer_read %subview[%off2, %off2], %c0
+ {in_bounds = [true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<8xf16>
+ gpu.return %0 : vector<8xf16>
+}
+
+// LOAD-ND-LABEL: @load_from_subview(
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND: return %[[VEC]]
+
+// LOAD-GATHER-LABEL: @load_from_subview(
+// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER: arith.muli {{.*}} : index
+// LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
+// LOAD-GATHER: arith.addi {{.*}} : index
+// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index b3f761a545ee1..fcfc9414da4f6 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -30,8 +30,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
// STORE-SCATTER-DAG: %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
}
// -----
@@ -64,8 +65,9 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}
// -----
@@ -104,8 +106,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}
// -----
@@ -155,8 +158,9 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
// STORE-SCATTER-DAG: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
// STORE-SCATTER-DAG: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
-// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1>
+// STORE-SCATTER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32x64xf32> -> index
+// STORE-SCATTER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
}
// -----
@@ -186,8 +190,9 @@ gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
// STORE-SCATTER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1>
+// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index
+// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1>
}
// -----
@@ -275,4 +280,49 @@ gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
// STORE-SCATTER-LABEL: @no_store_out_of_bounds_1D_vector(
// STORE-SCATTER: vector.transfer_write
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_to_subview(%vec: vector<8xf16>,
+ %source: memref<4096x4096xf16>, %off1: index, %off2: index) {
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
+ : memref<4096x4096xf16>
+ to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ vector.transfer_write %vec, %subview[%off2, %off2]
+ {in_bounds = [true]}
+ : vector<8xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ gpu.return
+}
+// STORE-ND-LABEL: @store_to_subview(
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>,
+// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME: boundary_check = false
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+
+// STORE-SCATTER-LABEL: @store_to_subview(
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>,
+// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// STORE-SCATTER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// STORE-SCATTER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// STORE-SCATTER-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-SCATTER: %[[BB:.+]], %[[OFFSET:.+]], {{.*}}, {{.*}} = memref.extract_strided_metadata %[[SUBVIEW]]
+// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// STORE-SCATTER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// STORE-SCATTER: arith.muli {{.*}} : index
+// STORE-SCATTER: arith.addi %[[OFFSET]]{{.*}} : index
+// STORE-SCATTER: arith.addi {{.*}} : index
+// STORE-SCATTER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]]
+// STORE-SCATTER-SAME: : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
+}
>From 4d4f2cfeaeb62390bf68eb2564f1a38a65c52417 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 18:06:20 +0000
Subject: [PATCH 2/2] Fix docstrings
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c29e345d10684..852c322cc6467 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -180,9 +180,9 @@ static void adjustStridesForPermutation(AffineMap permMap,
strides = applyPermutation(strides, perms64);
}
-// Computes memory strides for vector transfer operations, handling both
-// static and dynamic memrefs while applying permutation transformations
-// for XeGPU lowering.
+// Computes memory strides and a memref offset for vector transfer operations,
+// handling both static and dynamic memrefs while applying permutation
+// transformations for XeGPU lowering.
static std::pair<SmallVector<Value>, Value>
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
@@ -264,7 +264,7 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
// %23 = arith.add %20, %21
// %local_offsets = arith.add %22, %23
// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-// %offsets = orig_offset + local_offsets
+// %offsets = memref_offset + orig_offset + local_offsets
static Value computeOffsets(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter, ArrayRef<Value> strides,
Value baseOffset) {
@@ -339,7 +339,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
return localOffsets;
}
-// Collapse memref shape to 1D
+// Convert memref to i64 base pointer
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
More information about the Mlir-commits
mailing list